From 1b4b66092009d363b68baa37fc2f05ad9eb0058f Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 9 Jan 2023 11:47:43 -0600 Subject: [PATCH] Adds error and corresponding test --- .../networks/nets/diffusion_model_unet.py | 3 +++ tests/test_diffusion_model_unet.py | 20 +++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index 7aabfbba..23020f29 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -1482,6 +1482,7 @@ def __init__( self.num_res_blocks = num_res_blocks self.attention_levels = attention_levels self.num_head_channels = num_head_channels + self.with_conditioning = with_conditioning # input self.conv_in = Convolution( @@ -1623,6 +1624,8 @@ def forward( h = self.conv_in(x) # 4. down + if context is not None and self.with_conditioning is False: + raise ValueError("model should have with_conditioning = True if context is provided") down_block_res_samples: List[torch.Tensor] = [h] for downsample_block in self.down_blocks: h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) diff --git a/tests/test_diffusion_model_unet.py b/tests/test_diffusion_model_unet.py index f9925b96..e661a213 100644 --- a/tests/test_diffusion_model_unet.py +++ b/tests/test_diffusion_model_unet.py @@ -168,6 +168,26 @@ def test_with_conditioning_cross_attention_dim_none(self): norm_num_groups=8, ) + def test_context_with_conditioning_none(self): + with self.assertRaises(ValueError): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=False, + transformer_num_layers=1, + norm_num_groups=8, + ) + with eval_mode(net): + net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + context=torch.rand((1, 1, 3)), + ) + def test_shape_conditioned_models_class_conditioning(self): net = DiffusionModelUNet( spatial_dims=2,