diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py index 8a830027..2d4dbb20 100644 --- a/tests/test_latent_diffusion_inferer.py +++ b/tests/test_latent_diffusion_inferer.py @@ -27,26 +27,26 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_channels": (8, 8, 8), + "num_channels": (4, 4), "latent_channels": 3, - "attention_levels": [False, False, False], + "attention_levels": [False, False], "num_res_blocks": 1, "with_encoder_nonlocal_attn": False, "with_decoder_nonlocal_attn": False, - "norm_num_groups": 8, + "norm_num_groups": 4, }, { "spatial_dims": 2, "in_channels": 3, "out_channels": 3, - "num_channels": [8, 8, 8], - "norm_num_groups": 8, - "attention_levels": [False, False, True], + "num_channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], "num_res_blocks": 1, - "num_head_channels": 8, + "num_head_channels": 4, }, - (1, 1, 32, 32), - (1, 3, 8, 8), + (1, 1, 8, 8), + (1, 3, 4, 4), ], [ "VQVAE", @@ -58,8 +58,8 @@ "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), "num_res_layers": 1, - "num_channels": [8, 8], - "num_res_channels": [8, 8], + "num_channels": [4, 4], + "num_res_channels": [4, 4], "num_embeddings": 16, "embedding_dim": 3, }, @@ -67,14 +67,14 @@ "spatial_dims": 2, "in_channels": 3, "out_channels": 3, - "num_channels": [8, 8, 8], + "num_channels": [8, 8], "norm_num_groups": 8, - "attention_levels": [False, False, True], + "attention_levels": [False, False], "num_res_blocks": 1, "num_head_channels": 8, }, - (1, 1, 32, 32), - (1, 3, 8, 8), + (1, 1, 16, 16), + (1, 3, 4, 4), ], ] @@ -83,66 +83,75 @@ class TestDiffusionSamplingInferer(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_prediction_shape(self, model_type, autoencoder_params, stage_2_params, input_shape, latent_shape): if model_type == "AutoencoderKL": - autoencoder_model = AutoencoderKL(**autoencoder_params) + stage_1 = AutoencoderKL(**autoencoder_params) if model_type == "VQVAE": - autoencoder_model = VQVAE(**autoencoder_params) + stage_1 = VQVAE(**autoencoder_params) stage_2 = DiffusionModelUNet(**stage_2_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" - autoencoder_model.to(device) + stage_1.to(device) stage_2.to(device) - autoencoder_model.eval() - autoencoder_model.train() + stage_1.eval() + stage_2.eval() + input = torch.randn(input_shape).to(device) noise = torch.randn(latent_shape).to(device) scheduler = DDPMScheduler(num_train_timesteps=10) inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() prediction = inferer( - inputs=input, autoencoder_model=autoencoder_model, diffusion_model=stage_2, noise=noise, timesteps=timesteps + inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps ) self.assertEqual(prediction.shape, latent_shape) @parameterized.expand(TEST_CASES) def test_sample_shape(self, model_type, autoencoder_params, stage_2_params, input_shape, latent_shape): if model_type == "AutoencoderKL": - autoencoder_model = AutoencoderKL(**autoencoder_params) + stage_1 = AutoencoderKL(**autoencoder_params) if model_type == "VQVAE": - autoencoder_model = VQVAE(**autoencoder_params) + stage_1 = VQVAE(**autoencoder_params) stage_2 = DiffusionModelUNet(**stage_2_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" - autoencoder_model.to(device) + stage_1.to(device) stage_2.to(device) - autoencoder_model.eval() - autoencoder_model.train() + stage_1.eval() + stage_2.eval() + noise = torch.randn(latent_shape).to(device) scheduler = DDPMScheduler(num_train_timesteps=10) inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) scheduler.set_timesteps(num_inference_steps=10) + sample = inferer.sample( - input_noise=noise, autoencoder_model=autoencoder_model, diffusion_model=stage_2, scheduler=scheduler + input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler ) self.assertEqual(sample.shape, input_shape) @parameterized.expand(TEST_CASES) def test_sample_intermediates(self, model_type, autoencoder_params, stage_2_params, input_shape, latent_shape): if model_type == "AutoencoderKL": - autoencoder_model = AutoencoderKL(**autoencoder_params) + stage_1 = AutoencoderKL(**autoencoder_params) if model_type == "VQVAE": - autoencoder_model = VQVAE(**autoencoder_params) + stage_1 = VQVAE(**autoencoder_params) stage_2 = DiffusionModelUNet(**stage_2_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" - autoencoder_model.to(device) + stage_1.to(device) stage_2.to(device) - autoencoder_model.eval() - autoencoder_model.train() + stage_1.eval() + stage_2.eval() + noise = torch.randn(latent_shape).to(device) scheduler = DDPMScheduler(num_train_timesteps=10) inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( input_noise=noise, - autoencoder_model=autoencoder_model, + autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler, save_intermediates=True, @@ -154,22 +163,25 @@ def test_sample_intermediates(self, model_type, autoencoder_params, stage_2_para @parameterized.expand(TEST_CASES) def test_get_likelihoods(self, model_type, autoencoder_params, stage_2_params, input_shape, latent_shape): if model_type == "AutoencoderKL": - autoencoder_model = AutoencoderKL(**autoencoder_params) + stage_1 = AutoencoderKL(**autoencoder_params) if model_type == "VQVAE": - autoencoder_model = VQVAE(**autoencoder_params) + stage_1 = VQVAE(**autoencoder_params) stage_2 = DiffusionModelUNet(**stage_2_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" - autoencoder_model.to(device) + stage_1.to(device) stage_2.to(device) - autoencoder_model.eval() - autoencoder_model.train() + stage_1.eval() + stage_2.eval() + input = torch.randn(input_shape).to(device) scheduler = DDPMScheduler(num_train_timesteps=10) inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.get_likelihood( inputs=input, - autoencoder_model=autoencoder_model, + autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler, save_intermediates=True, @@ -180,22 +192,25 @@ def test_get_likelihoods(self, model_type, autoencoder_params, stage_2_params, i @parameterized.expand(TEST_CASES) def test_resample_likelihoods(self, model_type, autoencoder_params, stage_2_params, input_shape, latent_shape): if model_type == "AutoencoderKL": - autoencoder_model = AutoencoderKL(**autoencoder_params) + stage_1 = AutoencoderKL(**autoencoder_params) if model_type == "VQVAE": - autoencoder_model = VQVAE(**autoencoder_params) + stage_1 = VQVAE(**autoencoder_params) stage_2 = DiffusionModelUNet(**stage_2_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" - autoencoder_model.to(device) + stage_1.to(device) stage_2.to(device) - autoencoder_model.eval() - autoencoder_model.train() + stage_1.eval() + stage_2.eval() + input = torch.randn(input_shape).to(device) scheduler = DDPMScheduler(num_train_timesteps=10) inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.get_likelihood( inputs=input, - autoencoder_model=autoencoder_model, + autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler, save_intermediates=True,