From 35c76bf636ab5cecd81eb8a86e6784c452e7fa9b Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Tue, 27 Jun 2023 14:08:17 +0100 Subject: [PATCH 1/4] Add dropout for conditioning cross-attention blocks. --- .../networks/nets/diffusion_model_unet.py | 25 +++++- tests/test_dropout.py | 81 +++++++++++++++++++ 2 files changed, 105 insertions(+), 1 deletion(-) create mode 100644 tests/test_dropout.py diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index c69e9732..2de6705d 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -911,6 +911,7 @@ class CrossAttnDownBlock(nn.Module): cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers """ def __init__( @@ -930,6 +931,7 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, use_flash_attention: bool = False, + dropout_cattn: float = 0.0 ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -962,6 +964,7 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, + dropout=dropout_cattn ) ) @@ -1100,6 +1103,7 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, use_flash_attention: bool = False, + dropout_cattn: float = 0.0 ) -> None: super().__init__() self.attention = None @@ -1123,6 +1127,7 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, + dropout=dropout_cattn ) self.resnet_2 = ResnetBlock( spatial_dims=spatial_dims, @@ -1266,7 +1271,7 @@ def __init__( add_upsample: bool = True, resblock_updown: bool = False, num_head_channels: int = 1, - use_flash_attention: bool = False, + use_flash_attention: bool = False ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -1363,6 +1368,7 @@ class CrossAttnUpBlock(nn.Module): cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers """ def __init__( @@ -1382,6 +1388,7 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, use_flash_attention: bool = False, + dropout_cattn: float = 0.0 ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -1415,6 +1422,7 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, + dropout=dropout_cattn ) ) @@ -1478,6 +1486,7 @@ def get_down_block( cross_attention_dim: int | None, upcast_attention: bool = False, use_flash_attention: bool = False, + dropout_cattn: float = 0.0 ) -> nn.Module: if with_attn: return AttnDownBlock( @@ -1509,6 +1518,7 @@ def get_down_block( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn ) else: return DownBlock( @@ -1536,6 +1546,7 @@ def get_mid_block( cross_attention_dim: int | None, upcast_attention: bool = False, use_flash_attention: bool = False, + dropout_cattn: float = 0.0 ) -> nn.Module: if with_conditioning: return CrossAttnMidBlock( @@ -1549,6 +1560,7 @@ def get_mid_block( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn ) else: return AttnMidBlock( @@ -1580,6 +1592,7 @@ def get_up_block( cross_attention_dim: int | None, upcast_attention: bool = False, use_flash_attention: bool = False, + dropout_cattn: float = 0.0 ) -> nn.Module: if with_attn: return AttnUpBlock( @@ -1613,6 +1626,7 @@ def get_up_block( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn ) else: return UpBlock( @@ -1653,6 +1667,7 @@ class DiffusionModelUNet(nn.Module): classes. upcast_attention: if True, upcast attention operations to full precision. use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers """ def __init__( @@ -1673,6 +1688,7 @@ def __init__( num_class_embeds: int | None = None, upcast_attention: bool = False, use_flash_attention: bool = False, + dropout_cattn: float = 0.0 ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: @@ -1684,6 +1700,10 @@ def __init__( raise ValueError( "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." ) + if dropout_cattn > 1.0 or dropout_cattn < 0.0: + raise ValueError( + "Dropout cannot be negative or >1.0!" + ) # All number of channels should be multiple of num_groups if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): @@ -1773,6 +1793,7 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn ) self.down_blocks.append(down_block) @@ -1790,6 +1811,7 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn ) # up @@ -1824,6 +1846,7 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn ) self.up_blocks.append(up_block) diff --git a/tests/test_dropout.py b/tests/test_dropout.py new file mode 100644 index 00000000..1c0bf57f --- /dev/null +++ b/tests/test_dropout.py @@ -0,0 +1,81 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +import unittest +from parameterized import parameterized +from generative.networks.nets import DiffusionModelUNet + +CASE_WRONG = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "dropout_cattn": 3.0 + } + ], +] + +CASE_OK = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "dropout_cattn": 0.25 + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3 + } + ], +] + +class TestDiffusionModelUNetDropout(unittest.TestCase): + @parameterized.expand(CASE_WRONG) + def test_wrong(self, input_param): + with self.assertRaises(ValueError): + net = DiffusionModelUNet(**input_param) + + @parameterized.expand(CASE_OK) + def test_right(self, input_param): + net = DiffusionModelUNet(**input_param) + +if __name__ == "__main__": + unittest.main() From f55088792ddc96fc138f7f4a882ac114544ef175 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Wed, 1 Nov 2023 08:47:03 +0000 Subject: [PATCH 2/4] Removed test_dropout. Included tests for this add-on (dropout possibility in cross-attention blocks) in the main test_diffusion_model_unet. --- tests/test_diffusion_model_unet.py | 64 +++++++++++++++++++++++ tests/test_dropout.py | 81 ------------------------------ 2 files changed, 64 insertions(+), 81 deletions(-) delete mode 100644 tests/test_dropout.py diff --git a/tests/test_diffusion_model_unet.py b/tests/test_diffusion_model_unet.py index b02c37b1..a1073c13 100644 --- a/tests/test_diffusion_model_unet.py +++ b/tests/test_diffusion_model_unet.py @@ -231,6 +231,59 @@ ], ] +DROPOUT_OK = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "dropout_cattn": 0.25 + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3 + } + ], +] + +DROPOUT_WRONG = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "dropout_cattn": 3.0 + } + ], +] + class TestDiffusionModelUNet2D(unittest.TestCase): @parameterized.expand(UNCOND_CASES_2D) @@ -524,6 +577,17 @@ def test_script_conditioned_3d_models(self): net, torch.rand((1, 1, 16, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 3)) ) + # Test dropout specification for cross-attention blocks + @parameterized.expand(DROPOUT_WRONG) + def test_wrong_dropout(self, input_param): + with self.assertRaises(ValueError): + net = DiffusionModelUNet(**input_param) + + @parameterized.expand(DROPOUT_OK) + def test_right_dropout(self, input_param): + net = DiffusionModelUNet(**input_param) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_dropout.py b/tests/test_dropout.py deleted file mode 100644 index 1c0bf57f..00000000 --- a/tests/test_dropout.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations -import unittest -from parameterized import parameterized -from generative.networks.nets import DiffusionModelUNet - -CASE_WRONG = [ - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "num_channels": (8, 8, 8), - "attention_levels": (False, False, True), - "num_head_channels": 4, - "norm_num_groups": 8, - "with_conditioning": True, - "transformer_num_layers": 1, - "cross_attention_dim": 3, - "dropout_cattn": 3.0 - } - ], -] - -CASE_OK = [ - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "num_channels": (8, 8, 8), - "attention_levels": (False, False, True), - "num_head_channels": 4, - "norm_num_groups": 8, - "with_conditioning": True, - "transformer_num_layers": 1, - "cross_attention_dim": 3, - "dropout_cattn": 0.25 - } - ], - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "num_channels": (8, 8, 8), - "attention_levels": (False, False, True), - "num_head_channels": 4, - "norm_num_groups": 8, - "with_conditioning": True, - "transformer_num_layers": 1, - "cross_attention_dim": 3 - } - ], -] - -class TestDiffusionModelUNetDropout(unittest.TestCase): - @parameterized.expand(CASE_WRONG) - def test_wrong(self, input_param): - with self.assertRaises(ValueError): - net = DiffusionModelUNet(**input_param) - - @parameterized.expand(CASE_OK) - def test_right(self, input_param): - net = DiffusionModelUNet(**input_param) - -if __name__ == "__main__": - unittest.main() From 76ff96f7aa0ae235d4a414f756b10e6dacf370d3 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 1 Nov 2023 08:18:05 -0600 Subject: [PATCH 3/4] Update tests/test_diffusion_model_unet.py Signed-off-by: Mark Graham --- tests/test_diffusion_model_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_diffusion_model_unet.py b/tests/test_diffusion_model_unet.py index a1073c13..ab650251 100644 --- a/tests/test_diffusion_model_unet.py +++ b/tests/test_diffusion_model_unet.py @@ -585,7 +585,7 @@ def test_wrong_dropout(self, input_param): @parameterized.expand(DROPOUT_OK) def test_right_dropout(self, input_param): - net = DiffusionModelUNet(**input_param) + _ = DiffusionModelUNet(**input_param) From 783ece554bd28945c1a3679211f12d2b94b4f7e0 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 1 Nov 2023 08:18:10 -0600 Subject: [PATCH 4/4] Update tests/test_diffusion_model_unet.py Signed-off-by: Mark Graham --- tests/test_diffusion_model_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_diffusion_model_unet.py b/tests/test_diffusion_model_unet.py index ab650251..976e88d4 100644 --- a/tests/test_diffusion_model_unet.py +++ b/tests/test_diffusion_model_unet.py @@ -581,7 +581,7 @@ def test_script_conditioned_3d_models(self): @parameterized.expand(DROPOUT_WRONG) def test_wrong_dropout(self, input_param): with self.assertRaises(ValueError): - net = DiffusionModelUNet(**input_param) + _ = DiffusionModelUNet(**input_param) @parameterized.expand(DROPOUT_OK) def test_right_dropout(self, input_param):