Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 17 additions & 20 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,7 @@ def get_3d_sincos_pos_embed(
# 1. Spatial
grid_h = torch.arange(spatial_size[1], device=device, dtype=torch.float32) / spatial_interpolation_scale
grid_w = torch.arange(spatial_size[0], device=device, dtype=torch.float32) / spatial_interpolation_scale
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
grid = torch.stack(grid, dim=0)

grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
grid = torch.meshgrid(grid_h, grid_w, indexing="ij")
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid, output_type="pt")

# 2. Temporal
Expand Down Expand Up @@ -195,10 +192,7 @@ def _get_3d_sincos_pos_embed_np(
# 1. Spatial
grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale
grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)

grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
grid = np.meshgrid(grid_h, grid_w, indexing="ij")
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)

# 2. Temporal
Expand Down Expand Up @@ -274,23 +268,25 @@ def get_2d_sincos_pos_embed(
/ (grid_size[1] / base_size)
/ interpolation_scale
)
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
grid = torch.stack(grid, dim=0)

grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
grid = torch.meshgrid(grid_h, grid_w, indexing="ij")
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type=output_type)
if cls_token and extra_tokens > 0:
pos_embed = torch.concat([torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0)
return pos_embed



def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
r"""
This function generates 2D sinusoidal positional embeddings from a grid.
Note: this function have been fixed to accept grid tuples with height grid first and width grid following.
In order to match the names of the `emb_h` and `emb_w` variables.
But this fix breaks the coherency to original (incorrect) implementation found in models like MAE, DiT, etc.
So take care to make sure you understand what the indices actually means.

Args:
embed_dim (`int`): The embedding dimension.
grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`.
grid (`Tuple[torch.Tensor, torch.Tensor]`): Two grids of positions each with shape `(H, W)`.

Returns:
`torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
Expand All @@ -313,7 +309,9 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0], output_type=output_type) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1], output_type=output_type) # (H*W, D/2)

emb = torch.concat([emb_h, emb_w], dim=1) # (H*W, D)
# concatenate two embeddings as width first
# to match pre-trained model behavior
emb = torch.concat([emb_w, emb_h], dim=1) # (H*W, D)
return emb


Expand Down Expand Up @@ -381,10 +379,7 @@ def get_2d_sincos_pos_embed_np(

grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)

grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
grid = np.meshgrid(grid_h, grid_w, indexing="ij")
pos_embed = get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
Expand All @@ -397,7 +392,7 @@ def get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid):

Args:
embed_dim (`int`): The embedding dimension.
grid (`np.ndarray`): Grid of positions with shape `(H * W,)`.
grid (`Tuple[np.ndarray, np.ndarray]`): Two grids of positions each with shape `(H, W)`.

Returns:
`np.ndarray`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
Expand All @@ -409,7 +404,9 @@ def get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid):
emb_h = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[1]) # (H*W, D/2)

emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
# concatenate two embeddings as width first
# to match pre-trained model behavior
emb = np.concatenate([emb_w, emb_h], axis=1) # (H*W, D)
return emb


Expand Down