From 3f8a55c7137775f3c50fe2be8ab0140aa656540e Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 20 Sep 2022 17:08:38 +0000 Subject: [PATCH 01/16] documenting `attention_flax.py` file --- src/diffusers/models/attention_flax.py | 81 ++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 525be4818dcc..59919b977856 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -17,6 +17,27 @@ class FlaxAttentionBlock(nn.Module): + r""" + A multi-head attention as described in: https://arxiv.org/abs/1706.03762 + + This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + general usage and behavior. + + Finally, this model supports inherent JAX features such as: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + query_dim (:obj:`int`): hidden states dimension + heads (:obj:`int`, *optional*, defaults to 8): Number of heads + dim_head (:obj:`int`, *optional*, defaults to 64): hidden states dimension inside each head + dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` + + """ query_dim: int heads: int = 8 dim_head: int = 64 @@ -74,6 +95,27 @@ def __call__(self, hidden_states, context=None, deterministic=True): class FlaxBasicTransformerBlock(nn.Module): + r""" + A transformer block layer with `GLU` (Gated Linear Unit) activation function as described in: + https://arxiv.org/abs/1706.03762 + + This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + general usage and behavior. + + Finally, this model supports inherent JAX features such as: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + dim (:obj:`int`): Inner hidden states dimension + n_heads (:obj:`int`): Number of heads + d_head (:obj:`int`): Hidden states dimension inside each head + dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` + """ dim: int n_heads: int d_head: int @@ -110,6 +152,28 @@ def __call__(self, hidden_states, context, deterministic=True): class FlaxSpatialTransformer(nn.Module): + r""" + A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in: + https://arxiv.org/pdf/1506.02025.pdf + + This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + general usage and behavior. + + Finally, this model supports inherent JAX features such as: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + in_channels (:obj:`int`): Input number of channels + n_heads (:obj:`int`): Number of heads + d_head (:obj:`int`): Hidden states dimension inside each head + depth (:obj:`int`, *optional*, defaults to 1): Number of transformers block + dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` + """ in_channels: int n_heads: int d_head: int @@ -163,6 +227,15 @@ def __call__(self, hidden_states, context, deterministic=True): class FlaxGluFeedForward(nn.Module): + r""" + Flax module that encapsulates two Linear layers separated by a gated linear unit activation from: + https://arxiv.org/abs/2002.05202 + + Parameters: + dim (:obj:`int`): Inner hidden states dimension + dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` + """ dim: int dropout: float = 0.0 dtype: jnp.dtype = jnp.float32 @@ -180,6 +253,14 @@ def __call__(self, hidden_states, deterministic=True): class FlaxGEGLU(nn.Module): + r""" + Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from + https://arxiv.org/abs/2002.05202. + + arameters: + dim (:obj:`int`): Input hidden states dimension dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout + rate dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` + """ dim: int dropout: float = 0.0 dtype: jnp.dtype = jnp.float32 From f4ecb28e22f702822597ae2d67e130d0a78fdf87 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 20 Sep 2022 17:24:28 +0000 Subject: [PATCH 02/16] documenting `embeddings_flax.py` --- src/diffusers/models/embeddings_flax.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/diffusers/models/embeddings_flax.py b/src/diffusers/models/embeddings_flax.py index 63442ab997b4..521347b6efc8 100644 --- a/src/diffusers/models/embeddings_flax.py +++ b/src/diffusers/models/embeddings_flax.py @@ -37,6 +37,13 @@ def get_sinusoidal_embeddings(timesteps, embedding_dim): class FlaxTimestepEmbedding(nn.Module): + r""" + Time step Embedding Module. Learns embeddings for input time steps. + + Args: + time_embed_dim (`int`, *optional*, defaults to `32`): time step embedding dimension + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` + """ time_embed_dim: int = 32 dtype: jnp.dtype = jnp.float32 @@ -49,6 +56,12 @@ def __call__(self, temb): class FlaxTimesteps(nn.Module): + r""" + Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239 + + Args: + dim (`int`, *optional*, defaults to `32`): time step embedding dimension + """ dim: int = 32 @nn.compact From a840591e58b99f05b0d680101760a8eb32ab0150 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 20 Sep 2022 17:37:17 +0000 Subject: [PATCH 03/16] documenting `unet_blocks_flax.py` --- src/diffusers/models/unet_blocks_flax.py | 66 ++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/src/diffusers/models/unet_blocks_flax.py b/src/diffusers/models/unet_blocks_flax.py index fe6cc3b194e3..2c61c4fcd93e 100644 --- a/src/diffusers/models/unet_blocks_flax.py +++ b/src/diffusers/models/unet_blocks_flax.py @@ -19,6 +19,21 @@ class FlaxCrossAttnDownBlock2D(nn.Module): + r""" + Cross Attention 2D Downsizing block - original architecture from Unet transformers: + https://arxiv.org/abs/2103.06104 + + Parameters: + in_channels (:obj:`int`): Input channels + out_channels (:obj:`int`): Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers + attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): + Number of attention heads of each spatial transformer block + add_downsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add downsampling layer before each final output + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` + """ in_channels: int out_channels: int dropout: float = 0.0 @@ -73,6 +88,18 @@ def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=Tru class FlaxDownBlock2D(nn.Module): + r""" + Flax 2D downsizing block + + Parameters: + in_channels (:obj:`int`): Input channels + out_channels (:obj:`int`): Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers + add_downsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add downsampling layer before each final output + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` + """ in_channels: int out_channels: int dropout: float = 0.0 @@ -113,6 +140,21 @@ def __call__(self, hidden_states, temb, deterministic=True): class FlaxCrossAttnUpBlock2D(nn.Module): + r""" + Cross Attention 2D Upsampling block - original architecture from Unet transformers: + https://arxiv.org/abs/2103.06104 + + Parameters: + in_channels (:obj:`int`): Input channels + out_channels (:obj:`int`): Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers + attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): + Number of attention heads of each spatial transformer block + add_upsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add upsampling layer before each final output + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` + """ in_channels: int out_channels: int prev_output_channel: int @@ -170,6 +212,19 @@ def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_ class FlaxUpBlock2D(nn.Module): + r""" + Flax 2D upsampling block + + Parameters: + in_channels (:obj:`int`): Input channels + out_channels (:obj:`int`): Output channels + prev_output_channel (:obj:`int`): Output channels from the previous block + dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers + add_downsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add downsampling layer before each final output + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` + """ in_channels: int out_channels: int prev_output_channel: int @@ -214,6 +269,17 @@ def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=T class FlaxUNetMidBlock2DCrossAttn(nn.Module): + r""" + Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104 + + Parameters: + in_channels (:obj:`int`): Input channels + dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers + attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): + Number of attention heads of each spatial transformer block + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` + """ in_channels: int dropout: float = 0.0 num_layers: int = 1 From ea8cbdb6ec839bd4ed1510fdefc32df545643c68 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Tue, 20 Sep 2022 18:35:09 +0000 Subject: [PATCH 04/16] Add new objs to doc page --- docs/source/api/models.mdx | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/docs/source/api/models.mdx b/docs/source/api/models.mdx index 525548e7c302..c92fdccb8333 100644 --- a/docs/source/api/models.mdx +++ b/docs/source/api/models.mdx @@ -45,3 +45,21 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module ## AutoencoderKL [[autodoc]] AutoencoderKL + +## FlaxModelMixin +[[autodoc]] FlaxModelMixin + +## FlaxUNet2DConditionOutput +[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput + +## FlaxUNet2DConditionModel +[[autodoc]] FlaxUNet2DConditionModel + +## FlaxDecoderOutput +[[autodoc]] models.vae_flax.FlaxDecoderOutput + +## FlaxAutoencoderKLOutput +[[autodoc]] models.vae_flax.FlaxAutoencoderKLOutput + +## FlaxAutoencoderKL +[[autodoc]] FlaxAutoencoderKL From 1341223f178f5344a7d914904f741e0a7a99c682 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 20 Sep 2022 20:38:54 +0000 Subject: [PATCH 05/16] document `vae_flax.py` --- src/diffusers/models/vae_flax.py | 202 ++++++++++++++++++++++++++++++- 1 file changed, 201 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py index eba9259b8201..df64dd3d32ac 100644 --- a/src/diffusers/models/vae_flax.py +++ b/src/diffusers/models/vae_flax.py @@ -23,6 +23,7 @@ class FlaxDecoderOutput(BaseOutput): Args: sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): Decoded output sample of the model. Output of the last layer of the model. + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` """ sample: jnp.ndarray @@ -43,6 +44,15 @@ class FlaxAutoencoderKLOutput(BaseOutput): class FlaxUpsample2D(nn.Module): + """ + Flax implementation of 2D Upsample layer + + Args: + in_channels (`int`): + Input channels + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` + """ + in_channels: int dtype: jnp.dtype = jnp.float32 @@ -67,6 +77,15 @@ def __call__(self, hidden_states): class FlaxDownsample2D(nn.Module): + """ + Flax implementation of 2D Downsample layer + + Args: + in_channels (`int`): + Input channels + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` + """ + in_channels: int dtype: jnp.dtype = jnp.float32 @@ -87,6 +106,22 @@ def __call__(self, hidden_states): class FlaxResnetBlock2D(nn.Module): + """ + Flax implementation of 2D Resnet Block. + + Args: + in_channels (`int`): + Input channels + out_channels (`int`): + Output channels + dropout_prob (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`): + Whether to use `nin_shortcut`. This activates a new layer inside ResNet block + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + parameters `dtype` + """ + in_channels: int out_channels: int = None dropout_prob: float = 0.0 @@ -145,6 +180,18 @@ def __call__(self, hidden_states, deterministic=True): class FlaxAttentionBlock(nn.Module): + r""" + Flax Convolutional based multi-head attention block for diffusion-based VAE. + + Parameters: + channels (:obj:`int`): + Input channels + num_head_channels (:obj:`int`, *optional*, defaults to `None`): + Number of attention heads + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + parameters `dtype` + + """ channels: int num_head_channels: int = None dtype: jnp.dtype = jnp.float32 @@ -202,6 +249,23 @@ def __call__(self, hidden_states): class FlaxDownEncoderBlock2D(nn.Module): + r""" + Flax Resnet blocks-based Encoder block for diffusion-based VAE. + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of Resnet layer block + add_downsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add downsample layer + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + parameters `dtype` + """ in_channels: int out_channels: int dropout: float = 0.0 @@ -237,6 +301,23 @@ def __call__(self, hidden_states, deterministic=True): class FlaxUpEncoderBlock2D(nn.Module): + r""" + Flax Resnet blocks-based Encoder block for diffusion-based VAE. + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of Resnet layer block + add_downsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add downsample layer + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + parameters `dtype` + """ in_channels: int out_channels: int dropout: float = 0.0 @@ -272,6 +353,21 @@ def __call__(self, hidden_states, deterministic=True): class FlaxUNetMidBlock2D(nn.Module): + r""" + Flax Unet Mid-Block module. + + Parameters: + in_channels (:obj:`int`): + Input channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of Resnet layer block + attn_num_head_channels (:obj:`int`, *optional*, defaults to `1`): + Number of attention heads for each attention block + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + parameters `dtype` + """ in_channels: int dropout: float = 0.0 num_layers: int = 1 @@ -318,6 +414,39 @@ def __call__(self, hidden_states, deterministic=True): class FlaxEncoder(nn.Module): + r""" + Flax Implementation of VAE Encoder. + + This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + general usage and behavior. + + Finally, this model supports inherent JAX features such as: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + in_channels (:obj:`int`, *optional*, defaults to 3): + Input channels + out_channels (:obj:`int`, *optional*, defaults to 3): + Output channels + down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`): + DownEncoder block type + block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): + Tuple containing the number of output channels for each block + layers_per_block (:obj:`int`, *optional*, defaults to `2`): + Number of Resnet layer for each block + norm_num_groups (:obj:`int`, *optional*, defaults to `2`): + norm num group + act_fn (:obj:`str`, *optional*, defaults to `silu`): + Activation function + double_z (:obj:`bool`, *optional*, defaults to `False`): + Whether to double the last output channels + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + parameters `dtype` + """ in_channels: int = 3 out_channels: int = 3 down_block_types: Tuple[str] = ("DownEncoderBlock2D",) @@ -393,7 +522,39 @@ def __call__(self, sample, deterministic: bool = True): class FlaxDecoder(nn.Module): - dtype: jnp.dtype = jnp.float32 + r""" + Flax Implementation of VAE Decoder. + + This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + general usage and behavior. + + Finally, this model supports inherent JAX features such as: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + in_channels (:obj:`int`, *optional*, defaults to 3): + Input channels + out_channels (:obj:`int`, *optional*, defaults to 3): + Output channels + up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`): + UpDecoder block type + block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): + Tuple containing the number of output channels for each block + layers_per_block (:obj:`int`, *optional*, defaults to `2`): + Number of Resnet layer for each block + norm_num_groups (:obj:`int`, *optional*, defaults to `2`): + norm num group + act_fn (:obj:`str`, *optional*, defaults to `silu`): + Activation function + double_z (:obj:`bool`, *optional*, defaults to `False`): + Whether to double the last output channels + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + parameters `dtype` + """ in_channels: int = 3 out_channels: int = 3 up_block_types: Tuple[str] = ("UpDecoderBlock2D",) @@ -401,6 +562,7 @@ class FlaxDecoder(nn.Module): layers_per_block: int = 2 norm_num_groups: int = 32 act_fn: str = "silu" + dtype: jnp.dtype = jnp.float32 def setup(self): block_out_channels = self.block_out_channels @@ -508,6 +670,44 @@ def mode(self): @flax_register_to_config class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): + r""" + Flax Implementation of Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational + Bayes by Diederik P. Kingma and Max Welling. + + This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + general usage and behavior. + + Finally, this model supports inherent JAX features such as: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + in_channels (:obj:`int`, *optional*, defaults to 3): + Input channels + out_channels (:obj:`int`, *optional*, defaults to 3): + Output channels + down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`): + DownEncoder block type + up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`): + UpDecoder block type + block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): + Tuple containing the number of output channels for each block + layers_per_block (:obj:`int`, *optional*, defaults to `2`): + Number of Resnet layer for each block + act_fn (:obj:`str`, *optional*, defaults to `silu`): + Activation function + latent_channels (:obj:`int`, *optional*, defaults to `4`): + Latent space channels + norm_num_groups (:obj:`int`, *optional*, defaults to `2`): + Norm num group + sample_size (:obj:`int`, *optional*, defaults to `32`): + Sample input size + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + parameters `dtype` + """ in_channels: int = 3 out_channels: int = 3 down_block_types: Tuple[str] = ("DownEncoderBlock2D",) From 7b97067e142035531af57d9c71ef02e1f6947ea3 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 20 Sep 2022 22:53:15 +0200 Subject: [PATCH 06/16] Apply suggestions from code review --- src/diffusers/models/unet_blocks_flax.py | 75 ++++++++++++++++-------- src/diffusers/models/vae_flax.py | 21 ++++--- 2 files changed, 62 insertions(+), 34 deletions(-) diff --git a/src/diffusers/models/unet_blocks_flax.py b/src/diffusers/models/unet_blocks_flax.py index 2c61c4fcd93e..39d10c961710 100644 --- a/src/diffusers/models/unet_blocks_flax.py +++ b/src/diffusers/models/unet_blocks_flax.py @@ -24,15 +24,20 @@ class FlaxCrossAttnDownBlock2D(nn.Module): https://arxiv.org/abs/2103.06104 Parameters: - in_channels (:obj:`int`): Input channels - out_channels (:obj:`int`): Output channels - dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate - num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): Number of attention heads of each spatial transformer block add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsampling layer before each final output - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` """ in_channels: int out_channels: int @@ -92,13 +97,18 @@ class FlaxDownBlock2D(nn.Module): Flax 2D downsizing block Parameters: - in_channels (:obj:`int`): Input channels - out_channels (:obj:`int`): Output channels - dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate - num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsampling layer before each final output - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` """ in_channels: int out_channels: int @@ -145,15 +155,20 @@ class FlaxCrossAttnUpBlock2D(nn.Module): https://arxiv.org/abs/2103.06104 Parameters: - in_channels (:obj:`int`): Input channels - out_channels (:obj:`int`): Output channels - dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate - num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): Number of attention heads of each spatial transformer block add_upsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add upsampling layer before each final output - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` """ in_channels: int out_channels: int @@ -216,14 +231,20 @@ class FlaxUpBlock2D(nn.Module): Flax 2D upsampling block Parameters: - in_channels (:obj:`int`): Input channels - out_channels (:obj:`int`): Output channels - prev_output_channel (:obj:`int`): Output channels from the previous block - dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate - num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + prev_output_channel (:obj:`int`): + Output channels from the previous block + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsampling layer before each final output - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` """ in_channels: int out_channels: int @@ -273,12 +294,16 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104 Parameters: - in_channels (:obj:`int`): Input channels - dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate - num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers + in_channels (:obj:`int`): + Input channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): Number of attention heads of each spatial transformer block - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` """ in_channels: int dropout: float = 0.0 diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py index df64dd3d32ac..ea7ba6cc7cf7 100644 --- a/src/diffusers/models/vae_flax.py +++ b/src/diffusers/models/vae_flax.py @@ -23,7 +23,8 @@ class FlaxDecoderOutput(BaseOutput): Args: sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): Decoded output sample of the model. Output of the last layer of the model. - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` """ sample: jnp.ndarray @@ -50,7 +51,8 @@ class FlaxUpsample2D(nn.Module): Args: in_channels (`int`): Input channels - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` """ in_channels: int @@ -83,7 +85,8 @@ class FlaxDownsample2D(nn.Module): Args: in_channels (`int`): Input channels - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` """ in_channels: int @@ -119,7 +122,7 @@ class FlaxResnetBlock2D(nn.Module): use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`): Whether to use `nin_shortcut`. This activates a new layer inside ResNet block dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - parameters `dtype` + Parameters `dtype` """ in_channels: int @@ -189,7 +192,7 @@ class FlaxAttentionBlock(nn.Module): num_head_channels (:obj:`int`, *optional*, defaults to `None`): Number of attention heads dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - parameters `dtype` + Parameters `dtype` """ channels: int @@ -264,7 +267,7 @@ class FlaxDownEncoderBlock2D(nn.Module): add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsample layer dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - parameters `dtype` + Parameters `dtype` """ in_channels: int out_channels: int @@ -316,7 +319,7 @@ class FlaxUpEncoderBlock2D(nn.Module): add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsample layer dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - parameters `dtype` + Parameters `dtype` """ in_channels: int out_channels: int @@ -366,7 +369,7 @@ class FlaxUNetMidBlock2D(nn.Module): attn_num_head_channels (:obj:`int`, *optional*, defaults to `1`): Number of attention heads for each attention block dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - parameters `dtype` + Parameters `dtype` """ in_channels: int dropout: float = 0.0 @@ -445,7 +448,7 @@ class FlaxEncoder(nn.Module): double_z (:obj:`bool`, *optional*, defaults to `False`): Whether to double the last output channels dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - parameters `dtype` + Parameters `dtype` """ in_channels: int = 3 out_channels: int = 3 From edc00f0c82f1efd3a4abf5677a14734c87fe73df Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 20 Sep 2022 20:55:56 +0000 Subject: [PATCH 07/16] modify `unet_2d_condition_flax.py` --- .../models/unet_2d_condition_flax.py | 31 ++++++++++++++----- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index d0fcd9f6ae13..bfa326d169ac 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -39,10 +39,23 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library implements for all the models (such as downloading or saving, etc.) + Also, this model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + general usage and behavior. + + Finally, this model supports inherent JAX features such as: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + Parameters: - sample_size (`int`, *optional*): The size of the input sample. - in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. - out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + sample_size (`int`, *optional*): + The size of the input sample. + in_channels (`int`, *optional*, defaults to 4): + The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): + The number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D" @@ -51,10 +64,14 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D" block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. - attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. - cross_attention_dim (`int`, *optional*, defaults to 768): The dimension of the cross attention features. - dropout (`float`, *optional*, defaults to 0): Dropout probability for down, up and bottleneck blocks. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + attention_head_dim (`int`, *optional*, defaults to 8): + The dimension of the attention heads. + cross_attention_dim (`int`, *optional*, defaults to 768): + The dimension of the cross attention features. + dropout (`float`, *optional*, defaults to 0): + Dropout probability for down, up and bottleneck blocks. """ sample_size: int = 32 From dae3064a5f7e3038dd2ac910878b55b4e7dc35fd Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 20 Sep 2022 20:57:15 +0000 Subject: [PATCH 08/16] make style --- src/diffusers/models/unet_blocks_flax.py | 50 ++++++++++++------------ src/diffusers/models/vae_flax.py | 6 +-- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/src/diffusers/models/unet_blocks_flax.py b/src/diffusers/models/unet_blocks_flax.py index 39d10c961710..a4b358fd3ca3 100644 --- a/src/diffusers/models/unet_blocks_flax.py +++ b/src/diffusers/models/unet_blocks_flax.py @@ -24,19 +24,19 @@ class FlaxCrossAttnDownBlock2D(nn.Module): https://arxiv.org/abs/2103.06104 Parameters: - in_channels (:obj:`int`): + in_channels (:obj:`int`): Input channels - out_channels (:obj:`int`): + out_channels (:obj:`int`): Output channels - dropout (:obj:`float`, *optional*, defaults to 0.0): + dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate - num_layers (:obj:`int`, *optional*, defaults to 1): + num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): Number of attention heads of each spatial transformer block add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsampling layer before each final output - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int @@ -97,17 +97,17 @@ class FlaxDownBlock2D(nn.Module): Flax 2D downsizing block Parameters: - in_channels (:obj:`int`): + in_channels (:obj:`int`): Input channels - out_channels (:obj:`int`): + out_channels (:obj:`int`): Output channels - dropout (:obj:`float`, *optional*, defaults to 0.0): + dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate - num_layers (:obj:`int`, *optional*, defaults to 1): + num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsampling layer before each final output - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int @@ -155,19 +155,19 @@ class FlaxCrossAttnUpBlock2D(nn.Module): https://arxiv.org/abs/2103.06104 Parameters: - in_channels (:obj:`int`): + in_channels (:obj:`int`): Input channels - out_channels (:obj:`int`): + out_channels (:obj:`int`): Output channels - dropout (:obj:`float`, *optional*, defaults to 0.0): + dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate - num_layers (:obj:`int`, *optional*, defaults to 1): + num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): Number of attention heads of each spatial transformer block add_upsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add upsampling layer before each final output - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int @@ -231,19 +231,19 @@ class FlaxUpBlock2D(nn.Module): Flax 2D upsampling block Parameters: - in_channels (:obj:`int`): + in_channels (:obj:`int`): Input channels - out_channels (:obj:`int`): + out_channels (:obj:`int`): Output channels - prev_output_channel (:obj:`int`): + prev_output_channel (:obj:`int`): Output channels from the previous block - dropout (:obj:`float`, *optional*, defaults to 0.0): + dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate - num_layers (:obj:`int`, *optional*, defaults to 1): + num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsampling layer before each final output - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int @@ -294,15 +294,15 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104 Parameters: - in_channels (:obj:`int`): + in_channels (:obj:`int`): Input channels - dropout (:obj:`float`, *optional*, defaults to 0.0): + dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate - num_layers (:obj:`int`, *optional*, defaults to 1): + num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): Number of attention heads of each spatial transformer block - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py index ea7ba6cc7cf7..10019e091167 100644 --- a/src/diffusers/models/vae_flax.py +++ b/src/diffusers/models/vae_flax.py @@ -23,7 +23,7 @@ class FlaxDecoderOutput(BaseOutput): Args: sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): Decoded output sample of the model. Output of the last layer of the model. - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ @@ -51,7 +51,7 @@ class FlaxUpsample2D(nn.Module): Args: in_channels (`int`): Input channels - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ @@ -85,7 +85,7 @@ class FlaxDownsample2D(nn.Module): Args: in_channels (`int`): Input channels - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ From 89cd2f3efbdcfb130ad0f79c19d15a1eca50c9aa Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 20 Sep 2022 22:58:13 +0200 Subject: [PATCH 09/16] Apply suggestions from code review --- src/diffusers/models/attention_flax.py | 93 ++++++++++++------------- src/diffusers/models/embeddings_flax.py | 9 ++- 2 files changed, 50 insertions(+), 52 deletions(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index fdc12bf92e1f..ddb267867d6b 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -20,22 +20,17 @@ class FlaxAttentionBlock(nn.Module): r""" A multi-head attention as described in: https://arxiv.org/abs/1706.03762 - This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) - subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to - general usage and behavior. - - Finally, this model supports inherent JAX features such as: - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - Parameters: - query_dim (:obj:`int`): hidden states dimension - heads (:obj:`int`, *optional*, defaults to 8): Number of heads - dim_head (:obj:`int`, *optional*, defaults to 64): hidden states dimension inside each head - dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` + query_dim (:obj:`int`): + Input hidden states dimension + heads (:obj:`int`, *optional*, defaults to 8): + Number of heads + dim_head (:obj:`int`, *optional*, defaults to 64): + Hidden states dimension inside each head + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` """ query_dim: int @@ -99,22 +94,18 @@ class FlaxBasicTransformerBlock(nn.Module): A transformer block layer with `GLU` (Gated Linear Unit) activation function as described in: https://arxiv.org/abs/1706.03762 - This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) - subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to - general usage and behavior. - - Finally, this model supports inherent JAX features such as: - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) Parameters: - dim (:obj:`int`): Inner hidden states dimension - n_heads (:obj:`int`): Number of heads - d_head (:obj:`int`): Hidden states dimension inside each head - dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` + dim (:obj:`int`): + Inner hidden states dimension + n_heads (:obj:`int`): + Number of heads + d_head (:obj:`int`): + Hidden states dimension inside each head + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` """ dim: int n_heads: int @@ -156,23 +147,20 @@ class FlaxSpatialTransformer(nn.Module): A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in: https://arxiv.org/pdf/1506.02025.pdf - This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) - subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to - general usage and behavior. - - Finally, this model supports inherent JAX features such as: - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) Parameters: - in_channels (:obj:`int`): Input number of channels - n_heads (:obj:`int`): Number of heads - d_head (:obj:`int`): Hidden states dimension inside each head - depth (:obj:`int`, *optional*, defaults to 1): Number of transformers block - dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` + in_channels (:obj:`int`): + Input number of channels + n_heads (:obj:`int`): + Number of heads + d_head (:obj:`int`): + Hidden states dimension inside each head + depth (:obj:`int`, *optional*, defaults to 1): + Number of transformers block + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` """ in_channels: int n_heads: int @@ -231,9 +219,12 @@ class FlaxGluFeedForward(nn.Module): https://arxiv.org/abs/2002.05202 Parameters: - dim (:obj:`int`): Inner hidden states dimension - dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` + dim (:obj:`int`): + Inner hidden states dimension + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` """ dim: int dropout: float = 0.0 @@ -257,8 +248,12 @@ class FlaxGEGLU(nn.Module): https://arxiv.org/abs/2002.05202. arameters: - dim (:obj:`int`): Input hidden states dimension dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout - rate dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` + dim (:obj:`int`): + Input hidden states dimension + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` """ dim: int dropout: float = 0.0 diff --git a/src/diffusers/models/embeddings_flax.py b/src/diffusers/models/embeddings_flax.py index 521347b6efc8..ef175b2a460c 100644 --- a/src/diffusers/models/embeddings_flax.py +++ b/src/diffusers/models/embeddings_flax.py @@ -41,8 +41,10 @@ class FlaxTimestepEmbedding(nn.Module): Time step Embedding Module. Learns embeddings for input time steps. Args: - time_embed_dim (`int`, *optional*, defaults to `32`): time step embedding dimension - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` + time_embed_dim (`int`, *optional*, defaults to `32`): + Time step embedding dimension + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` """ time_embed_dim: int = 32 dtype: jnp.dtype = jnp.float32 @@ -60,7 +62,8 @@ class FlaxTimesteps(nn.Module): Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239 Args: - dim (`int`, *optional*, defaults to `32`): time step embedding dimension + dim (`int`, *optional*, defaults to `32`): + Time step embedding dimension """ dim: int = 32 From ce3569708b9a7cfe4f2426d7bf63af523e5bd731 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 20 Sep 2022 20:58:27 +0000 Subject: [PATCH 10/16] make style --- src/diffusers/models/attention_flax.py | 46 ++++++++++++------------- src/diffusers/models/embeddings_flax.py | 6 ++-- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index ddb267867d6b..8c5d6f211c1e 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -21,15 +21,15 @@ class FlaxAttentionBlock(nn.Module): A multi-head attention as described in: https://arxiv.org/abs/1706.03762 Parameters: - query_dim (:obj:`int`): + query_dim (:obj:`int`): Input hidden states dimension - heads (:obj:`int`, *optional*, defaults to 8): + heads (:obj:`int`, *optional*, defaults to 8): Number of heads - dim_head (:obj:`int`, *optional*, defaults to 64): + dim_head (:obj:`int`, *optional*, defaults to 64): Hidden states dimension inside each head - dropout (:obj:`float`, *optional*, defaults to 0.0): + dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ @@ -96,15 +96,15 @@ class FlaxBasicTransformerBlock(nn.Module): Parameters: - dim (:obj:`int`): + dim (:obj:`int`): Inner hidden states dimension - n_heads (:obj:`int`): + n_heads (:obj:`int`): Number of heads d_head (:obj:`int`): Hidden states dimension inside each head - dropout (:obj:`float`, *optional*, defaults to 0.0): + dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ dim: int @@ -149,17 +149,17 @@ class FlaxSpatialTransformer(nn.Module): Parameters: - in_channels (:obj:`int`): + in_channels (:obj:`int`): Input number of channels - n_heads (:obj:`int`): + n_heads (:obj:`int`): Number of heads - d_head (:obj:`int`): + d_head (:obj:`int`): Hidden states dimension inside each head - depth (:obj:`int`, *optional*, defaults to 1): + depth (:obj:`int`, *optional*, defaults to 1): Number of transformers block - dropout (:obj:`float`, *optional*, defaults to 0.0): + dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int @@ -219,11 +219,11 @@ class FlaxGluFeedForward(nn.Module): https://arxiv.org/abs/2002.05202 Parameters: - dim (:obj:`int`): + dim (:obj:`int`): Inner hidden states dimension - dropout (:obj:`float`, *optional*, defaults to 0.0): + dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ dim: int @@ -248,11 +248,11 @@ class FlaxGEGLU(nn.Module): https://arxiv.org/abs/2002.05202. arameters: - dim (:obj:`int`): - Input hidden states dimension - dropout (:obj:`float`, *optional*, defaults to 0.0): - Dropout rate - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + dim (:obj:`int`): + Input hidden states dimension + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ dim: int diff --git a/src/diffusers/models/embeddings_flax.py b/src/diffusers/models/embeddings_flax.py index ef175b2a460c..9dd2aa268afc 100644 --- a/src/diffusers/models/embeddings_flax.py +++ b/src/diffusers/models/embeddings_flax.py @@ -41,9 +41,9 @@ class FlaxTimestepEmbedding(nn.Module): Time step Embedding Module. Learns embeddings for input time steps. Args: - time_embed_dim (`int`, *optional*, defaults to `32`): + time_embed_dim (`int`, *optional*, defaults to `32`): Time step embedding dimension - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ time_embed_dim: int = 32 @@ -62,7 +62,7 @@ class FlaxTimesteps(nn.Module): Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239 Args: - dim (`int`, *optional*, defaults to `32`): + dim (`int`, *optional*, defaults to `32`): Time step embedding dimension """ dim: int = 32 From 04ce16ee537736aa28a2e0e978a2694d2c1ae051 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 20 Sep 2022 23:24:50 +0200 Subject: [PATCH 11/16] Apply suggestions from code review --- src/diffusers/models/attention_flax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 8c5d6f211c1e..c49d29893a87 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -18,7 +18,7 @@ class FlaxAttentionBlock(nn.Module): r""" - A multi-head attention as described in: https://arxiv.org/abs/1706.03762 + A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762 Parameters: query_dim (:obj:`int`): @@ -91,7 +91,7 @@ def __call__(self, hidden_states, context=None, deterministic=True): class FlaxBasicTransformerBlock(nn.Module): r""" - A transformer block layer with `GLU` (Gated Linear Unit) activation function as described in: + A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in: https://arxiv.org/abs/1706.03762 From 8c52c023c5f90d34d60e0134b67d743b145f1c24 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 20 Sep 2022 21:27:41 +0000 Subject: [PATCH 12/16] fix indent --- src/diffusers/models/attention_flax.py | 44 +++++++++++++------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index c49d29893a87..6ee3e94dd5ed 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -22,15 +22,15 @@ class FlaxAttentionBlock(nn.Module): Parameters: query_dim (:obj:`int`): - Input hidden states dimension + Input hidden states dimension heads (:obj:`int`, *optional*, defaults to 8): - Number of heads + Number of heads dim_head (:obj:`int`, *optional*, defaults to 64): - Hidden states dimension inside each head + Hidden states dimension inside each head dropout (:obj:`float`, *optional*, defaults to 0.0): - Dropout rate + Dropout rate dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` + Parameters `dtype` """ query_dim: int @@ -97,15 +97,15 @@ class FlaxBasicTransformerBlock(nn.Module): Parameters: dim (:obj:`int`): - Inner hidden states dimension + Inner hidden states dimension n_heads (:obj:`int`): - Number of heads + Number of heads d_head (:obj:`int`): - Hidden states dimension inside each head + Hidden states dimension inside each head dropout (:obj:`float`, *optional*, defaults to 0.0): - Dropout rate + Dropout rate dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` + Parameters `dtype` """ dim: int n_heads: int @@ -150,17 +150,17 @@ class FlaxSpatialTransformer(nn.Module): Parameters: in_channels (:obj:`int`): - Input number of channels + Input number of channels n_heads (:obj:`int`): - Number of heads + Number of heads d_head (:obj:`int`): - Hidden states dimension inside each head + Hidden states dimension inside each head depth (:obj:`int`, *optional*, defaults to 1): - Number of transformers block + Number of transformers block dropout (:obj:`float`, *optional*, defaults to 0.0): - Dropout rate + Dropout rate dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` + Parameters `dtype` """ in_channels: int n_heads: int @@ -220,11 +220,11 @@ class FlaxGluFeedForward(nn.Module): Parameters: dim (:obj:`int`): - Inner hidden states dimension + Inner hidden states dimension dropout (:obj:`float`, *optional*, defaults to 0.0): - Dropout rate + Dropout rate dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` + Parameters `dtype` """ dim: int dropout: float = 0.0 @@ -249,11 +249,11 @@ class FlaxGEGLU(nn.Module): arameters: dim (:obj:`int`): - Input hidden states dimension + Input hidden states dimension dropout (:obj:`float`, *optional*, defaults to 0.0): - Dropout rate + Dropout rate dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` + Parameters `dtype` """ dim: int dropout: float = 0.0 From 29d643ebafc55b7c9158cc2ccd528dc20d652e63 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 20 Sep 2022 21:28:24 +0000 Subject: [PATCH 13/16] fix typo --- src/diffusers/models/attention_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 6ee3e94dd5ed..1745265b91e1 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -247,7 +247,7 @@ class FlaxGEGLU(nn.Module): Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. - arameters: + Parameters: dim (:obj:`int`): Input hidden states dimension dropout (:obj:`float`, *optional*, defaults to 0.0): From 309f445ff7dcef28a400096629d832a19cc706f7 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 20 Sep 2022 21:29:14 +0000 Subject: [PATCH 14/16] fix indent unet --- src/diffusers/models/unet_blocks_flax.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/unet_blocks_flax.py b/src/diffusers/models/unet_blocks_flax.py index a4b358fd3ca3..c23b7844c3fd 100644 --- a/src/diffusers/models/unet_blocks_flax.py +++ b/src/diffusers/models/unet_blocks_flax.py @@ -25,19 +25,19 @@ class FlaxCrossAttnDownBlock2D(nn.Module): Parameters: in_channels (:obj:`int`): - Input channels + Input channels out_channels (:obj:`int`): - Output channels + Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): - Dropout rate + Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): - Number of attention blocks layers + Number of attention blocks layers attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): Number of attention heads of each spatial transformer block add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsampling layer before each final output dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` + Parameters `dtype` """ in_channels: int out_channels: int From fb5ee88bcecd5ec73bd3b500534d40e4e5740986 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 21 Sep 2022 12:40:37 +0200 Subject: [PATCH 15/16] Update src/diffusers/models/vae_flax.py --- src/diffusers/models/vae_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py index 10019e091167..d686b53660ce 100644 --- a/src/diffusers/models/vae_flax.py +++ b/src/diffusers/models/vae_flax.py @@ -117,7 +117,7 @@ class FlaxResnetBlock2D(nn.Module): Input channels out_channels (`int`): Output channels - dropout_prob (:obj:`float`, *optional*, defaults to 0.0): + dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`): Whether to use `nin_shortcut`. This activates a new layer inside ResNet block From fc52b2cd2fff5b9da90125fe6eb8428fc280d807 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 23 Sep 2022 12:42:23 +0200 Subject: [PATCH 16/16] Apply suggestions from code review Co-authored-by: Pedro Cuenca --- src/diffusers/models/vae_flax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py index dffd6b758264..b3261b11cf6c 100644 --- a/src/diffusers/models/vae_flax.py +++ b/src/diffusers/models/vae_flax.py @@ -549,7 +549,7 @@ class FlaxDecoder(nn.Module): Tuple containing the number of output channels for each block layers_per_block (:obj:`int`, *optional*, defaults to `2`): Number of Resnet layer for each block - norm_num_groups (:obj:`int`, *optional*, defaults to `2`): + norm_num_groups (:obj:`int`, *optional*, defaults to `32`): norm num group act_fn (:obj:`str`, *optional*, defaults to `silu`): Activation function @@ -704,7 +704,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): Activation function latent_channels (:obj:`int`, *optional*, defaults to `4`): Latent space channels - norm_num_groups (:obj:`int`, *optional*, defaults to `2`): + norm_num_groups (:obj:`int`, *optional*, defaults to `32`): Norm num group sample_size (:obj:`int`, *optional*, defaults to `32`): Sample input size