Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions generative/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
import torch.nn.functional as F
from monai.inferers import Inferer
from monai.utils import optional_import
from monai.transforms import SpatialPad, CenterSpatialCrop

tqdm, has_tqdm = optional_import("tqdm", name="tqdm")


class DiffusionInferer(Inferer):
"""
DiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal forward pass
Expand Down Expand Up @@ -293,7 +293,6 @@ def _get_decoder_log_likelihood(
assert log_probs.shape == inputs.shape
return log_probs


class LatentDiffusionInferer(DiffusionInferer):
"""
LatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, and a scheduler, and can
Expand All @@ -303,11 +302,24 @@ class LatentDiffusionInferer(DiffusionInferer):
scheduler: a scheduler to be used in combination with `unet` to denoise the encoded image latents.
scale_factor: scale factor to multiply the values of the latent representation before processing it by the
second stage.
ldm_latent_shape: desired spatial latent space shape. Used if there is a difference in the autoencoder model's latent shape.
autoencoder_latent_shape: autoencoder_latent_shape: autoencoder spatial latent space shape. Used if there is a difference between the autoencoder's latent shape and the DM shape.
"""

def __init__(self, scheduler: nn.Module, scale_factor: float = 1.0) -> None:
def __init__(self, scheduler: nn.Module, scale_factor: float = 1.0,
ldm_latent_shape: list | None = None,
autoencoder_latent_shape: list | None = None) -> None:

super().__init__(scheduler=scheduler)
self.scale_factor = scale_factor
if (ldm_latent_shape is None) ^ (autoencoder_latent_shape is None):
raise ValueError("If ldm_latent_shape is None, autoencoder_latent_shape must be None"
"and vice versa.")
self.ldm_latent_shape = ldm_latent_shape
self.autoencoder_latent_shape = autoencoder_latent_shape
if self.ldm_latent_shape is not None:
self.ldm_resizer = SpatialPad(spatial_size=[-1,]+self.ldm_latent_shape)
self.autoencoder_resizer = CenterSpatialCrop(roi_size=[-1,]+self.autoencoder_latent_shape)

def __call__(
self,
Expand All @@ -334,6 +346,9 @@ def __call__(
with torch.no_grad():
latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor

if self.ldm_latent_shape is not None:
latent = self.ldm_resizer(latent)

prediction = super().__call__(
inputs=latent,
diffusion_model=diffusion_model,
Expand Down Expand Up @@ -386,6 +401,10 @@ def sample(
else:
latent = outputs

if self.ldm_latent_shape is not None:
latent = self.autoencoder_resizer(latent)
latent_intermediates = [self.autoencoder_resizer(l) for l in latent_intermediates]

image = autoencoder_model.decode_stage_2_outputs(latent / self.scale_factor)

if save_intermediates:
Expand Down Expand Up @@ -437,6 +456,10 @@ def get_likelihood(
f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}"
)
latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor

if self.ldm_latent_shape is not None:
latents = self.ldm_resizer(latents)

outputs = super().get_likelihood(
inputs=latents,
diffusion_model=diffusion_model,
Expand All @@ -453,7 +476,6 @@ def get_likelihood(
outputs = (outputs[0], intermediates)
return outputs


class VQVAETransformerInferer(Inferer):
"""
Class to perform inference with a VQVAE + Transformer model.
Expand Down
117 changes: 117 additions & 0 deletions tests/test_latent_diffusion_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,89 @@
(1, 3, 4, 4, 4),
],
]
TEST_CASES_DIFF_SHAPES = [
[
"AutoencoderKL",
{
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 1,
"num_channels": (4, 4),
"latent_channels": 3,
"attention_levels": [False, False],
"num_res_blocks": 1,
"with_encoder_nonlocal_attn": False,
"with_decoder_nonlocal_attn": False,
"norm_num_groups": 4,
},
{
"spatial_dims": 2,
"in_channels": 3,
"out_channels": 3,
"num_channels": [4, 4],
"norm_num_groups": 4,
"attention_levels": [False, False],
"num_res_blocks": 1,
"num_head_channels": 4,
},
(1, 1, 12, 12),
(1, 3, 8, 8),
],
[
"VQVAE",
{
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 1,
"num_channels": [4, 4],
"num_res_layers": 1,
"num_res_channels": [4, 4],
"downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)),
"upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),
"num_embeddings": 16,
"embedding_dim": 3,
},
{
"spatial_dims": 2,
"in_channels": 3,
"out_channels": 3,
"num_channels": [8, 8],
"norm_num_groups": 8,
"attention_levels": [False, False],
"num_res_blocks": 1,
"num_head_channels": 8,
},
(1, 1, 12, 12),
(1, 3, 8, 8),
],
[
"VQVAE",
{
"spatial_dims": 3,
"in_channels": 1,
"out_channels": 1,
"num_channels": [4, 4],
"num_res_layers": 1,
"num_res_channels": [4, 4],
"downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)),
"upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),
"num_embeddings": 16,
"embedding_dim": 3,
},
{
"spatial_dims": 3,
"in_channels": 3,
"out_channels": 3,
"num_channels": [8, 8],
"norm_num_groups": 8,
"attention_levels": [False, False],
"num_res_blocks": 1,
"num_head_channels": 8,
},
(1, 1, 12, 12, 12),
(1, 3, 8, 8, 8),
],
]


class TestDiffusionSamplingInferer(unittest.TestCase):
Expand Down Expand Up @@ -325,6 +408,40 @@ def test_sample_shape_conditioned_concat(
)
self.assertEqual(sample.shape, input_shape)

@parameterized.expand(TEST_CASES_DIFF_SHAPES)
def test_sample_shape_different_latents(self,
model_type,
autoencoder_params,
stage_2_params,
input_shape,
latent_shape
):
if model_type == "AutoencoderKL":
stage_1 = AutoencoderKL(**autoencoder_params)
if model_type == "VQVAE":
stage_1 = VQVAE(**autoencoder_params)
stage_2 = DiffusionModelUNet(**stage_2_params)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
stage_1.to(device)
stage_2.to(device)
stage_1.eval()
stage_2.eval()

input = torch.randn(input_shape).to(device)
noise = torch.randn(latent_shape).to(device)
scheduler = DDPMScheduler(num_train_timesteps=10)
# We infer the VAE shape
autoencoder_latent_shape = [i//(2**(len(autoencoder_params['num_channels'])-1)) for i in input_shape[2:]]
inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0,
ldm_latent_shape=list(latent_shape[2:]),
autoencoder_latent_shape=autoencoder_latent_shape)
scheduler.set_timesteps(num_inference_steps=10)

timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
prediction = inferer(
inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps
)
self.assertEqual(prediction.shape, latent_shape)
if __name__ == "__main__":
unittest.main()