Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 73 additions & 10 deletions src/diffusers/models/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,11 @@ class T2IAdapter(ModelMixin, ConfigMixin):
The number of channel of each downsample block's output hidden state. The `len(block_out_channels)` will
also determine the number of downsample blocks in the Adapter.
num_res_blocks (`int`, *optional*, defaults to 2):
Number of ResNet blocks in each downsample block
Number of ResNet blocks in each downsample block.
downscale_factor (`int`, *optional*, defaults to 8):
A factor that determines the total downscale factor of the Adapter.
adapter_type (`str`, *optional*, defaults to `full_adapter`):
The type of Adapter to use. Choose either `full_adapter` or `full_adapter_xl` or `light_adapter`.
"""

@register_to_config
Expand Down Expand Up @@ -275,6 +279,10 @@ def total_downscale_factor(self):


class FullAdapter(nn.Module):
r"""
See [`T2IAdapter`] for more information.
"""

def __init__(
self,
in_channels: int = 3,
Expand Down Expand Up @@ -321,6 +329,10 @@ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:


class FullAdapterXL(nn.Module):
r"""
See [`T2IAdapter`] for more information.
"""

def __init__(
self,
in_channels: int = 3,
Expand Down Expand Up @@ -367,7 +379,22 @@ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:


class AdapterBlock(nn.Module):
def __init__(self, in_channels, out_channels, num_res_blocks, down=False):
r"""
An AdapterBlock is a helper model that contains multiple ResNet-like blocks. It is used in the `FullAdapter` and
`FullAdapterXL` models.

Parameters:
in_channels (`int`):
Number of channels of AdapterBlock's input.
out_channels (`int`):
Number of channels of AdapterBlock's output.
num_res_blocks (`int`):
Number of ResNet blocks in the AdapterBlock.
down (`bool`, *optional*, defaults to `False`):
Whether to perform downsampling on AdapterBlock's input.
"""

def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
super().__init__()

self.downsample = None
Expand All @@ -382,7 +409,7 @@ def __init__(self, in_channels, out_channels, num_res_blocks, down=False):
*[AdapterResnetBlock(out_channels) for _ in range(num_res_blocks)],
)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""
This method takes tensor x as input and performs operations downsampling and convolutional layers if the
self.downsample and self.in_conv properties of AdapterBlock model are specified. Then it applies a series of
Expand All @@ -400,13 +427,21 @@ def forward(self, x):


class AdapterResnetBlock(nn.Module):
def __init__(self, channels):
r"""
An `AdapterResnetBlock` is a helper model that implements a ResNet-like block.

Parameters:
channels (`int`):
Number of channels of AdapterResnetBlock's input and output.
"""

def __init__(self, channels: int):
super().__init__()
self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.act = nn.ReLU()
self.block2 = nn.Conv2d(channels, channels, kernel_size=1)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""
This method takes input tensor x and applies a convolutional layer, ReLU activation, and another convolutional
layer on the input tensor. It returns addition with the input tensor.
Expand All @@ -423,6 +458,10 @@ def forward(self, x):


class LightAdapter(nn.Module):
r"""
See [`T2IAdapter`] for more information.
"""

def __init__(
self,
in_channels: int = 3,
Expand All @@ -449,7 +488,7 @@ def __init__(

self.total_downscale_factor = downscale_factor * (2 ** len(channels))

def forward(self, x):
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
r"""
This method takes the input tensor x and performs downscaling and appends it in list of feature tensors. Each
feature tensor corresponds to a different level of processing within the LightAdapter.
Expand All @@ -466,7 +505,22 @@ def forward(self, x):


class LightAdapterBlock(nn.Module):
def __init__(self, in_channels, out_channels, num_res_blocks, down=False):
r"""
A `LightAdapterBlock` is a helper model that contains multiple `LightAdapterResnetBlocks`. It is used in the
`LightAdapter` model.

Parameters:
in_channels (`int`):
Number of channels of LightAdapterBlock's input.
out_channels (`int`):
Number of channels of LightAdapterBlock's output.
num_res_blocks (`int`):
Number of LightAdapterResnetBlocks in the LightAdapterBlock.
down (`bool`, *optional*, defaults to `False`):
Whether to perform downsampling on LightAdapterBlock's input.
"""

def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
super().__init__()
mid_channels = out_channels // 4

Expand All @@ -478,7 +532,7 @@ def __init__(self, in_channels, out_channels, num_res_blocks, down=False):
self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)])
self.out_conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""
This method takes tensor x as input and performs downsampling if required. Then it applies in convolution
layer, a sequence of residual blocks, and out convolutional layer.
Expand All @@ -494,13 +548,22 @@ def forward(self, x):


