diff --git a/monai/apps/generation/maisi/networks/controlnet_maisi.py b/monai/apps/generation/maisi/networks/controlnet_maisi.py index 3641124b7d..269086d971 100644 --- a/monai/apps/generation/maisi/networks/controlnet_maisi.py +++ b/monai/apps/generation/maisi/networks/controlnet_maisi.py @@ -11,24 +11,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Sequence, cast +from typing import Sequence import torch -from monai.utils import optional_import +from monai.networks.nets.controlnet import ControlNet +from monai.networks.nets.diffusion_model_unet import get_timestep_embedding -ControlNet, has_controlnet = optional_import("generative.networks.nets.controlnet", name="ControlNet") -get_timestep_embedding, has_get_timestep_embedding = optional_import( - "generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding" -) -if TYPE_CHECKING: - from generative.networks.nets.controlnet import ControlNet as ControlNetType -else: - ControlNetType = cast(type, ControlNet) - - -class ControlNetMaisi(ControlNetType): +class ControlNetMaisi(ControlNet): """ Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image Diffusion Models" (https://arxiv.org/abs/2302.05543) @@ -49,10 +40,12 @@ class ControlNetMaisi(ControlNetType): num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. conditioning_embedding_in_channels: number of input channels for the conditioning embedding. conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding. use_checkpointing: if True, use activation checkpointing to save memory. + include_fc: whether to include the final linear layer. Default to False. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -71,10 +64,12 @@ def __init__( cross_attention_dim: int | None = None, num_class_embeds: int | None = None, upcast_attention: bool = False, - use_flash_attention: bool = False, conditioning_embedding_in_channels: int = 1, - conditioning_embedding_num_channels: Sequence[int] | None = (16, 32, 96, 256), + conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256), use_checkpointing: bool = True, + include_fc: bool = False, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__( spatial_dims, @@ -91,9 +86,11 @@ def __init__( cross_attention_dim, num_class_embeds, upcast_attention, - use_flash_attention, conditioning_embedding_in_channels, conditioning_embedding_num_channels, + include_fc, + use_combined_linear, + use_flash_attention, ) self.use_checkpointing = use_checkpointing @@ -105,7 +102,7 @@ def forward( conditioning_scale: float = 1.0, context: torch.Tensor | None = None, class_labels: torch.Tensor | None = None, - ) -> tuple[Sequence[torch.Tensor], torch.Tensor]: + ) -> tuple[list[torch.Tensor], torch.Tensor]: emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels) h = self._apply_initial_convolution(x) if self.use_checkpointing: diff --git a/monai/networks/nets/controlnet.py b/monai/networks/nets/controlnet.py index 8b08eaae10..8b8813597f 100644 --- a/monai/networks/nets/controlnet.py +++ b/monai/networks/nets/controlnet.py @@ -174,24 +174,22 @@ def __init__( super().__init__() if with_conditioning is True and cross_attention_dim is None: raise ValueError( - "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "ControlNet expects dimension of the cross-attention conditioning (cross_attention_dim) " "to be specified when with_conditioning=True." ) if cross_attention_dim is not None and with_conditioning is False: - raise ValueError( - "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." - ) + raise ValueError("ControlNet expects with_conditioning=True when specifying the cross_attention_dim.") # All number of channels should be multiple of num_groups if any((out_channel % norm_num_groups) != 0 for out_channel in channels): raise ValueError( - f"DiffusionModelUNet expects all channels to be a multiple of norm_num_groups, but got" + f"ControlNet expects all channels to be a multiple of norm_num_groups, but got" f" channels={channels} and norm_num_groups={norm_num_groups}" ) if len(channels) != len(attention_levels): raise ValueError( - f"DiffusionModelUNet expects channels to have the same length as attention_levels, but got " + f"ControlNet expects channels to have the same length as attention_levels, but got " f"channels={channels} and attention_levels={attention_levels}" ) diff --git a/tests/test_controlnet_maisi.py b/tests/test_controlnet_maisi.py index 7b0e69f2c8..bfdf25ec6e 100644 --- a/tests/test_controlnet_maisi.py +++ b/tests/test_controlnet_maisi.py @@ -17,14 +17,12 @@ import torch from parameterized import parameterized +from monai.apps.generation.maisi.networks.controlnet_maisi import ControlNetMaisi from monai.networks import eval_mode from monai.utils import optional_import from tests.utils import SkipIfBeforePyTorchVersion -_, has_generative = optional_import("generative") - -if has_generative: - from monai.apps.generation.maisi.networks.controlnet_maisi import ControlNetMaisi +_, has_einops = optional_import("einops") TEST_CASES = [ [ @@ -103,8 +101,8 @@ TEST_CASES_ERROR = [ [ {"spatial_dims": 2, "in_channels": 1, "with_conditioning": True, "cross_attention_dim": None}, - "ControlNet expects dimension of the cross-attention conditioning " - "(cross_attention_dim) when using with_conditioning.", + "ControlNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "to be specified when with_conditioning=True.", ], [ {"spatial_dims": 2, "in_channels": 1, "with_conditioning": False, "cross_attention_dim": 2}, @@ -112,7 +110,8 @@ ], [ {"spatial_dims": 2, "in_channels": 1, "num_channels": (8, 16), "norm_num_groups": 16}, - "ControlNet expects all num_channels being multiple of norm_num_groups", + f"ControlNet expects all channels to be a multiple of norm_num_groups, but got" + f" channels={(8, 16)} and norm_num_groups={16}", ], [ { @@ -122,16 +121,17 @@ "attention_levels": (True,), "norm_num_groups": 8, }, - "ControlNet expects num_channels being same size of attention_levels", + f"ControlNet expects channels to have the same length as attention_levels, but got " + f"channels={(8, 16)} and attention_levels={(True,)}", ], ] @SkipIfBeforePyTorchVersion((2, 0)) -@skipUnless(has_generative, "monai-generative required") class TestControlNet(unittest.TestCase): @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_shape_unconditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape): net = ControlNetMaisi(**input_param) with eval_mode(net): @@ -145,6 +145,7 @@ def test_shape_unconditioned_models(self, input_param, expected_num_down_blocks_ self.assertEqual(result[1].shape, expected_shape) @parameterized.expand(TEST_CASES_CONDITIONAL) + @skipUnless(has_einops, "Requires einops") def test_shape_conditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape): net = ControlNetMaisi(**input_param) with eval_mode(net):