diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index e9b04e9ca1cc..ffe3ec64f9ae 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -171,6 +171,7 @@ def __init__( self.model_outputs = [None] * solver_order self.lower_order_nums = 0 + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_timesteps def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -181,14 +182,22 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ - self.num_inference_steps = num_inference_steps timesteps = ( np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) ) + + # when num_inference_steps == num_train_timesteps, we can end up with + # duplicates in timesteps. + _, unique_indices = np.unique(timesteps, return_index=True) + timesteps = timesteps[np.sort(unique_indices)] + self.timesteps = torch.from_numpy(timesteps).to(device) + + self.num_inference_steps = len(timesteps) + self.model_outputs = [ None, ] * self.config.solver_order diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 28f0da2c41fb..c41fc7e16a4f 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -192,14 +192,22 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ - self.num_inference_steps = num_inference_steps timesteps = ( np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) ) + + # when num_inference_steps == num_train_timesteps, we can end up with + # duplicates in timesteps. + _, unique_indices = np.unique(timesteps, return_index=True) + timesteps = timesteps[np.sort(unique_indices)] + self.timesteps = torch.from_numpy(timesteps).to(device) + + self.num_inference_steps = len(timesteps) + self.model_outputs = [ None, ] * self.config.solver_order diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 0d164088105c..07e8b152b9d3 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -194,21 +194,29 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ - self.num_inference_steps = num_inference_steps timesteps = ( np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) ) + + # when num_inference_steps == num_train_timesteps, we can end up with + # duplicates in timesteps. + _, unique_indices = np.unique(timesteps, return_index=True) + timesteps = timesteps[np.sort(unique_indices)] + self.timesteps = torch.from_numpy(timesteps).to(device) + + self.num_inference_steps = len(timesteps) + self.model_outputs = [ None, ] * self.config.solver_order self.lower_order_nums = 0 self.last_sample = None if self.solver_p: - self.solver_p.set_timesteps(num_inference_steps, device=device) + self.solver_p.set_timesteps(self.num_inference_steps, device=device) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py index 9da43714f570..a5a1d09c6b65 100644 --- a/tests/schedulers/test_scheduler_dpm_multi.py +++ b/tests/schedulers/test_scheduler_dpm_multi.py @@ -243,3 +243,11 @@ def test_fp16_support(self): sample = scheduler.step(residual, t, sample).prev_sample assert sample.dtype == torch.float16 + + def test_unique_timesteps(self, **config): + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(scheduler.config.num_train_timesteps) + assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps diff --git a/tests/schedulers/test_scheduler_unipc.py b/tests/schedulers/test_scheduler_unipc.py index 6154c8e2d625..62cffc67388c 100644 --- a/tests/schedulers/test_scheduler_unipc.py +++ b/tests/schedulers/test_scheduler_unipc.py @@ -229,3 +229,11 @@ def test_fp16_support(self): sample = scheduler.step(residual, t, sample).prev_sample assert sample.dtype == torch.float16 + + def test_unique_timesteps(self, **config): + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(scheduler.config.num_train_timesteps) + assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps