From 1a0cd1ce5ae55c914c7230e044647b2b1f7647e9 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 9 Dec 2024 13:57:38 +0000 Subject: [PATCH 1/5] Use torch in get_2d_sincos_pos_embed --- src/diffusers/models/embeddings.py | 55 ++++++++++++------- .../pipelines/unidiffuser/modeling_uvit.py | 2 +- 2 files changed, 36 insertions(+), 21 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 8f8f1073da74..0153e794cb80 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -139,7 +139,13 @@ def get_3d_sincos_pos_embed( def get_2d_sincos_pos_embed( - embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 + embed_dim, + grid_size, + cls_token=False, + extra_tokens=0, + interpolation_scale=1.0, + base_size=16, + device: Optional[torch.device] = None, ): """ Creates 2D sinusoidal positional embeddings. @@ -157,22 +163,30 @@ def get_2d_sincos_pos_embed( The scale of the interpolation. Returns: - pos_embed (`np.ndarray`): + pos_embed (`torch.Tensor`): Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size, embed_dim]` if using cls_token """ if isinstance(grid_size, int): grid_size = (grid_size, grid_size) - grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale - grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) + grid_h = ( + torch.arange(grid_size[0], device=device, dtype=torch.float32) + / (grid_size[0] / base_size) + / interpolation_scale + ) + grid_w = ( + torch.arange(grid_size[1], device=device, dtype=torch.float32) + / (grid_size[1] / base_size) + / interpolation_scale + ) + grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first + grid = torch.stack(grid, dim=0) grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token and extra_tokens > 0: - pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + pos_embed = torch.concat([torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0) return pos_embed @@ -182,10 +196,10 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): Args: embed_dim (`int`): The embedding dimension. - grid (`np.ndarray`): Grid of positions with shape `(H * W,)`. + grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`. Returns: - `np.ndarray`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)` + `torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)` """ if embed_dim % 2 != 0: raise ValueError("embed_dim must be divisible by 2") @@ -194,7 +208,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + emb = torch.concat([emb_h, emb_w], dim=1) # (H*W, D) return emb @@ -204,25 +218,25 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): Args: embed_dim (`int`): The embedding dimension `D` - pos (`numpy.ndarray`): 1D tensor of positions with shape `(M,)` + pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)` Returns: - `numpy.ndarray`: Sinusoidal positional embeddings of shape `(M, D)`. + `torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`. """ if embed_dim % 2 != 0: raise ValueError("embed_dim must be divisible by 2") - omega = np.arange(embed_dim // 2, dtype=np.float64) + omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) - out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + out = torch.outer(pos, omega) # (M, D/2), outer product - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D) return emb @@ -291,7 +305,7 @@ def __init__( embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale ) persistent = True if pos_embed_max_size else False - self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent) + self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=persistent) else: raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}") @@ -341,8 +355,9 @@ def forward(self, latent): grid_size=(height, width), base_size=self.base_size, interpolation_scale=self.interpolation_scale, + device=latent.device, ) - pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device) + pos_embed = pos_embed.float().unsqueeze(0) else: pos_embed = self.pos_embed @@ -554,7 +569,7 @@ def __init__( pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size) pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size) - self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float(), persistent=False) + self.register_buffer("pos_embed", pos_embed.float(), persistent=False) def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor: batch_size, channel, height, width = hidden_states.shape diff --git a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py index cb1514b153ce..23aba3ecb123 100644 --- a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py +++ b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py @@ -105,7 +105,7 @@ def __init__( self.use_pos_embed = use_pos_embed if self.use_pos_embed: pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5)) - self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) + self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=False) def forward(self, latent): latent = self.proj(latent) From 6f0b4d6fe87840711aa279b10e48daaf5e033f4d Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 9 Dec 2024 14:37:24 +0000 Subject: [PATCH 2/5] Use torch in get_3d_sincos_pos_embed --- src/diffusers/models/embeddings.py | 40 ++++++++++++++++++------------ 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 0153e794cb80..38cada6d787f 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -84,7 +84,8 @@ def get_3d_sincos_pos_embed( temporal_size: int, spatial_interpolation_scale: float = 1.0, temporal_interpolation_scale: float = 1.0, -) -> np.ndarray: + device: Optional[torch.device] = None, +) -> torch.Tensor: r""" Creates 3D sinusoidal positional embeddings. @@ -102,7 +103,7 @@ def get_3d_sincos_pos_embed( Scale factor for temporal grid interpolation. Returns: - `np.ndarray`: + `torch.Tensor`: The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1], embed_dim]`. """ @@ -115,26 +116,28 @@ def get_3d_sincos_pos_embed( embed_dim_temporal = embed_dim // 4 # 1. Spatial - grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale - grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) + grid_h = torch.arange(spatial_size[1], device=device, dtype=torch.float32) / spatial_interpolation_scale + grid_w = torch.arange(spatial_size[0], device=device, dtype=torch.float32) / spatial_interpolation_scale + grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first + grid = torch.stack(grid, dim=0) grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]]) pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid) # 2. Temporal - grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale + grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32) / temporal_interpolation_scale pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t) # 3. Concat - pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :] - pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3] + pos_embed_spatial = pos_embed_spatial[None, :, :] + pos_embed_spatial = pos_embed_spatial.repeat_interleave(temporal_size, dim=0) # [T, H*W, D // 4 * 3] - pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :] - pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4] + pos_embed_temporal = pos_embed_temporal[:, None, :] + pos_embed_temporal = pos_embed_temporal.repeat_interleave( + spatial_size[0] * spatial_size[1], dim=1 + ) # [T, H*W, D // 4] - pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D] + pos_embed = torch.concat([pos_embed_temporal, pos_embed_spatial], dim=-1) # [T, H*W, D] return pos_embed @@ -468,7 +471,9 @@ def __init__( pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames) self.register_buffer("pos_embedding", pos_embedding, persistent=persistent) - def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor: + def _get_positional_embeddings( + self, sample_height: int, sample_width: int, sample_frames: int, device: Optional[torch.device] = None + ) -> torch.Tensor: post_patch_height = sample_height // self.patch_size post_patch_width = sample_width // self.patch_size post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1 @@ -480,8 +485,9 @@ def _get_positional_embeddings(self, sample_height: int, sample_width: int, samp post_time_compression_frames, self.spatial_interpolation_scale, self.temporal_interpolation_scale, + device=device, ) - pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1) + pos_embedding = pos_embedding.flatten(0, 1) joint_pos_embedding = torch.zeros( 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False ) @@ -536,8 +542,10 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): or self.sample_width != width or self.sample_frames != pre_time_compression_frames ): - pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames) - pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype) + pos_embedding = self._get_positional_embeddings( + height, width, pre_time_compression_frames, device=embeds.device + ) + pos_embedding = pos_embedding.to(dtype=embeds.dtype) else: pos_embedding = self.pos_embedding From 1adf5f0a773f5becbc7a7ece932055f00a294b2d Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 9 Dec 2024 14:44:56 +0000 Subject: [PATCH 3/5] get_1d_sincos_pos_embed_from_grid in LatteTransformer3DModel --- src/diffusers/models/transformers/latte_transformer_3d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index 7e2b1273687d..44e4ef789a80 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -158,7 +158,7 @@ def __init__( temp_pos_embed = get_1d_sincos_pos_embed_from_grid( inner_dim, torch.arange(0, video_length).unsqueeze(1) ) # 1152 hidden size - self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False) + self.register_buffer("temp_pos_embed", temp_pos_embed.float().unsqueeze(0), persistent=False) self.gradient_checkpointing = False From c5bd7714f8dad72c3edb77d41fee40123ba79482 Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 11 Dec 2024 14:51:26 +0000 Subject: [PATCH 4/5] deprecate --- src/diffusers/models/embeddings.py | 224 +++++++++++++++++- .../transformers/latte_transformer_3d.py | 2 +- .../pipelines/unidiffuser/modeling_uvit.py | 2 +- 3 files changed, 217 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 38cada6d787f..ee7300606a13 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -85,6 +85,7 @@ def get_3d_sincos_pos_embed( spatial_interpolation_scale: float = 1.0, temporal_interpolation_scale: float = 1.0, device: Optional[torch.device] = None, + output_type: str = "np", ) -> torch.Tensor: r""" Creates 3D sinusoidal positional embeddings. @@ -107,6 +108,20 @@ def get_3d_sincos_pos_embed( The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1], embed_dim]`. """ + if output_type == "np": + deprecation_message = ( + "`get_3d_sincos_pos_embed` uses `torch` and supports `device`." + " `from_numpy` is no longer required." + " Pass `output_type='pt' to use the new version now." + ) + deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False) + return get_3d_sincos_pos_embed_np( + embed_dim=embed_dim, + spatial_size=spatial_size, + temporal_size=temporal_size, + spatial_interpolation_scale=spatial_interpolation_scale, + temporal_interpolation_scale=temporal_interpolation_scale, + ) if embed_dim % 4 != 0: raise ValueError("`embed_dim` must be divisible by 4") if isinstance(spatial_size, int): @@ -122,11 +137,11 @@ def get_3d_sincos_pos_embed( grid = torch.stack(grid, dim=0) grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]]) - pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid) + pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid, output_type="pt") # 2. Temporal grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32) / temporal_interpolation_scale - pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t) + pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t, output_type="pt") # 3. Concat pos_embed_spatial = pos_embed_spatial[None, :, :] @@ -141,6 +156,66 @@ def get_3d_sincos_pos_embed( return pos_embed +def get_3d_sincos_pos_embed_np( + embed_dim: int, + spatial_size: Union[int, Tuple[int, int]], + temporal_size: int, + spatial_interpolation_scale: float = 1.0, + temporal_interpolation_scale: float = 1.0, +) -> np.ndarray: + r""" + Creates 3D sinusoidal positional embeddings. + + Args: + embed_dim (`int`): + The embedding dimension of inputs. It must be divisible by 16. + spatial_size (`int` or `Tuple[int, int]`): + The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both + spatial dimensions (height and width). + temporal_size (`int`): + The temporal dimension of postional embeddings (number of frames). + spatial_interpolation_scale (`float`, defaults to 1.0): + Scale factor for spatial grid interpolation. + temporal_interpolation_scale (`float`, defaults to 1.0): + Scale factor for temporal grid interpolation. + + Returns: + `np.ndarray`: + The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1], + embed_dim]`. + """ + if embed_dim % 4 != 0: + raise ValueError("`embed_dim` must be divisible by 4") + if isinstance(spatial_size, int): + spatial_size = (spatial_size, spatial_size) + + embed_dim_spatial = 3 * embed_dim // 4 + embed_dim_temporal = embed_dim // 4 + + # 1. Spatial + grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale + grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]]) + pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid) + + # 2. Temporal + grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale + pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t) + + # 3. Concat + pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :] + pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3] + + pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :] + pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4] + + pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D] + return pos_embed + + def get_2d_sincos_pos_embed( embed_dim, grid_size, @@ -149,6 +224,7 @@ def get_2d_sincos_pos_embed( interpolation_scale=1.0, base_size=16, device: Optional[torch.device] = None, + output_type: str = "np", ): """ Creates 2D sinusoidal positional embeddings. @@ -170,6 +246,21 @@ def get_2d_sincos_pos_embed( Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size, embed_dim]` if using cls_token """ + if output_type == "np": + deprecation_message = ( + "`get_2d_sincos_pos_embed` uses `torch` and supports `device`." + " `from_numpy` is no longer required." + " Pass `output_type='pt' to use the new version now." + ) + deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False) + return get_2d_sincos_pos_embed_np( + embed_dim=embed_dim, + grid_size=grid_size, + cls_token=cls_token, + extra_tokens=extra_tokens, + interpolation_scale=interpolation_scale, + base_size=base_size, + ) if isinstance(grid_size, int): grid_size = (grid_size, grid_size) @@ -187,13 +278,13 @@ def get_2d_sincos_pos_embed( grid = torch.stack(grid, dim=0) grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type=output_type) if cls_token and extra_tokens > 0: pos_embed = torch.concat([torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0) return pos_embed -def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"): r""" This function generates 2D sinusoidal positional embeddings from a grid. @@ -204,18 +295,29 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): Returns: `torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)` """ + if output_type == "np": + deprecation_message = ( + "`get_2d_sincos_pos_embed_from_grid` uses `torch` and supports `device`." + " `from_numpy` is no longer required." + " Pass `output_type='pt' to use the new version now." + ) + deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False) + return get_2d_sincos_pos_embed_from_grid_np( + embed_dim=embed_dim, + grid=grid, + ) if embed_dim % 2 != 0: raise ValueError("embed_dim must be divisible by 2") # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0], output_type=output_type) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1], output_type=output_type) # (H*W, D/2) emb = torch.concat([emb_h, emb_w], dim=1) # (H*W, D) return emb -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"): """ This function generates 1D positional embeddings from a grid. @@ -226,6 +328,14 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): Returns: `torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`. """ + if output_type == "np": + deprecation_message = ( + "`get_1d_sincos_pos_embed_from_grid` uses `torch` and supports `device`." + " `from_numpy` is no longer required." + " Pass `output_type='pt' to use the new version now." + ) + deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False) + return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos) if embed_dim % 2 != 0: raise ValueError("embed_dim must be divisible by 2") @@ -243,6 +353,94 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): return emb +def get_2d_sincos_pos_embed_np( + embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 +): + """ + Creates 2D sinusoidal positional embeddings. + + Args: + embed_dim (`int`): + The embedding dimension. + grid_size (`int`): + The size of the grid height and width. + cls_token (`bool`, defaults to `False`): + Whether or not to add a classification token. + extra_tokens (`int`, defaults to `0`): + The number of extra tokens to add. + interpolation_scale (`float`, defaults to `1.0`): + The scale of the interpolation. + + Returns: + pos_embed (`np.ndarray`): + Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size, + embed_dim]` if using cls_token + """ + if isinstance(grid_size, int): + grid_size = (grid_size, grid_size) + + grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale + grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + pos_embed = get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid): + r""" + This function generates 2D sinusoidal positional embeddings from a grid. + + Args: + embed_dim (`int`): The embedding dimension. + grid (`np.ndarray`): Grid of positions with shape `(H * W,)`. + + Returns: + `np.ndarray`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)` + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid_np(embed_dim, pos): + """ + This function generates 1D positional embeddings from a grid. + + Args: + embed_dim (`int`): The embedding dimension `D` + pos (`numpy.ndarray`): 1D tensor of positions with shape `(M,)` + + Returns: + `numpy.ndarray`: Sinusoidal positional embeddings of shape `(M, D)`. + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding with support for SD3 cropping. @@ -305,7 +503,11 @@ def __init__( self.pos_embed = None elif pos_embed_type == "sincos": pos_embed = get_2d_sincos_pos_embed( - embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale + embed_dim, + grid_size, + base_size=self.base_size, + interpolation_scale=self.interpolation_scale, + output_type="pt", ) persistent = True if pos_embed_max_size else False self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=persistent) @@ -359,6 +561,7 @@ def forward(self, latent): base_size=self.base_size, interpolation_scale=self.interpolation_scale, device=latent.device, + output_type="pt", ) pos_embed = pos_embed.float().unsqueeze(0) else: @@ -486,6 +689,7 @@ def _get_positional_embeddings( self.spatial_interpolation_scale, self.temporal_interpolation_scale, device=device, + output_type="pt", ) pos_embedding = pos_embedding.flatten(0, 1) joint_pos_embedding = torch.zeros( @@ -575,7 +779,9 @@ def __init__( # Linear projection for text embeddings self.text_proj = nn.Linear(text_hidden_size, hidden_size) - pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size) + pos_embed = get_2d_sincos_pos_embed( + hidden_size, pos_embed_max_size, base_size=pos_embed_max_size, output_type="pt" + ) pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size) self.register_buffer("pos_embed", pos_embed.float(), persistent=False) diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index 44e4ef789a80..d34ccfd20108 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -156,7 +156,7 @@ def __init__( # define temporal positional embedding temp_pos_embed = get_1d_sincos_pos_embed_from_grid( - inner_dim, torch.arange(0, video_length).unsqueeze(1) + inner_dim, torch.arange(0, video_length).unsqueeze(1), output_type="pt" ) # 1152 hidden size self.register_buffer("temp_pos_embed", temp_pos_embed.float().unsqueeze(0), persistent=False) diff --git a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py index 23aba3ecb123..1e285a9670e2 100644 --- a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py +++ b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py @@ -104,7 +104,7 @@ def __init__( self.use_pos_embed = use_pos_embed if self.use_pos_embed: - pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5)) + pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5), output_type="pt") self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=False) def forward(self, latent): From e6730ce704b30c2a13b2a93276e61f400c9c70ac Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 13 Dec 2024 09:25:15 +0000 Subject: [PATCH 5/5] move deprecate, make private --- src/diffusers/models/embeddings.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index fc35b0c87e9b..b423c17c1246 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -109,13 +109,7 @@ def get_3d_sincos_pos_embed( embed_dim]`. """ if output_type == "np": - deprecation_message = ( - "`get_3d_sincos_pos_embed` uses `torch` and supports `device`." - " `from_numpy` is no longer required." - " Pass `output_type='pt' to use the new version now." - ) - deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False) - return get_3d_sincos_pos_embed_np( + return _get_3d_sincos_pos_embed_np( embed_dim=embed_dim, spatial_size=spatial_size, temporal_size=temporal_size, @@ -156,7 +150,7 @@ def get_3d_sincos_pos_embed( return pos_embed -def get_3d_sincos_pos_embed_np( +def _get_3d_sincos_pos_embed_np( embed_dim: int, spatial_size: Union[int, Tuple[int, int]], temporal_size: int, @@ -184,6 +178,12 @@ def get_3d_sincos_pos_embed_np( The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1], embed_dim]`. """ + deprecation_message = ( + "`get_3d_sincos_pos_embed` uses `torch` and supports `device`." + " `from_numpy` is no longer required." + " Pass `output_type='pt' to use the new version now." + ) + deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False) if embed_dim % 4 != 0: raise ValueError("`embed_dim` must be divisible by 4") if isinstance(spatial_size, int):