From 6cef2a03746e166c5e842d39d0168598d14eb219 Mon Sep 17 00:00:00 2001 From: Nihanth Subramanya Date: Fri, 3 Nov 2023 20:40:37 +0100 Subject: [PATCH 1/4] LMS scheduler: Allow forcing lower order steps for initial and final steps --- src/diffusers/schedulers/scheduling_lms_discrete.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 05126377763e..6308c45b159b 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -139,6 +139,7 @@ def __init__( prediction_type: str = "epsilon", timestep_spacing: str = "linspace", steps_offset: int = 0, + use_nihanth_order_dropoff: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -168,6 +169,7 @@ def __init__( self.set_timesteps(num_train_timesteps, None) self.derivatives = [] self.is_scale_input_called = False + self.use_nihanth_order_dropoff = use_nihanth_order_dropoff self._step_index = None @@ -301,6 +303,7 @@ def _init_step_index(self, timestep): step_index = index_candidates[0] self._step_index = step_index.item() + self._initial_step_index = self._step_index # copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): @@ -401,7 +404,11 @@ def step( self.derivatives.pop(0) # 3. Compute linear multistep coefficients - order = min(self.step_index + 1, order) + if self.use_nihanth_order_dropoff: + order = min(self.step_index - self._initial_step_index + 1, order) + order = min(len(self.timesteps) - self.step_index, order) + else: + order = min(self.step_index + 1, order) lms_coeffs = [self.get_lms_coefficient(order, self.step_index, curr_order) for curr_order in range(order)] # 4. Compute previous sample based on the derivatives path From b9d1a47531793485a420dadf1d432c412eb71880 Mon Sep 17 00:00:00 2001 From: Nihanth Subramanya Date: Fri, 3 Nov 2023 21:41:00 +0100 Subject: [PATCH 2/4] Simplify --- src/diffusers/schedulers/scheduling_lms_discrete.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 6308c45b159b..3dc3ab7b802a 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -303,7 +303,6 @@ def _init_step_index(self, timestep): step_index = index_candidates[0] self._step_index = step_index.item() - self._initial_step_index = self._step_index # copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): @@ -404,11 +403,10 @@ def step( self.derivatives.pop(0) # 3. Compute linear multistep coefficients + order = min(self.step_index + 1, order) if self.use_nihanth_order_dropoff: - order = min(self.step_index - self._initial_step_index + 1, order) - order = min(len(self.timesteps) - self.step_index, order) - else: - order = min(self.step_index + 1, order) + # Use order=1 for the last 15 steps + order = 1 if len(self.timesteps) - self.step_index < 15 else order lms_coeffs = [self.get_lms_coefficient(order, self.step_index, curr_order) for curr_order in range(order)] # 4. Compute previous sample based on the derivatives path From 3b4a2fd460bdf7b58e2277c663dbf62012a39df8 Mon Sep 17 00:00:00 2001 From: Nihanth Subramanya Date: Tue, 7 Nov 2023 01:03:40 +0100 Subject: [PATCH 3/4] Fix final sigma value for dpm multistep --- .../schedulers/scheduling_dpmsolver_multistep.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 479a27de41ea..ad262aeeeeea 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -267,16 +267,18 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + sigma_last = [0.0] if self.euler_at_final else sigmas[-1:] + sigmas = np.concatenate([sigmas, sigma_last]).astype(np.float32) elif self.config.use_lu_lambdas: lambdas = np.flip(log_sigmas.copy()) lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps) sigmas = np.exp(lambdas) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + sigma_last = [0.0] if self.euler_at_final else sigmas[-1:] + sigmas = np.concatenate([sigmas, sigma_last]).astype(np.float32) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + sigma_last = 0.0 if self.euler_at_final else ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) From a29d651ae54febb6c9164b6f02e41bccfb9e42fe Mon Sep 17 00:00:00 2001 From: Nihanth Subramanya Date: Wed, 22 Nov 2023 02:23:27 +0100 Subject: [PATCH 4/4] Impl denoising_start/end for SDXL ControlNet pipeline --- .../pipeline_controlnet_sd_xl_img2img.py | 80 +++++++++++++++++-- 1 file changed, 75 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 4fccd6a91b0f..d6191f05f5f8 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -782,13 +782,40 @@ def prepare_control_image( return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps - def get_timesteps(self, num_inference_steps, strength, device): + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): # get the original timestep using init_timestep - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + else: + t_start = 0 - t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + if denoising_start is not None: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_start * self.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd devirative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + timesteps = timesteps[-num_inference_steps:] + return timesteps, num_inference_steps + return timesteps, num_inference_steps - t_start # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents @@ -979,6 +1006,14 @@ def do_classifier_free_guidance(self): def cross_attention_kwargs(self): return self._cross_attention_kwargs + @property + def denoising_end(self): + return self._denoising_end + + @property + def denoising_start(self): + return self._denoising_start + @property def num_timesteps(self): return self._num_timesteps @@ -995,6 +1030,8 @@ def __call__( width: Optional[int] = None, strength: float = 0.8, num_inference_steps: int = 50, + denoising_start: Optional[float] = None, + denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, @@ -1236,6 +1273,8 @@ def __call__( self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._denoising_start = denoising_start # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -1322,11 +1361,20 @@ def __call__( assert False # 5. Prepare timesteps + def denoising_value_valid(dnv): + return isinstance(self.denoising_end, float) and 0 < dnv < 1 + self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, + strength, + device, + denoising_start=self.denoising_start if denoising_value_valid else None, + ) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) self._num_timesteps = len(timesteps) + add_noise = True if self.denoising_start is None else False # 6. Prepare latent variables latents = self.prepare_latents( image, @@ -1336,7 +1384,7 @@ def __call__( prompt_embeds.dtype, device, generator, - True, + add_noise, ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline @@ -1395,6 +1443,28 @@ def __call__( # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + if ( + self.denoising_end is not None + and self.denoising_start is not None + and denoising_value_valid(self.denoising_end) + and denoising_value_valid(self.denoising_start) + and self.denoising_start >= self.denoising_end + ): + raise ValueError( + f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: " + + f" {self.denoising_end} when using type float." + ) + elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance