diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py index b3261b11cf6c..074133a05c4a 100644 --- a/src/diffusers/models/vae_flax.py +++ b/src/diffusers/models/vae_flax.py @@ -119,6 +119,8 @@ class FlaxResnetBlock2D(nn.Module): Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate + groups (:obj:`int`, *optional*, defaults to `32`): + The number of groups to use for group norm. use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`): Whether to use `nin_shortcut`. This activates a new layer inside ResNet block dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): @@ -128,13 +130,14 @@ class FlaxResnetBlock2D(nn.Module): in_channels: int out_channels: int = None dropout: float = 0.0 + groups: int = 32 use_nin_shortcut: bool = None dtype: jnp.dtype = jnp.float32 def setup(self): out_channels = self.in_channels if self.out_channels is None else self.out_channels - self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-6) + self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6) self.conv1 = nn.Conv( out_channels, kernel_size=(3, 3), @@ -143,7 +146,7 @@ def setup(self): dtype=self.dtype, ) - self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6) + self.norm2 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6) self.dropout_layer = nn.Dropout(self.dropout) self.conv2 = nn.Conv( out_channels, @@ -191,12 +194,15 @@ class FlaxAttentionBlock(nn.Module): Input channels num_head_channels (:obj:`int`, *optional*, defaults to `None`): Number of attention heads + num_groups (:obj:`int`, *optional*, defaults to `32`): + The number of groups to use for group norm dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ channels: int num_head_channels: int = None + num_groups: int = 32 dtype: jnp.dtype = jnp.float32 def setup(self): @@ -204,7 +210,7 @@ def setup(self): dense = partial(nn.Dense, self.channels, dtype=self.dtype) - self.group_norm = nn.GroupNorm(num_groups=32, epsilon=1e-6) + self.group_norm = nn.GroupNorm(num_groups=self.num_groups, epsilon=1e-6) self.query, self.key, self.value = dense(), dense(), dense() self.proj_attn = dense() @@ -264,6 +270,8 @@ class FlaxDownEncoderBlock2D(nn.Module): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of Resnet layer block + resnet_groups (:obj:`int`, *optional*, defaults to `32`): + The number of groups to use for the Resnet block group norm add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsample layer dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): @@ -273,6 +281,7 @@ class FlaxDownEncoderBlock2D(nn.Module): out_channels: int dropout: float = 0.0 num_layers: int = 1 + resnet_groups: int = 32 add_downsample: bool = True dtype: jnp.dtype = jnp.float32 @@ -285,6 +294,7 @@ def setup(self): in_channels=in_channels, out_channels=self.out_channels, dropout=self.dropout, + groups=self.resnet_groups, dtype=self.dtype, ) resnets.append(res_block) @@ -303,9 +313,9 @@ def __call__(self, hidden_states, deterministic=True): return hidden_states -class FlaxUpEncoderBlock2D(nn.Module): +class FlaxUpDecoderBlock2D(nn.Module): r""" - Flax Resnet blocks-based Encoder block for diffusion-based VAE. + Flax Resnet blocks-based Decoder block for diffusion-based VAE. Parameters: in_channels (:obj:`int`): @@ -316,8 +326,10 @@ class FlaxUpEncoderBlock2D(nn.Module): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of Resnet layer block - add_downsample (:obj:`bool`, *optional*, defaults to `True`): - Whether to add downsample layer + resnet_groups (:obj:`int`, *optional*, defaults to `32`): + The number of groups to use for the Resnet block group norm + add_upsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add upsample layer dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ @@ -325,6 +337,7 @@ class FlaxUpEncoderBlock2D(nn.Module): out_channels: int dropout: float = 0.0 num_layers: int = 1 + resnet_groups: int = 32 add_upsample: bool = True dtype: jnp.dtype = jnp.float32 @@ -336,6 +349,7 @@ def setup(self): in_channels=in_channels, out_channels=self.out_channels, dropout=self.dropout, + groups=self.resnet_groups, dtype=self.dtype, ) resnets.append(res_block) @@ -366,6 +380,8 @@ class FlaxUNetMidBlock2D(nn.Module): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of Resnet layer block + resnet_groups (:obj:`int`, *optional*, defaults to `32`): + The number of groups to use for the Resnet and Attention block group norm attn_num_head_channels (:obj:`int`, *optional*, defaults to `1`): Number of attention heads for each attention block dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): @@ -374,16 +390,20 @@ class FlaxUNetMidBlock2D(nn.Module): in_channels: int dropout: float = 0.0 num_layers: int = 1 + resnet_groups: int = 32 attn_num_head_channels: int = 1 dtype: jnp.dtype = jnp.float32 def setup(self): + resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32) + # there is always at least one resnet resnets = [ FlaxResnetBlock2D( in_channels=self.in_channels, out_channels=self.in_channels, dropout=self.dropout, + groups=resnet_groups, dtype=self.dtype, ) ] @@ -392,7 +412,10 @@ def setup(self): for _ in range(self.num_layers): attn_block = FlaxAttentionBlock( - channels=self.in_channels, num_head_channels=self.attn_num_head_channels, dtype=self.dtype + channels=self.in_channels, + num_head_channels=self.attn_num_head_channels, + num_groups=resnet_groups, + dtype=self.dtype, ) attentions.append(attn_block) @@ -400,6 +423,7 @@ def setup(self): in_channels=self.in_channels, out_channels=self.in_channels, dropout=self.dropout, + groups=resnet_groups, dtype=self.dtype, ) resnets.append(res_block) @@ -441,7 +465,7 @@ class FlaxEncoder(nn.Module): Tuple containing the number of output channels for each block layers_per_block (:obj:`int`, *optional*, defaults to `2`): Number of Resnet layer for each block - norm_num_groups (:obj:`int`, *optional*, defaults to `2`): + norm_num_groups (:obj:`int`, *optional*, defaults to `32`): norm num group act_fn (:obj:`str`, *optional*, defaults to `silu`): Activation function @@ -483,6 +507,7 @@ def setup(self): in_channels=input_channel, out_channels=output_channel, num_layers=self.layers_per_block, + resnet_groups=self.norm_num_groups, add_downsample=not is_final_block, dtype=self.dtype, ) @@ -491,12 +516,15 @@ def setup(self): # middle self.mid_block = FlaxUNetMidBlock2D( - in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype + in_channels=block_out_channels[-1], + resnet_groups=self.norm_num_groups, + attn_num_head_channels=None, + dtype=self.dtype, ) # end conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels - self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6) + self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6) self.conv_out = nn.Conv( conv_out_channels, kernel_size=(3, 3), @@ -581,7 +609,10 @@ def setup(self): # middle self.mid_block = FlaxUNetMidBlock2D( - in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype + in_channels=block_out_channels[-1], + resnet_groups=self.norm_num_groups, + attn_num_head_channels=None, + dtype=self.dtype, ) # upsampling @@ -594,10 +625,11 @@ def setup(self): is_final_block = i == len(block_out_channels) - 1 - up_block = FlaxUpEncoderBlock2D( + up_block = FlaxUpDecoderBlock2D( in_channels=prev_output_channel, out_channels=output_channel, num_layers=self.layers_per_block + 1, + resnet_groups=self.norm_num_groups, add_upsample=not is_final_block, dtype=self.dtype, ) @@ -607,7 +639,7 @@ def setup(self): self.up_blocks = up_blocks # end - self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6) + self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6) self.conv_out = nn.Conv( self.out_channels, kernel_size=(3, 3), diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 0177d30abac9..f44b9cd394c9 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -14,6 +14,8 @@ import requests from packaging import version +from .import_utils import is_flax_available + global_rng = random.Random() torch_device = "cuda" if torch.cuda.is_available() else "cpu" @@ -89,6 +91,13 @@ def slow(test_case): return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) +def require_flax(test_case): + """ + Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed + """ + return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case) + + def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: """ Args: diff --git a/tests/test_modeling_common_flax.py b/tests/test_modeling_common_flax.py new file mode 100644 index 000000000000..61849b22318f --- /dev/null +++ b/tests/test_modeling_common_flax.py @@ -0,0 +1,44 @@ +from diffusers.utils import is_flax_available +from diffusers.utils.testing_utils import require_flax + + +if is_flax_available(): + import jax + + +@require_flax +class FlaxModelTesterMixin: + def test_output(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"]) + jax.lax.stop_gradient(variables) + + output = model.apply(variables, inputs_dict["sample"]) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_forward_with_norm_groups(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["norm_num_groups"] = 16 + init_dict["block_out_channels"] = (16, 32) + + model = self.model_class(**init_dict) + variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"]) + jax.lax.stop_gradient(variables) + + output = model.apply(variables, inputs_dict["sample"]) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") diff --git a/tests/test_models_vae_flax.py b/tests/test_models_vae_flax.py new file mode 100644 index 000000000000..e5c56b61a5a4 --- /dev/null +++ b/tests/test_models_vae_flax.py @@ -0,0 +1,39 @@ +import unittest + +from diffusers import FlaxAutoencoderKL +from diffusers.utils import is_flax_available +from diffusers.utils.testing_utils import require_flax + +from .test_modeling_common_flax import FlaxModelTesterMixin + + +if is_flax_available(): + import jax + + +@require_flax +class FlaxAutoencoderKLTests(FlaxModelTesterMixin, unittest.TestCase): + model_class = FlaxAutoencoderKL + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 3 + sizes = (32, 32) + + prng_key = jax.random.PRNGKey(0) + image = jax.random.uniform(prng_key, ((batch_size, num_channels) + sizes)) + + return {"sample": image, "prng_key": prng_key} + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": [32, 64], + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], + "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], + "latent_channels": 4, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict