From 432c04dabaff76d7b5c38fcd300618c46e0f939e Mon Sep 17 00:00:00 2001 From: Joqsan Azocar Date: Wed, 4 Jan 2023 13:03:46 +0300 Subject: [PATCH] fix: DDPMScheduler.set_timesteps() --- src/diffusers/schedulers/scheduling_ddim.py | 8 ++++++++ src/diffusers/schedulers/scheduling_ddpm.py | 15 +++++++++++---- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 70cf22654873..95def885bceb 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -201,6 +201,14 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + self.num_inference_steps = num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 86edcb441fcb..7c300d4a42c1 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -184,11 +184,18 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ - num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + self.num_inference_steps = num_inference_steps - timesteps = np.arange( - 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps - )[::-1].copy() + + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) self.timesteps = torch.from_numpy(timesteps).to(device) def _get_variance(self, t, predicted_variance=None, variance_type=None):