diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 35f5dc34574c..6b67990dac06 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -94,7 +94,7 @@ def __init__( mid_block_scale_factor: float = 1, downsample_padding: int = 1, act_fn: str = "silu", - attention_head_dim: int = 8, + attention_head_dim: Optional[int] = 8, norm_num_groups: int = 32, norm_eps: float = 1e-5, resnet_time_scale_shift: str = "default", @@ -107,6 +107,17 @@ def __init__( self.sample_size = sample_size time_embed_dim = block_out_channels[0] * 4 + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + # input self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index ba2c09b297b9..daffef765077 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -150,6 +150,27 @@ def __init__( self.sample_size = sample_size + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + # input conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 806875a9a507..2c5b717ac861 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -236,6 +236,31 @@ def __init__( self.sample_size = sample_size + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + "Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`:" + f" {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + "Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`:" + f" {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + "Must provide the same number of `only_cross_attention` as `down_block_types`." + f" `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + "Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`:" + f" {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + # input conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = LinearMultiDim( diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py index 6db22626c9e0..9c9a0a018629 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py @@ -56,7 +56,7 @@ def get_dummy_components(self): up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), cross_attention_dim=32, # SD2-specific config below - attention_head_dim=(2, 4, 8, 8), + attention_head_dim=(2, 4), use_linear_projection=True, ) scheduler = DDIMScheduler( diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py index 49d0e4dbdca2..14fcbef2d164 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py @@ -65,7 +65,7 @@ def get_dummy_components(self): down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), cross_attention_dim=32, - attention_head_dim=(2, 4, 8, 8), + attention_head_dim=(2, 4), use_linear_projection=True, ) scheduler = PNDMScheduler(skip_prk_steps=True) @@ -288,7 +288,7 @@ def test_stable_diffusion_depth2img_default_case(self): if torch_device == "mps": expected_slice = np.array([0.6071, 0.5035, 0.4378, 0.5776, 0.5753, 0.4316, 0.4513, 0.5263, 0.4546]) else: - expected_slice = np.array([0.6374, 0.5039, 0.4199, 0.4819, 0.5563, 0.4617, 0.4028, 0.5381, 0.4711]) + expected_slice = np.array([0.6312, 0.4984, 0.4154, 0.4788, 0.5535, 0.4599, 0.4017, 0.5359, 0.4716]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -309,7 +309,7 @@ def test_stable_diffusion_depth2img_negative_prompt(self): if torch_device == "mps": expected_slice = np.array([0.5825, 0.5135, 0.4095, 0.5452, 0.6059, 0.4211, 0.3994, 0.5177, 0.4335]) else: - expected_slice = np.array([0.6332, 0.5167, 0.3911, 0.4446, 0.5971, 0.4619, 0.3821, 0.5323, 0.4621]) + expected_slice = np.array([0.6296, 0.5125, 0.3890, 0.4456, 0.5955, 0.4621, 0.3810, 0.5310, 0.4626]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -331,7 +331,7 @@ def test_stable_diffusion_depth2img_multiple_init_images(self): if torch_device == "mps": expected_slice = np.array([0.6501, 0.5150, 0.4939, 0.6688, 0.5437, 0.5758, 0.5115, 0.4406, 0.4551]) else: - expected_slice = np.array([0.6248, 0.5206, 0.6007, 0.6749, 0.5022, 0.6442, 0.5352, 0.4140, 0.4681]) + expected_slice = np.array([0.6267, 0.5232, 0.6001, 0.6738, 0.5029, 0.6429, 0.5364, 0.4159, 0.4674]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -386,7 +386,7 @@ def test_stable_diffusion_depth2img_pil(self): if torch_device == "mps": expected_slice = np.array([0.53232, 0.47015, 0.40868, 0.45651, 0.4891, 0.4668, 0.4287, 0.48822, 0.47439]) else: - expected_slice = np.array([0.6374, 0.5039, 0.4199, 0.4819, 0.5563, 0.4617, 0.4028, 0.5381, 0.4711]) + expected_slice = np.array([0.6312, 0.4984, 0.4154, 0.4788, 0.5535, 0.4599, 0.4017, 0.5359, 0.4716]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py index 0ec47a510b50..58bdd465e422 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py @@ -47,7 +47,7 @@ def get_dummy_components(self): up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), cross_attention_dim=32, # SD2-specific config below - attention_head_dim=(2, 4, 8, 8), + attention_head_dim=(2, 4), use_linear_projection=True, ) scheduler = PNDMScheduler(skip_prk_steps=True) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py index a06f13632a3a..39cc546f6774 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py @@ -56,7 +56,7 @@ def dummy_cond_unet(self): up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), cross_attention_dim=32, # SD2-specific config below - attention_head_dim=(2, 4, 8, 8), + attention_head_dim=(2, 4), use_linear_projection=True, ) return model