From d85a3090f5e9ba2d18e2d0fe0efb37f9232a4e79 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Mon, 18 Dec 2023 11:17:31 +0000 Subject: [PATCH 1/2] Issue was coming from the definition of SpatialPad (self.ldm_resizer) and Crop (self.autoencoder_resizer). The spatial_size passed included a [-1] to cover the channel dimension. The code, as it was written, made the assumption that this channel dimension was a spatial dimension and that the batch dimension was the channel one, leading to errors related to the affine transform of the MetaTensor being wrong. self.ldm_resizer should operate on an unbatched version of the tensor, hence we changed the call to the resizers by one that uses decollate and then stacks the elements of the batch together again. --- generative/inferers/inferer.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 7be22017..44248d06 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -14,7 +14,7 @@ import math from collections.abc import Callable, Sequence from functools import partial - +from monai.data import decollate_batch import torch import torch.nn as nn import torch.nn.functional as F @@ -348,8 +348,8 @@ def __init__( self.ldm_latent_shape = ldm_latent_shape self.autoencoder_latent_shape = autoencoder_latent_shape if self.ldm_latent_shape is not None: - self.ldm_resizer = SpatialPad(spatial_size=[-1] + self.ldm_latent_shape) - self.autoencoder_resizer = CenterSpatialCrop(roi_size=[-1] + self.autoencoder_latent_shape) + self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape) + self.autoencoder_resizer = CenterSpatialCrop(roi_size=self.autoencoder_latent_shape) def __call__( self, @@ -379,7 +379,7 @@ def __call__( latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor if self.ldm_latent_shape is not None: - latent = self.ldm_resizer(latent) + latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)],0) call = super().__call__ if isinstance(diffusion_model, SPADEDiffusionModelUNet): @@ -454,14 +454,14 @@ def sample( else: latent = outputs - if self.ldm_latent_shape is not None: - latent = self.autoencoder_resizer(latent) - latent_intermediates = [self.autoencoder_resizer(l) for l in latent_intermediates] + if self.autoencoder_latent_shape is not None: + latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) + latent_intermediates = [torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) + for l in latent_intermediates] decode = autoencoder_model.decode_stage_2_outputs if isinstance(autoencoder_model, SPADEAutoencoderKL): decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) - image = decode(latent / self.scale_factor) if save_intermediates: @@ -521,7 +521,7 @@ def get_likelihood( latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor if self.ldm_latent_shape is not None: - latents = self.ldm_resizer(latents) + latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0) get_likelihood = super().get_likelihood if isinstance(diffusion_model, SPADEDiffusionModelUNet): @@ -861,7 +861,7 @@ def __init__( self.ldm_latent_shape = ldm_latent_shape self.autoencoder_latent_shape = autoencoder_latent_shape if self.ldm_latent_shape is not None: - self.ldm_resizer = SpatialPad(spatial_size=[-1] + self.ldm_latent_shape) + self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape) self.autoencoder_resizer = CenterSpatialCrop(roi_size=[-1] + self.autoencoder_latent_shape) def __call__( @@ -896,7 +896,8 @@ def __call__( latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor if self.ldm_latent_shape is not None: - latent = self.ldm_resizer(latent) + latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0) + if cn_cond.shape[2:] != latent.shape[2:]: cn_cond = F.interpolate(cn_cond, latent.shape[2:]) @@ -985,9 +986,10 @@ def sample( else: latent = outputs - if self.ldm_latent_shape is not None: - latent = self.autoencoder_resizer(latent) - latent_intermediates = [self.autoencoder_resizer(l) for l in latent_intermediates] + if self.autoencoder_latent_shape is not None: + latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) + latent_intermediates = [torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) + for l in latent_intermediates] decode = autoencoder_model.decode_stage_2_outputs if isinstance(autoencoder_model, SPADEAutoencoderKL): @@ -1060,7 +1062,7 @@ def get_likelihood( cn_cond = F.interpolate(cn_cond, latents.shape[2:]) if self.ldm_latent_shape is not None: - latents = self.ldm_resizer(latents) + latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0) get_likelihood = super().get_likelihood if isinstance(diffusion_model, SPADEDiffusionModelUNet): From 1f4afa4f23c03669717fa41be9b24586901c3a30 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 20 Dec 2023 14:01:56 +0000 Subject: [PATCH 2/2] Formatting --- generative/inferers/inferer.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 44248d06..b56b9832 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -14,10 +14,11 @@ import math from collections.abc import Callable, Sequence from functools import partial -from monai.data import decollate_batch + import torch import torch.nn as nn import torch.nn.functional as F +from monai.data import decollate_batch from monai.inferers import Inferer from monai.transforms import CenterSpatialCrop, SpatialPad from monai.utils import optional_import @@ -379,7 +380,7 @@ def __call__( latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor if self.ldm_latent_shape is not None: - latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)],0) + latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0) call = super().__call__ if isinstance(diffusion_model, SPADEDiffusionModelUNet): @@ -456,8 +457,9 @@ def sample( if self.autoencoder_latent_shape is not None: latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) - latent_intermediates = [torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) - for l in latent_intermediates] + latent_intermediates = [ + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates + ] decode = autoencoder_model.decode_stage_2_outputs if isinstance(autoencoder_model, SPADEAutoencoderKL): @@ -598,7 +600,7 @@ def __call__( diffuse = diffusion_model if isinstance(diffusion_model, SPADEDiffusionModelUNet): - diffuse = partial(diffusion_model, seg = seg) + diffuse = partial(diffusion_model, seg=seg) prediction = diffuse( x=noisy_image, @@ -746,7 +748,7 @@ def get_likelihood( diffuse = diffusion_model if isinstance(diffusion_model, SPADEDiffusionModelUNet): - diffuse = partial(diffusion_model, seg = seg) + diffuse = partial(diffusion_model, seg=seg) if mode == "concat": noisy_image = torch.cat([noisy_image, conditioning], dim=1) @@ -832,6 +834,7 @@ def get_likelihood( else: return total_kl + class ControlNetLatentDiffusionInferer(ControlNetDiffusionInferer): """ ControlNetLatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, controlnet, @@ -988,8 +991,9 @@ def sample( if self.autoencoder_latent_shape is not None: latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) - latent_intermediates = [torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) - for l in latent_intermediates] + latent_intermediates = [ + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates + ] decode = autoencoder_model.decode_stage_2_outputs if isinstance(autoencoder_model, SPADEAutoencoderKL): @@ -1087,6 +1091,7 @@ def get_likelihood( outputs = (outputs[0], intermediates) return outputs + class VQVAETransformerInferer(Inferer): """ Class to perform inference with a VQVAE + Transformer model.