diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index b56b9832..5880d456 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -457,9 +457,10 @@ def sample( if self.autoencoder_latent_shape is not None: latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) - latent_intermediates = [ - torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates - ] + if save_intermediates: + latent_intermediates = [ + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates + ] decode = autoencoder_model.decode_stage_2_outputs if isinstance(autoencoder_model, SPADEAutoencoderKL): @@ -991,9 +992,10 @@ def sample( if self.autoencoder_latent_shape is not None: latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) - latent_intermediates = [ - torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates - ] + if save_intermediates: + latent_intermediates = [ + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates + ] decode = autoencoder_model.decode_stage_2_outputs if isinstance(autoencoder_model, SPADEAutoencoderKL):