From 67bc916fabaab3152a2292f551b01036c49f1cec Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Sat, 3 Feb 2024 11:46:25 +0000 Subject: [PATCH] Fix bug on inferer sample methods when there needs to be a cropping before passing the output of the DM to the AEKL, and save_intermediates is False. --- generative/inferers/inferer.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index b56b9832..5880d456 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -457,9 +457,10 @@ def sample( if self.autoencoder_latent_shape is not None: latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) - latent_intermediates = [ - torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates - ] + if save_intermediates: + latent_intermediates = [ + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates + ] decode = autoencoder_model.decode_stage_2_outputs if isinstance(autoencoder_model, SPADEAutoencoderKL): @@ -991,9 +992,10 @@ def sample( if self.autoencoder_latent_shape is not None: latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) - latent_intermediates = [ - torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates - ] + if save_intermediates: + latent_intermediates = [ + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates + ] decode = autoencoder_model.decode_stage_2_outputs if isinstance(autoencoder_model, SPADEAutoencoderKL):