class LightAdapterResnetBlock(nn.Module):
def __init__(self, channels):
"""
A `LightAdapterResnetBlock` is a helper model that implements a ResNet-like block with a slightly different
architecture than `AdapterResnetBlock`.

Parameters:
channels (`int`):
Number of channels of LightAdapterResnetBlock's input and output.
"""

def __init__(self, channels: int):
super().__init__()
self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.act = nn.ReLU()
self.block2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""
This function takes input tensor x and processes it through one convolutional layer, ReLU activation, and
another convolutional layer and adds it to input tensor.
Expand Down
95 changes: 72 additions & 23 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple

import torch
import torch.nn.functional as F
Expand All @@ -26,7 +26,17 @@

@maybe_allow_in_graph
class GatedSelfAttentionDense(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, d_head):
r"""
A gated self-attention dense layer that combines visual features and object features.

Parameters:
query_dim (`int`): The number of channels in the query.
context_dim (`int`): The number of channels in the context.
n_heads (`int`): The number of heads to use for attention.
d_head (`int`): The number of channels in each head.
"""

def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
super().__init__()

# we need a linear projection since we need cat visual feature and obj feature
Expand All @@ -43,7 +53,7 @@ def __init__(self, query_dim, context_dim, n_heads, d_head):

self.enabled = True

def forward(self, x, objs):
def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
if not self.enabled:
return x

Expand All @@ -67,15 +77,25 @@ class BasicTransformerBlock(nn.Module):
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
"""

def __init__(
Expand Down Expand Up @@ -175,7 +195,7 @@ def forward(
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
):
) -> torch.FloatTensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
if self.use_ada_layer_norm:
Expand Down Expand Up @@ -301,7 +321,7 @@ def __init__(
if final_dropout:
self.net.append(nn.Dropout(dropout))

def forward(self, hidden_states, scale: float = 1.0):
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
for module in self.net:
if isinstance(module, (LoRACompatibleLinear, GEGLU)):
hidden_states = module(hidden_states, scale)
Expand All @@ -313,14 +333,19 @@ def forward(self, hidden_states, scale: float = 1.0):
class GELU(nn.Module):
r"""
GELU activation function with tanh approximation support with `approximate="tanh"`.

Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
"""

def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out)
self.approximate = approximate

def gelu(self, gate):
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps":
return F.gelu(gate, approximate=self.approximate)
# mps: gelu is not implemented for float16
Expand All @@ -345,7 +370,7 @@ def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = LoRACompatibleLinear(dim_in, dim_out * 2)

def gelu(self, gate):
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps":
return F.gelu(gate)
# mps: gelu is not implemented for float16
Expand All @@ -357,46 +382,57 @@ def forward(self, hidden_states, scale: float = 1.0):


class ApproximateGELU(nn.Module):
"""
The approximate form of Gaussian Error Linear Unit (GELU)
r"""
The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2:
https://arxiv.org/abs/1606.08415.

For more details, see section 2: https://arxiv.org/abs/1606.08415
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
"""

def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
return x * torch.sigmoid(1.702 * x)


class AdaLayerNorm(nn.Module):
"""
r"""
Norm layer modified to incorporate timestep embeddings.

Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the dictionary of embeddings.
"""

def __init__(self, embedding_dim, num_embeddings):
def __init__(self, embedding_dim: int, num_embeddings: int):
super().__init__()
self.emb = nn.Embedding(num_embeddings, embedding_dim)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)

def forward(self, x, timestep):
def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
emb = self.linear(self.silu(self.emb(timestep)))
scale, shift = torch.chunk(emb, 2)
x = self.norm(x) * (1 + scale) + shift
return x


class AdaLayerNormZero(nn.Module):
"""
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).

Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the dictionary of embeddings.
"""

def __init__(self, embedding_dim, num_embeddings):
def __init__(self, embedding_dim: int, num_embeddings: int):
super().__init__()

self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
Expand All @@ -405,16 +441,29 @@ def __init__(self, embedding_dim, num_embeddings):
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)

def forward(self, x, timestep, class_labels, hidden_dtype=None):
def forward(
self,
x: torch.Tensor,
timestep: torch.Tensor,
class_labels: torch.LongTensor,
hidden_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp


class AdaGroupNorm(nn.Module):
"""
r"""
GroupNorm layer modified to incorporate timestep embeddings.

Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the dictionary of embeddings.
num_groups (`int`): The number of groups to separate the channels into.
act_fn (`str`, *optional*, defaults to `None`): The activation function to use.
eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability.
"""

def __init__(
Expand All @@ -431,7 +480,7 @@ def __init__(

self.linear = nn.Linear(embedding_dim, out_dim * 2)

def forward(self, x, emb):
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
if self.act:
emb = self.act(emb)
emb = self.linear(emb)
Expand Down
Loading