diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index d7516fa601e1..7dbbe35f3d55 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -15,6 +15,7 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver import math +from collections import defaultdict from typing import List, Optional, Tuple, Union import numpy as np @@ -274,11 +275,6 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc self.sigmas = torch.from_numpy(sigmas) - # when num_inference_steps == num_train_timesteps, we can end up with - # duplicates in timesteps. - _, unique_indices = np.unique(timesteps, return_index=True) - timesteps = timesteps[np.sort(unique_indices)] - self.timesteps = torch.from_numpy(timesteps).to(device) self.num_inference_steps = len(timesteps) @@ -288,6 +284,9 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc ] * self.config.solver_order self.lower_order_nums = 0 + # add an index counter for schedulers that allow duplicated timesteps + self._index_counter = defaultdict(int) + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ @@ -660,11 +659,25 @@ def step( if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) - step_index = (self.timesteps == timestep).nonzero() - if len(step_index) == 0: + indices = (self.timesteps == timestep).nonzero() + timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep + + if len(indices) == 0: step_index = len(self.timesteps) - 1 else: - step_index = step_index.item() + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + if len(self._index_counter) == 0: + pos = 1 if len(indices) > 1 else 0 + else: + pos = self._index_counter[timestep_int] + step_index = indices[pos].item() + + # advance index counter by 1 + self._index_counter[timestep_int] += 1 + prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] lower_order_final = ( (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py index c9935780b983..86b24af24095 100644 --- a/tests/schedulers/test_scheduler_dpm_multi.py +++ b/tests/schedulers/test_scheduler_dpm_multi.py @@ -264,10 +264,10 @@ def test_fp16_support(self): assert sample.dtype == torch.float16 - def test_unique_timesteps(self, **config): + def test_duplicated_timesteps(self, **config): for scheduler_class in self.scheduler_classes: scheduler_config = self.get_scheduler_config(**config) scheduler = scheduler_class(**scheduler_config) scheduler.set_timesteps(scheduler.config.num_train_timesteps) - assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps + assert len(scheduler.timesteps) == scheduler.num_inference_steps