From d603fc433f076209b5312704d8fcbafcb867dbbf Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Thu, 8 Aug 2024 23:32:17 +0000 Subject: [PATCH 1/8] update import Signed-off-by: Pengfei Guo --- .../maisi/networks/controlnet_maisi.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/monai/apps/generation/maisi/networks/controlnet_maisi.py b/monai/apps/generation/maisi/networks/controlnet_maisi.py index 3641124b7d..328e222732 100644 --- a/monai/apps/generation/maisi/networks/controlnet_maisi.py +++ b/monai/apps/generation/maisi/networks/controlnet_maisi.py @@ -17,18 +17,22 @@ from monai.utils import optional_import -ControlNet, has_controlnet = optional_import("generative.networks.nets.controlnet", name="ControlNet") -get_timestep_embedding, has_get_timestep_embedding = optional_import( - "generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding" -) +# ControlNet, has_controlnet = optional_import("generative.networks.nets.controlnet", name="ControlNet") +# get_timestep_embedding, has_get_timestep_embedding = optional_import( +# "generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding" +# ) -if TYPE_CHECKING: - from generative.networks.nets.controlnet import ControlNet as ControlNetType -else: - ControlNetType = cast(type, ControlNet) +# if TYPE_CHECKING: +# from generative.networks.nets.controlnet import ControlNet as ControlNetType +# else: +# ControlNetType = cast(type, ControlNet) +from monai.networks.nets.controlnet import ControlNet +from monai.networks.nets.diffusion_model_unet import get_timestep_embedding -class ControlNetMaisi(ControlNetType): + + +class ControlNetMaisi(ControlNet): """ Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image Diffusion Models" (https://arxiv.org/abs/2302.05543) From e034bf3a7973dd682883cc2116274681a83b2f39 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Aug 2024 23:35:46 +0000 Subject: [PATCH 2/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/generation/maisi/networks/controlnet_maisi.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/apps/generation/maisi/networks/controlnet_maisi.py b/monai/apps/generation/maisi/networks/controlnet_maisi.py index 328e222732..9fcb111c2e 100644 --- a/monai/apps/generation/maisi/networks/controlnet_maisi.py +++ b/monai/apps/generation/maisi/networks/controlnet_maisi.py @@ -11,11 +11,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Sequence, cast +from typing import Sequence import torch -from monai.utils import optional_import # ControlNet, has_controlnet = optional_import("generative.networks.nets.controlnet", name="ControlNet") # get_timestep_embedding, has_get_timestep_embedding = optional_import( From cea517b6c637f3e021a3d0f1b23078e3e5d9f524 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Sun, 11 Aug 2024 19:57:41 +0000 Subject: [PATCH 3/8] update Signed-off-by: Pengfei Guo --- .../generation/maisi/networks/controlnet_maisi.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/monai/apps/generation/maisi/networks/controlnet_maisi.py b/monai/apps/generation/maisi/networks/controlnet_maisi.py index 9fcb111c2e..5bc77bc6d2 100644 --- a/monai/apps/generation/maisi/networks/controlnet_maisi.py +++ b/monai/apps/generation/maisi/networks/controlnet_maisi.py @@ -15,17 +15,6 @@ import torch - -# ControlNet, has_controlnet = optional_import("generative.networks.nets.controlnet", name="ControlNet") -# get_timestep_embedding, has_get_timestep_embedding = optional_import( -# "generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding" -# ) - -# if TYPE_CHECKING: -# from generative.networks.nets.controlnet import ControlNet as ControlNetType -# else: -# ControlNetType = cast(type, ControlNet) - from monai.networks.nets.controlnet import ControlNet from monai.networks.nets.diffusion_model_unet import get_timestep_embedding @@ -52,6 +41,8 @@ class ControlNetMaisi(ControlNet): num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. upcast_attention: if True, upcast attention operations to full precision. + include_fc: whether to include the final linear layer. Default to False. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. conditioning_embedding_in_channels: number of input channels for the conditioning embedding. conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding. @@ -74,6 +65,8 @@ def __init__( cross_attention_dim: int | None = None, num_class_embeds: int | None = None, upcast_attention: bool = False, + include_fc: bool = False, + use_combined_linear: bool = False, use_flash_attention: bool = False, conditioning_embedding_in_channels: int = 1, conditioning_embedding_num_channels: Sequence[int] | None = (16, 32, 96, 256), From 1d3a84210d078c6afb0d802aca5fa632b6f32d73 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Sun, 11 Aug 2024 22:01:12 +0000 Subject: [PATCH 4/8] update Signed-off-by: Pengfei Guo --- .../maisi/networks/controlnet_maisi.py | 17 ++++++++------- tests/test_controlnet_maisi.py | 21 +++++++------------ 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/monai/apps/generation/maisi/networks/controlnet_maisi.py b/monai/apps/generation/maisi/networks/controlnet_maisi.py index 5bc77bc6d2..3d593e3bfd 100644 --- a/monai/apps/generation/maisi/networks/controlnet_maisi.py +++ b/monai/apps/generation/maisi/networks/controlnet_maisi.py @@ -19,7 +19,6 @@ from monai.networks.nets.diffusion_model_unet import get_timestep_embedding - class ControlNetMaisi(ControlNet): """ Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image @@ -41,12 +40,12 @@ class ControlNetMaisi(ControlNet): num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. upcast_attention: if True, upcast attention operations to full precision. - include_fc: whether to include the final linear layer. Default to False. - use_combined_linear: whether to use a single linear layer for qkv projection, default to False. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. conditioning_embedding_in_channels: number of input channels for the conditioning embedding. conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding. use_checkpointing: if True, use activation checkpointing to save memory. + include_fc: whether to include the final linear layer. Default to False. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -65,12 +64,12 @@ def __init__( cross_attention_dim: int | None = None, num_class_embeds: int | None = None, upcast_attention: bool = False, - include_fc: bool = False, - use_combined_linear: bool = False, - use_flash_attention: bool = False, conditioning_embedding_in_channels: int = 1, conditioning_embedding_num_channels: Sequence[int] | None = (16, 32, 96, 256), use_checkpointing: bool = True, + include_fc: bool = False, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__( spatial_dims, @@ -87,9 +86,11 @@ def __init__( cross_attention_dim, num_class_embeds, upcast_attention, - use_flash_attention, conditioning_embedding_in_channels, conditioning_embedding_num_channels, + include_fc, + use_combined_linear, + use_flash_attention, ) self.use_checkpointing = use_checkpointing diff --git a/tests/test_controlnet_maisi.py b/tests/test_controlnet_maisi.py index 7b0e69f2c8..7087cca667 100644 --- a/tests/test_controlnet_maisi.py +++ b/tests/test_controlnet_maisi.py @@ -12,20 +12,14 @@ from __future__ import annotations import unittest -from unittest import skipUnless import torch from parameterized import parameterized +from monai.apps.generation.maisi.networks.controlnet_maisi import ControlNetMaisi from monai.networks import eval_mode -from monai.utils import optional_import from tests.utils import SkipIfBeforePyTorchVersion -_, has_generative = optional_import("generative") - -if has_generative: - from monai.apps.generation.maisi.networks.controlnet_maisi import ControlNetMaisi - TEST_CASES = [ [ { @@ -103,16 +97,17 @@ TEST_CASES_ERROR = [ [ {"spatial_dims": 2, "in_channels": 1, "with_conditioning": True, "cross_attention_dim": None}, - "ControlNet expects dimension of the cross-attention conditioning " - "(cross_attention_dim) when using with_conditioning.", + "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "to be specified when with_conditioning=True.", ], [ {"spatial_dims": 2, "in_channels": 1, "with_conditioning": False, "cross_attention_dim": 2}, - "ControlNet expects with_conditioning=True when specifying the cross_attention_dim.", + "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim.", ], [ {"spatial_dims": 2, "in_channels": 1, "num_channels": (8, 16), "norm_num_groups": 16}, - "ControlNet expects all num_channels being multiple of norm_num_groups", + f"DiffusionModelUNet expects all channels to be a multiple of norm_num_groups, but got" + f" channels={(8, 16)} and norm_num_groups={16}", ], [ { @@ -122,13 +117,13 @@ "attention_levels": (True,), "norm_num_groups": 8, }, - "ControlNet expects num_channels being same size of attention_levels", + f"DiffusionModelUNet expects channels to have the same length as attention_levels, but got " + f"channels={(8, 16)} and attention_levels={(True,)}", ], ] @SkipIfBeforePyTorchVersion((2, 0)) -@skipUnless(has_generative, "monai-generative required") class TestControlNet(unittest.TestCase): @parameterized.expand(TEST_CASES) From 1bb9c10e105b9493a51f37fbe8102350971524fa Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Sun, 11 Aug 2024 23:18:31 +0000 Subject: [PATCH 5/8] update Signed-off-by: Pengfei Guo --- monai/apps/generation/maisi/networks/controlnet_maisi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/apps/generation/maisi/networks/controlnet_maisi.py b/monai/apps/generation/maisi/networks/controlnet_maisi.py index 3d593e3bfd..269086d971 100644 --- a/monai/apps/generation/maisi/networks/controlnet_maisi.py +++ b/monai/apps/generation/maisi/networks/controlnet_maisi.py @@ -65,7 +65,7 @@ def __init__( num_class_embeds: int | None = None, upcast_attention: bool = False, conditioning_embedding_in_channels: int = 1, - conditioning_embedding_num_channels: Sequence[int] | None = (16, 32, 96, 256), + conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256), use_checkpointing: bool = True, include_fc: bool = False, use_combined_linear: bool = False, @@ -102,7 +102,7 @@ def forward( conditioning_scale: float = 1.0, context: torch.Tensor | None = None, class_labels: torch.Tensor | None = None, - ) -> tuple[Sequence[torch.Tensor], torch.Tensor]: + ) -> tuple[list[torch.Tensor], torch.Tensor]: emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels) h = self._apply_initial_convolution(x) if self.use_checkpointing: From 6cb49b1b9970e4f65f62c3f9126e350988463876 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Sun, 11 Aug 2024 23:29:49 +0000 Subject: [PATCH 6/8] update Signed-off-by: Pengfei Guo --- tests/test_controlnet_maisi.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_controlnet_maisi.py b/tests/test_controlnet_maisi.py index 7087cca667..fbdf017d13 100644 --- a/tests/test_controlnet_maisi.py +++ b/tests/test_controlnet_maisi.py @@ -12,14 +12,18 @@ from __future__ import annotations import unittest +from unittest import skipUnless import torch from parameterized import parameterized from monai.apps.generation.maisi.networks.controlnet_maisi import ControlNetMaisi from monai.networks import eval_mode +from monai.utils import optional_import from tests.utils import SkipIfBeforePyTorchVersion +_, has_einops = optional_import("einops") + TEST_CASES = [ [ { @@ -127,6 +131,7 @@ class TestControlNet(unittest.TestCase): @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_shape_unconditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape): net = ControlNetMaisi(**input_param) with eval_mode(net): @@ -140,6 +145,7 @@ def test_shape_unconditioned_models(self, input_param, expected_num_down_blocks_ self.assertEqual(result[1].shape, expected_shape) @parameterized.expand(TEST_CASES_CONDITIONAL) + @skipUnless(has_einops, "Requires einops") def test_shape_conditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape): net = ControlNetMaisi(**input_param) with eval_mode(net): From a645bc550e7a7388f28a7e9914167c9f5df3eeb4 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Mon, 12 Aug 2024 16:23:42 +0000 Subject: [PATCH 7/8] update Signed-off-by: Pengfei Guo --- monai/networks/nets/controlnet.py | 8 ++++---- tests/test_controlnet_maisi.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/monai/networks/nets/controlnet.py b/monai/networks/nets/controlnet.py index 8b08eaae10..65baf908ea 100644 --- a/monai/networks/nets/controlnet.py +++ b/monai/networks/nets/controlnet.py @@ -174,24 +174,24 @@ def __init__( super().__init__() if with_conditioning is True and cross_attention_dim is None: raise ValueError( - "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "ControlNet expects dimension of the cross-attention conditioning (cross_attention_dim) " "to be specified when with_conditioning=True." ) if cross_attention_dim is not None and with_conditioning is False: raise ValueError( - "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." + "ControlNet expects with_conditioning=True when specifying the cross_attention_dim." ) # All number of channels should be multiple of num_groups if any((out_channel % norm_num_groups) != 0 for out_channel in channels): raise ValueError( - f"DiffusionModelUNet expects all channels to be a multiple of norm_num_groups, but got" + f"ControlNet expects all channels to be a multiple of norm_num_groups, but got" f" channels={channels} and norm_num_groups={norm_num_groups}" ) if len(channels) != len(attention_levels): raise ValueError( - f"DiffusionModelUNet expects channels to have the same length as attention_levels, but got " + f"ControlNet expects channels to have the same length as attention_levels, but got " f"channels={channels} and attention_levels={attention_levels}" ) diff --git a/tests/test_controlnet_maisi.py b/tests/test_controlnet_maisi.py index fbdf017d13..bfdf25ec6e 100644 --- a/tests/test_controlnet_maisi.py +++ b/tests/test_controlnet_maisi.py @@ -101,16 +101,16 @@ TEST_CASES_ERROR = [ [ {"spatial_dims": 2, "in_channels": 1, "with_conditioning": True, "cross_attention_dim": None}, - "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "ControlNet expects dimension of the cross-attention conditioning (cross_attention_dim) " "to be specified when with_conditioning=True.", ], [ {"spatial_dims": 2, "in_channels": 1, "with_conditioning": False, "cross_attention_dim": 2}, - "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim.", + "ControlNet expects with_conditioning=True when specifying the cross_attention_dim.", ], [ {"spatial_dims": 2, "in_channels": 1, "num_channels": (8, 16), "norm_num_groups": 16}, - f"DiffusionModelUNet expects all channels to be a multiple of norm_num_groups, but got" + f"ControlNet expects all channels to be a multiple of norm_num_groups, but got" f" channels={(8, 16)} and norm_num_groups={16}", ], [ @@ -121,7 +121,7 @@ "attention_levels": (True,), "norm_num_groups": 8, }, - f"DiffusionModelUNet expects channels to have the same length as attention_levels, but got " + f"ControlNet expects channels to have the same length as attention_levels, but got " f"channels={(8, 16)} and attention_levels={(True,)}", ], ] From 53d9a43ef314ae7e1bb738a559a87cd2801a8c4d Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Mon, 12 Aug 2024 16:41:43 +0000 Subject: [PATCH 8/8] update Signed-off-by: Pengfei Guo --- monai/networks/nets/controlnet.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/networks/nets/controlnet.py b/monai/networks/nets/controlnet.py index 65baf908ea..8b8813597f 100644 --- a/monai/networks/nets/controlnet.py +++ b/monai/networks/nets/controlnet.py @@ -178,9 +178,7 @@ def __init__( "to be specified when with_conditioning=True." ) if cross_attention_dim is not None and with_conditioning is False: - raise ValueError( - "ControlNet expects with_conditioning=True when specifying the cross_attention_dim." - ) + raise ValueError("ControlNet expects with_conditioning=True when specifying the cross_attention_dim.") # All number of channels should be multiple of num_groups if any((out_channel % norm_num_groups) != 0 for out_channel in channels):