diff --git a/generative/networks/nets/autoencoderkl.py b/generative/networks/nets/autoencoderkl.py index 366c8ee2..7989d13f 100644 --- a/generative/networks/nets/autoencoderkl.py +++ b/generative/networks/nets/autoencoderkl.py @@ -612,7 +612,9 @@ def __init__( num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) if len(num_res_blocks) != len(num_channels): - raise ValueError("`num_res_blocks` should be a single integer or a tuple of integers with the same length as `num_channels`.") + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as `num_channels`." + ) self.encoder = Encoder( spatial_dims=spatial_dims, diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index 38b532ae..54dfa277 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -1655,7 +1655,9 @@ def __init__( num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) if len(num_res_blocks) != len(num_channels): - raise ValueError("`num_res_blocks` should be a single integer or a tuple of integers with the same length as `num_channels`.") + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as `num_channels`." + ) self.in_channels = in_channels self.block_out_channels = num_channels diff --git a/generative/networks/schedulers/ddim.py b/generative/networks/schedulers/ddim.py index 480e79e3..d1395113 100644 --- a/generative/networks/schedulers/ddim.py +++ b/generative/networks/schedulers/ddim.py @@ -103,7 +103,6 @@ def __init__( # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 - self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].astype(np.int64)) self.clip_sample = clip_sample diff --git a/generative/networks/schedulers/pndm.py b/generative/networks/schedulers/pndm.py index 2502c517..0a1f2018 100644 --- a/generative/networks/schedulers/pndm.py +++ b/generative/networks/schedulers/pndm.py @@ -117,11 +117,11 @@ def __init__( self.cur_sample = None self.ets = [] - self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy() # default the number of inference timesteps to the number of train steps self.set_timesteps(num_train_timesteps) + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/tests/test_scheduler_pndm.py b/tests/test_scheduler_pndm.py index 0e8b49e1..ee0cda29 100644 --- a/tests/test_scheduler_pndm.py +++ b/tests/test_scheduler_pndm.py @@ -39,7 +39,6 @@ def test_add_noise_2d_shape(self, input_param, input_shape, expected_shape): noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps) self.assertEqual(noisy.shape, expected_shape) - @parameterized.expand(TEST_CASES) def test_step_shape(self, input_param, input_shape, expected_shape): scheduler = PNDMScheduler(**input_param)