From 3993cf30f2c1b0986425dd6251b3cc4d53fda203 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 11 Feb 2023 10:47:18 +0000 Subject: [PATCH] Reformat code and use decorators Signed-off-by: Walter Hugo Lopez Pinaya --- generative/inferers/inferer.py | 17 +++++++---------- generative/networks/nets/autoencoderkl.py | 5 ++++- .../networks/nets/diffusion_model_unet.py | 5 ++++- generative/networks/schedulers/ddim.py | 1 - generative/networks/schedulers/pndm.py | 2 +- tests/test_scheduler_pndm.py | 1 - 6 files changed, 16 insertions(+), 15 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 84ca866e..3b4f6547 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -59,6 +59,7 @@ def __call__( return prediction + @torch.no_grad() def sample( self, input_noise: torch.Tensor, @@ -89,10 +90,9 @@ def sample( intermediates = [] for t in progress_bar: # 1. predict noise model_output - with torch.no_grad(): - model_output = diffusion_model( - image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning - ) + model_output = diffusion_model( + image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning + ) # 2. compute previous image: x_t -> x_t-1 image, _ = scheduler.step(model_output, t, image) @@ -310,6 +310,7 @@ def __call__( return prediction + @torch.no_grad() def sample( self, input_noise: torch.Tensor, @@ -347,16 +348,12 @@ def sample( else: latent = outputs - with torch.no_grad(): - image = autoencoder_model.decode_stage_2_outputs(latent / self.scale_factor) + image = autoencoder_model.decode_stage_2_outputs(latent / self.scale_factor) if save_intermediates: intermediates = [] for latent_intermediate in latent_intermediates: - with torch.no_grad(): - intermediates.append( - autoencoder_model.decode_stage_2_outputs(latent_intermediate / self.scale_factor) - ) + intermediates.append(autoencoder_model.decode_stage_2_outputs(latent_intermediate / self.scale_factor)) return image, intermediates else: diff --git a/generative/networks/nets/autoencoderkl.py b/generative/networks/nets/autoencoderkl.py index 366c8ee2..23750839 100644 --- a/generative/networks/nets/autoencoderkl.py +++ b/generative/networks/nets/autoencoderkl.py @@ -612,7 +612,10 @@ def __init__( num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) if len(num_res_blocks) != len(num_channels): - raise ValueError("`num_res_blocks` should be a single integer or a tuple of integers with the same length as `num_channels`.") + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + "`num_channels`." + ) self.encoder = Encoder( spatial_dims=spatial_dims, diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index 38b532ae..b651f206 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -1655,7 +1655,10 @@ def __init__( num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) if len(num_res_blocks) != len(num_channels): - raise ValueError("`num_res_blocks` should be a single integer or a tuple of integers with the same length as `num_channels`.") + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + "`num_channels`." + ) self.in_channels = in_channels self.block_out_channels = num_channels diff --git a/generative/networks/schedulers/ddim.py b/generative/networks/schedulers/ddim.py index 480e79e3..d1395113 100644 --- a/generative/networks/schedulers/ddim.py +++ b/generative/networks/schedulers/ddim.py @@ -103,7 +103,6 @@ def __init__( # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 - self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].astype(np.int64)) self.clip_sample = clip_sample diff --git a/generative/networks/schedulers/pndm.py b/generative/networks/schedulers/pndm.py index 2502c517..0a1f2018 100644 --- a/generative/networks/schedulers/pndm.py +++ b/generative/networks/schedulers/pndm.py @@ -117,11 +117,11 @@ def __init__( self.cur_sample = None self.ets = [] - self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy() # default the number of inference timesteps to the number of train steps self.set_timesteps(num_train_timesteps) + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/tests/test_scheduler_pndm.py b/tests/test_scheduler_pndm.py index 0e8b49e1..ee0cda29 100644 --- a/tests/test_scheduler_pndm.py +++ b/tests/test_scheduler_pndm.py @@ -39,7 +39,6 @@ def test_add_noise_2d_shape(self, input_param, input_shape, expected_shape): noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps) self.assertEqual(noisy.shape, expected_shape) - @parameterized.expand(TEST_CASES) def test_step_shape(self, input_param, input_shape, expected_shape): scheduler = PNDMScheduler(**input_param)