Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 60 additions & 45 deletions tests/test_latent_diffusion_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,26 @@
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 1,
"num_channels": (8, 8, 8),
"num_channels": (4, 4),
"latent_channels": 3,
"attention_levels": [False, False, False],
"attention_levels": [False, False],
"num_res_blocks": 1,
"with_encoder_nonlocal_attn": False,
"with_decoder_nonlocal_attn": False,
"norm_num_groups": 8,
"norm_num_groups": 4,
},
{
"spatial_dims": 2,
"in_channels": 3,
"out_channels": 3,
"num_channels": [8, 8, 8],
"norm_num_groups": 8,
"attention_levels": [False, False, True],
"num_channels": [4, 4],
"norm_num_groups": 4,
"attention_levels": [False, False],
"num_res_blocks": 1,
"num_head_channels": 8,
"num_head_channels": 4,
},
(1, 1, 32, 32),
(1, 3, 8, 8),
(1, 1, 8, 8),
(1, 3, 4, 4),
],
[
"VQVAE",
Expand All @@ -58,23 +58,23 @@
"downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)),
"upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),
"num_res_layers": 1,
"num_channels": [8, 8],
"num_res_channels": [8, 8],
"num_channels": [4, 4],
"num_res_channels": [4, 4],
"num_embeddings": 16,
"embedding_dim": 3,
},
{
"spatial_dims": 2,
"in_channels": 3,
"out_channels": 3,
"num_channels": [8, 8, 8],
"num_channels": [8, 8],
"norm_num_groups": 8,
"attention_levels": [False, False, True],
"attention_levels": [False, False],
"num_res_blocks": 1,
"num_head_channels": 8,
},
(1, 1, 32, 32),
(1, 3, 8, 8),
(1, 1, 16, 16),
(1, 3, 4, 4),
],
]

