diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 84ca866e..3b4f6547 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -59,6 +59,7 @@ def __call__( return prediction + @torch.no_grad() def sample( self, input_noise: torch.Tensor, @@ -89,10 +90,9 @@ def sample( intermediates = [] for t in progress_bar: # 1. predict noise model_output - with torch.no_grad(): - model_output = diffusion_model( - image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning - ) + model_output = diffusion_model( + image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning + ) # 2. compute previous image: x_t -> x_t-1 image, _ = scheduler.step(model_output, t, image) @@ -310,6 +310,7 @@ def __call__( return prediction + @torch.no_grad() def sample( self, input_noise: torch.Tensor, @@ -347,16 +348,12 @@ def sample( else: latent = outputs - with torch.no_grad(): - image = autoencoder_model.decode_stage_2_outputs(latent / self.scale_factor) + image = autoencoder_model.decode_stage_2_outputs(latent / self.scale_factor) if save_intermediates: intermediates = [] for latent_intermediate in latent_intermediates: - with torch.no_grad(): - intermediates.append( - autoencoder_model.decode_stage_2_outputs(latent_intermediate / self.scale_factor) - ) + intermediates.append(autoencoder_model.decode_stage_2_outputs(latent_intermediate / self.scale_factor)) return image, intermediates else: