diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 7be22017..b56b9832 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -18,6 +18,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from monai.data import decollate_batch from monai.inferers import Inferer from monai.transforms import CenterSpatialCrop, SpatialPad from monai.utils import optional_import @@ -348,8 +349,8 @@ def __init__( self.ldm_latent_shape = ldm_latent_shape self.autoencoder_latent_shape = autoencoder_latent_shape if self.ldm_latent_shape is not None: - self.ldm_resizer = SpatialPad(spatial_size=[-1] + self.ldm_latent_shape) - self.autoencoder_resizer = CenterSpatialCrop(roi_size=[-1] + self.autoencoder_latent_shape) + self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape) + self.autoencoder_resizer = CenterSpatialCrop(roi_size=self.autoencoder_latent_shape) def __call__( self, @@ -379,7 +380,7 @@ def __call__( latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor if self.ldm_latent_shape is not None: - latent = self.ldm_resizer(latent) + latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0) call = super().__call__ if isinstance(diffusion_model, SPADEDiffusionModelUNet): @@ -454,14 +455,15 @@ def sample( else: latent = outputs - if self.ldm_latent_shape is not None: - latent = self.autoencoder_resizer(latent) - latent_intermediates = [self.autoencoder_resizer(l) for l in latent_intermediates] + if self.autoencoder_latent_shape is not None: + latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) + latent_intermediates = [ + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates + ] decode = autoencoder_model.decode_stage_2_outputs if isinstance(autoencoder_model, SPADEAutoencoderKL): decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) - image = decode(latent / self.scale_factor) if save_intermediates: @@ -521,7 +523,7 @@ def get_likelihood( latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor if self.ldm_latent_shape is not None: - latents = self.ldm_resizer(latents) + latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0) get_likelihood = super().get_likelihood if isinstance(diffusion_model, SPADEDiffusionModelUNet): @@ -598,7 +600,7 @@ def __call__( diffuse = diffusion_model if isinstance(diffusion_model, SPADEDiffusionModelUNet): - diffuse = partial(diffusion_model, seg = seg) + diffuse = partial(diffusion_model, seg=seg) prediction = diffuse( x=noisy_image, @@ -746,7 +748,7 @@ def get_likelihood( diffuse = diffusion_model if isinstance(diffusion_model, SPADEDiffusionModelUNet): - diffuse = partial(diffusion_model, seg = seg) + diffuse = partial(diffusion_model, seg=seg) if mode == "concat": noisy_image = torch.cat([noisy_image, conditioning], dim=1) @@ -832,6 +834,7 @@ def get_likelihood( else: return total_kl + class ControlNetLatentDiffusionInferer(ControlNetDiffusionInferer): """ ControlNetLatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, controlnet, @@ -861,7 +864,7 @@ def __init__( self.ldm_latent_shape = ldm_latent_shape self.autoencoder_latent_shape = autoencoder_latent_shape if self.ldm_latent_shape is not None: - self.ldm_resizer = SpatialPad(spatial_size=[-1] + self.ldm_latent_shape) + self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape) self.autoencoder_resizer = CenterSpatialCrop(roi_size=[-1] + self.autoencoder_latent_shape) def __call__( @@ -896,7 +899,8 @@ def __call__( latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor if self.ldm_latent_shape is not None: - latent = self.ldm_resizer(latent) + latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0) + if cn_cond.shape[2:] != latent.shape[2:]: cn_cond = F.interpolate(cn_cond, latent.shape[2:]) @@ -985,9 +989,11 @@ def sample( else: latent = outputs - if self.ldm_latent_shape is not None: - latent = self.autoencoder_resizer(latent) - latent_intermediates = [self.autoencoder_resizer(l) for l in latent_intermediates] + if self.autoencoder_latent_shape is not None: + latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) + latent_intermediates = [ + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates + ] decode = autoencoder_model.decode_stage_2_outputs if isinstance(autoencoder_model, SPADEAutoencoderKL): @@ -1060,7 +1066,7 @@ def get_likelihood( cn_cond = F.interpolate(cn_cond, latents.shape[2:]) if self.ldm_latent_shape is not None: - latents = self.ldm_resizer(latents) + latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0) get_likelihood = super().get_likelihood if isinstance(diffusion_model, SPADEDiffusionModelUNet): @@ -1085,6 +1091,7 @@ def get_likelihood( outputs = (outputs[0], intermediates) return outputs + class VQVAETransformerInferer(Inferer): """ Class to perform inference with a VQVAE + Transformer model.