Expand All @@ -83,66 +83,75 @@ class TestDiffusionSamplingInferer(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_prediction_shape(self, model_type, autoencoder_params, stage_2_params, input_shape, latent_shape):
if model_type == "AutoencoderKL":
autoencoder_model = AutoencoderKL(**autoencoder_params)
stage_1 = AutoencoderKL(**autoencoder_params)
if model_type == "VQVAE":
autoencoder_model = VQVAE(**autoencoder_params)
stage_1 = VQVAE(**autoencoder_params)
stage_2 = DiffusionModelUNet(**stage_2_params)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
autoencoder_model.to(device)
stage_1.to(device)
stage_2.to(device)
autoencoder_model.eval()
autoencoder_model.train()
stage_1.eval()
stage_2.eval()

input = torch.randn(input_shape).to(device)
noise = torch.randn(latent_shape).to(device)
scheduler = DDPMScheduler(num_train_timesteps=10)
inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
scheduler.set_timesteps(num_inference_steps=10)

timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
prediction = inferer(
inputs=input, autoencoder_model=autoencoder_model, diffusion_model=stage_2, noise=noise, timesteps=timesteps
inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps
)
self.assertEqual(prediction.shape, latent_shape)

@parameterized.expand(TEST_CASES)
def test_sample_shape(self, model_type, autoencoder_params, stage_2_params, input_shape, latent_shape):
if model_type == "AutoencoderKL":
autoencoder_model = AutoencoderKL(**autoencoder_params)
stage_1 = AutoencoderKL(**autoencoder_params)
if model_type == "VQVAE":
autoencoder_model = VQVAE(**autoencoder_params)
stage_1 = VQVAE(**autoencoder_params)
stage_2 = DiffusionModelUNet(**stage_2_params)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
autoencoder_model.to(device)
stage_1.to(device)
stage_2.to(device)
autoencoder_model.eval()
autoencoder_model.train()
stage_1.eval()
stage_2.eval()

noise = torch.randn(latent_shape).to(device)
scheduler = DDPMScheduler(num_train_timesteps=10)
inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
scheduler.set_timesteps(num_inference_steps=10)

sample = inferer.sample(
input_noise=noise, autoencoder_model=autoencoder_model, diffusion_model=stage_2, scheduler=scheduler
input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler
)
self.assertEqual(sample.shape, input_shape)

@parameterized.expand(TEST_CASES)
def test_sample_intermediates(self, model_type, autoencoder_params, stage_2_params, input_shape, latent_shape):
if model_type == "AutoencoderKL":
autoencoder_model = AutoencoderKL(**autoencoder_params)
stage_1 = AutoencoderKL(**autoencoder_params)
if model_type == "VQVAE":
autoencoder_model = VQVAE(**autoencoder_params)
stage_1 = VQVAE(**autoencoder_params)
stage_2 = DiffusionModelUNet(**stage_2_params)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
autoencoder_model.to(device)
stage_1.to(device)
stage_2.to(device)
autoencoder_model.eval()
autoencoder_model.train()
stage_1.eval()
stage_2.eval()

noise = torch.randn(latent_shape).to(device)
scheduler = DDPMScheduler(num_train_timesteps=10)
inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
scheduler.set_timesteps(num_inference_steps=10)

sample, intermediates = inferer.sample(
input_noise=noise,
autoencoder_model=autoencoder_model,
autoencoder_model=stage_1,
diffusion_model=stage_2,
scheduler=scheduler,
save_intermediates=True,
Expand All @@ -154,22 +163,25 @@ def test_sample_intermediates(self, model_type, autoencoder_params, stage_2_para
@parameterized.expand(TEST_CASES)
def test_get_likelihoods(self, model_type, autoencoder_params, stage_2_params, input_shape, latent_shape):
if model_type == "AutoencoderKL":
autoencoder_model = AutoencoderKL(**autoencoder_params)
stage_1 = AutoencoderKL(**autoencoder_params)
if model_type == "VQVAE":
autoencoder_model = VQVAE(**autoencoder_params)
stage_1 = VQVAE(**autoencoder_params)
stage_2 = DiffusionModelUNet(**stage_2_params)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
autoencoder_model.to(device)
stage_1.to(device)
stage_2.to(device)
autoencoder_model.eval()
autoencoder_model.train()
stage_1.eval()
stage_2.eval()

input = torch.randn(input_shape).to(device)
scheduler = DDPMScheduler(num_train_timesteps=10)
inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
scheduler.set_timesteps(num_inference_steps=10)

sample, intermediates = inferer.get_likelihood(
inputs=input,
autoencoder_model=autoencoder_model,
autoencoder_model=stage_1,
diffusion_model=stage_2,
scheduler=scheduler,
save_intermediates=True,
Expand All @@ -180,22 +192,25 @@ def test_get_likelihoods(self, model_type, autoencoder_params, stage_2_params, i
@parameterized.expand(TEST_CASES)
def test_resample_likelihoods(self, model_type, autoencoder_params, stage_2_params, input_shape, latent_shape):
if model_type == "AutoencoderKL":
autoencoder_model = AutoencoderKL(**autoencoder_params)
stage_1 = AutoencoderKL(**autoencoder_params)
if model_type == "VQVAE":
autoencoder_model = VQVAE(**autoencoder_params)
stage_1 = VQVAE(**autoencoder_params)
stage_2 = DiffusionModelUNet(**stage_2_params)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
autoencoder_model.to(device)
stage_1.to(device)
stage_2.to(device)
autoencoder_model.eval()
autoencoder_model.train()
stage_1.eval()
stage_2.eval()

input = torch.randn(input_shape).to(device)
scheduler = DDPMScheduler(num_train_timesteps=10)
inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
scheduler.set_timesteps(num_inference_steps=10)

sample, intermediates = inferer.get_likelihood(
inputs=input,
autoencoder_model=autoencoder_model,
autoencoder_model=stage_1,
diffusion_model=stage_2,
scheduler=scheduler,
save_intermediates=True,
Expand Down