diff --git a/generative/networks/schedulers/ddim.py b/generative/networks/schedulers/ddim.py index e12994b3..480e79e3 100644 --- a/generative/networks/schedulers/ddim.py +++ b/generative/networks/schedulers/ddim.py @@ -103,13 +103,15 @@ def __init__( # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 - # setable values - self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].astype(np.int64)) self.clip_sample = clip_sample self.steps_offset = steps_offset + # default the number of inference timesteps to the number of train steps + self.set_timesteps(num_train_timesteps) + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/generative/networks/schedulers/pndm.py b/generative/networks/schedulers/pndm.py index 4f5e4f61..2502c517 100644 --- a/generative/networks/schedulers/pndm.py +++ b/generative/networks/schedulers/pndm.py @@ -117,13 +117,11 @@ def __init__( self.cur_sample = None self.ets = [] - # settable values - self.num_inference_steps = None + self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy() - self.prk_timesteps = torch.Tensor([]) - self.plms_timesteps = torch.Tensor([]) - self.timesteps = torch.Tensor([]) + # default the number of inference timesteps to the number of train steps + self.set_timesteps(num_train_timesteps) def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/tests/test_scheduler_pndm.py b/tests/test_scheduler_pndm.py index ff9b5ce0..0e8b49e1 100644 --- a/tests/test_scheduler_pndm.py +++ b/tests/test_scheduler_pndm.py @@ -39,13 +39,6 @@ def test_add_noise_2d_shape(self, input_param, input_shape, expected_shape): noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps) self.assertEqual(noisy.shape, expected_shape) - @parameterized.expand(TEST_CASES) - def test_error_if_timesteps_not_set(self, input_param, input_shape, expected_shape): - scheduler = PNDMScheduler(**input_param) - with self.assertRaises(ValueError): - model_output = torch.randn(input_shape) - sample = torch.randn(input_shape) - scheduler.step(model_output=model_output, timestep=500, sample=sample) @parameterized.expand(TEST_CASES) def test_step_shape(self, input_param, input_shape, expected_shape):