diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 390b752abe15..284f31702dd5 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -127,10 +127,7 @@ def get_3d_sincos_pos_embed( # 1. Spatial grid_h = torch.arange(spatial_size[1], device=device, dtype=torch.float32) / spatial_interpolation_scale grid_w = torch.arange(spatial_size[0], device=device, dtype=torch.float32) / spatial_interpolation_scale - grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first - grid = torch.stack(grid, dim=0) - - grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]]) + grid = torch.meshgrid(grid_h, grid_w, indexing="ij") pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid, output_type="pt") # 2. Temporal @@ -195,10 +192,7 @@ def _get_3d_sincos_pos_embed_np( # 1. Spatial grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) - - grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]]) + grid = np.meshgrid(grid_h, grid_w, indexing="ij") pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid) # 2. Temporal @@ -274,23 +268,25 @@ def get_2d_sincos_pos_embed( / (grid_size[1] / base_size) / interpolation_scale ) - grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first - grid = torch.stack(grid, dim=0) - - grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + grid = torch.meshgrid(grid_h, grid_w, indexing="ij") pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type=output_type) if cls_token and extra_tokens > 0: pos_embed = torch.concat([torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0) return pos_embed + def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"): r""" This function generates 2D sinusoidal positional embeddings from a grid. + Note: this function have been fixed to accept grid tuples with height grid first and width grid following. + In order to match the names of the `emb_h` and `emb_w` variables. + But this fix breaks the coherency to original (incorrect) implementation found in models like MAE, DiT, etc. + So take care to make sure you understand what the indices actually means. Args: embed_dim (`int`): The embedding dimension. - grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`. + grid (`Tuple[torch.Tensor, torch.Tensor]`): Two grids of positions each with shape `(H, W)`. Returns: `torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)` @@ -313,7 +309,9 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"): emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0], output_type=output_type) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1], output_type=output_type) # (H*W, D/2) - emb = torch.concat([emb_h, emb_w], dim=1) # (H*W, D) + # concatenate two embeddings as width first + # to match pre-trained model behavior + emb = torch.concat([emb_w, emb_h], dim=1) # (H*W, D) return emb @@ -381,10 +379,7 @@ def get_2d_sincos_pos_embed_np( grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) - - grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + grid = np.meshgrid(grid_h, grid_w, indexing="ij") pos_embed = get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid) if cls_token and extra_tokens > 0: pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) @@ -397,7 +392,7 @@ def get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid): Args: embed_dim (`int`): The embedding dimension. - grid (`np.ndarray`): Grid of positions with shape `(H * W,)`. + grid (`Tuple[np.ndarray, np.ndarray]`): Two grids of positions each with shape `(H, W)`. Returns: `np.ndarray`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)` @@ -409,7 +404,9 @@ def get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid): emb_h = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[1]) # (H*W, D/2) - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + # concatenate two embeddings as width first + # to match pre-trained model behavior + emb = np.concatenate([emb_w, emb_h], axis=1) # (H*W, D) return emb