diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index 533da32fa0..f27f73ec60 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -27,13 +27,11 @@ AutoencoderKL, has_autoencoderkl = optional_import("generative.networks.nets.autoencoderkl", name="AutoencoderKL") ResBlock, has_resblock = optional_import("generative.networks.nets.autoencoderkl", name="ResBlock") - if TYPE_CHECKING: from generative.networks.nets.autoencoderkl import AutoencoderKL as AutoencoderKLType else: AutoencoderKLType = cast(type, AutoencoderKL) - # Set up logging configuration logger = logging.getLogger(__name__) diff --git a/tests/test_autoencoderkl_maisi.py b/tests/test_autoencoderkl_maisi.py index e88dc469c9..392a3d7db2 100644 --- a/tests/test_autoencoderkl_maisi.py +++ b/tests/test_autoencoderkl_maisi.py @@ -29,7 +29,6 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - CASES_NO_ATTENTION = [ [ { @@ -82,6 +81,7 @@ @unittest.skipUnless(has_generative, "monai-generative required") class TestAutoencoderKlMaisi(unittest.TestCase): + @parameterized.expand(CASES) def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape): net = AutoencoderKlMaisi(**input_param).to(device) diff --git a/tests/test_controlnet_maisi.py b/tests/test_controlnet_maisi.py index b522b750c8..7b0e69f2c8 100644 --- a/tests/test_controlnet_maisi.py +++ b/tests/test_controlnet_maisi.py @@ -130,6 +130,7 @@ @SkipIfBeforePyTorchVersion((2, 0)) @skipUnless(has_generative, "monai-generative required") class TestControlNet(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_shape_unconditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape): net = ControlNetMaisi(**input_param) diff --git a/tests/test_diffusion_model_unet_maisi.py b/tests/test_diffusion_model_unet_maisi.py index b5c14192d9..059a4a4ba8 100644 --- a/tests/test_diffusion_model_unet_maisi.py +++ b/tests/test_diffusion_model_unet_maisi.py @@ -26,7 +26,6 @@ if has_generative: from monai.apps.generation.maisi.networks.diffusion_model_unet_maisi import DiffusionModelUNetMaisi - UNCOND_CASES_2D = [ [ { @@ -294,6 +293,7 @@ @skipUnless(has_generative, "monai-generative required") class TestDiffusionModelUNetMaisi2D(unittest.TestCase): + @parameterized.expand(UNCOND_CASES_2D) @skipUnless(has_einops, "Requires einops") def test_shape_unconditioned_models(self, input_param): @@ -512,6 +512,7 @@ def test_shape_with_additional_inputs(self, input_param): @skipUnless(has_generative, "monai-generative required") class TestDiffusionModelUNetMaisi3D(unittest.TestCase): + @parameterized.expand(UNCOND_CASES_3D) @skipUnless(has_einops, "Requires einops") def test_shape_unconditioned_models(self, input_param):