From 119cf05bfa0dd1136f308f03897b25bacd13739b Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 10 Nov 2023 03:06:48 +0000 Subject: [PATCH 1/3] fix --- .../schedulers/scheduling_dpmsolver_multistep.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 479a27de41ea..1a9bf070e62e 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -768,13 +768,9 @@ def _init_step_index(self, timestep): if len(index_candidates) == 0: step_index = len(self.timesteps) - 1 # The sigma index that is taken for the **very** first `step` - # is always the second index (or the last index if there is only 1) - # This way we can ensure we don't accidentally skip a sigma in - # case we start in the middle of the denoising schedule (e.g. for image-to-image) - elif len(index_candidates) > 1: - step_index = index_candidates[1].item() - else: - step_index = index_candidates[0].item() + # is always the first index. This way we can ensure we don't accidentally skip a sigma in + # case we start with a duplicated timestep + step_index = index_candidates[0].item() self._step_index = step_index @@ -885,7 +881,7 @@ def add_noise( schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + step_indices = [(schedule_timesteps == t).nonzero()[0].item() for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): From c2a8afc60c6596b0fcfee0458a74923d363d94bd Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 10 Nov 2023 04:05:44 +0000 Subject: [PATCH 2/3] Revert "fix" This reverts commit 119cf05bfa0dd1136f308f03897b25bacd13739b. --- .../schedulers/scheduling_dpmsolver_multistep.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 1a9bf070e62e..479a27de41ea 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -768,9 +768,13 @@ def _init_step_index(self, timestep): if len(index_candidates) == 0: step_index = len(self.timesteps) - 1 # The sigma index that is taken for the **very** first `step` - # is always the first index. This way we can ensure we don't accidentally skip a sigma in - # case we start with a duplicated timestep - step_index = index_candidates[0].item() + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() self._step_index = step_index @@ -881,7 +885,7 @@ def add_noise( schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - step_indices = [(schedule_timesteps == t).nonzero()[0].item() for t in timesteps] + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): From df493b90ac0326a7e8d6989e9d1a2ea875768d0f Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 10 Nov 2023 07:18:32 +0000 Subject: [PATCH 3/3] draft --- .../pipeline_stable_diffusion_img2img.py | 1 + .../scheduling_dpmsolver_multistep.py | 51 ++++++++++++------- 2 files changed, 34 insertions(+), 18 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 40daecfa913f..685fea3b47e4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -551,6 +551,7 @@ def get_timesteps(self, num_inference_steps, strength, device): t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + self.scheduler._step_index_init = t_start * self.scheduler.order return timesteps, num_inference_steps - t_start diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 479a27de41ea..2fa34c59c5bc 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -215,6 +215,7 @@ def __init__( self.model_outputs = [None] * solver_order self.lower_order_nums = 0 self._step_index = None + self._step_index_init = None @property def step_index(self): @@ -222,6 +223,13 @@ def step_index(self): The index counter for current timestep. It will increae 1 after each scheduler step. """ return self._step_index + + @property + def step_index_init(self): + """ + the first step_index for denoising loop. + """ + return self._step_index_init def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): """ @@ -760,23 +768,28 @@ def multistep_dpm_solver_third_order_update( return x_t def _init_step_index(self, timestep): - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - - index_candidates = (self.timesteps == timestep).nonzero() - - if len(index_candidates) == 0: - step_index = len(self.timesteps) - 1 - # The sigma index that is taken for the **very** first `step` - # is always the second index (or the last index if there is only 1) - # This way we can ensure we don't accidentally skip a sigma in - # case we start in the middle of the denoising schedule (e.g. for image-to-image) - elif len(index_candidates) > 1: - step_index = index_candidates[1].item() - else: - step_index = index_candidates[0].item() + + if self.step_index_init is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + index_candidates = (self.timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() - self._step_index = step_index + self._step_index_init = step_index + self._step_index = step_index + else: + self._step_index = self.step_index_init def step( self, @@ -884,8 +897,10 @@ def add_noise( else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + if self.step_index_init is None: + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + else: + step_indices = [self.step_index_init] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape):