From 9db1b4599956c8479d0e08261bb49a33312dc014 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 7 Jan 2024 08:40:01 +0000 Subject: [PATCH 01/25] fix --- src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 49c07a504985..a42cd927cd90 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -267,7 +267,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + sigmas = np.concatenate([sigmas, np.array([0])]).astype(np.float32) elif self.config.use_lu_lambdas: lambdas = np.flip(log_sigmas.copy()) lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps) @@ -831,7 +831,9 @@ def step( # Improve numerical stability for small number of steps lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( - self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15) + self.config.euler_at_final + or (self.config.lower_order_final and len(self.timesteps) < 15) + or self.config.use_karras_sigmas ) lower_order_second = ( (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 From 72e58ba9fbe320b52199bbcb61a0fb0185d24807 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 7 Jan 2024 08:47:53 +0000 Subject: [PATCH 02/25] fix copies --- .../schedulers/scheduling_dpmsolver_multistep_inverse.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 5d8f3fdf49cd..fb3b488891df 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -824,7 +824,9 @@ def step( # Improve numerical stability for small number of steps lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( - self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15) + self.config.euler_at_final + or (self.config.lower_order_final and len(self.timesteps) < 15) + or self.config.use_karras_sigmas ) lower_order_second = ( (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 From 9a494792c263eadadf70118b96c2f0a539256448 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 11 Jan 2024 07:48:00 +0000 Subject: [PATCH 03/25] refactor --- .../scheduling_dpmsolver_multistep.py | 31 ++++++++++++++++--- .../scheduling_dpmsolver_multistep_inverse.py | 2 +- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index a42cd927cd90..d53d2c43c0ad 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -164,6 +164,7 @@ def __init__( lower_order_final: bool = True, euler_at_final: bool = False, use_karras_sigmas: Optional[bool] = False, + final_sigmas_type: Optional[str] = "default", # "denoise_to_zero", "default" use_lu_lambdas: Optional[bool] = False, lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, @@ -267,16 +268,38 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - sigmas = np.concatenate([sigmas, np.array([0])]).astype(np.float32) + if self.config.final_sigmas_type == "default": + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + elif self.config.final_sigmas_type == "denoise_to_zero": + sigmas = np.concatenate([sigmas, np.array([0])]).astype(np.float32) + else: + raise ValueError( + f"final_sigmas_type must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" + ) elif self.config.use_lu_lambdas: lambdas = np.flip(log_sigmas.copy()) lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps) sigmas = np.exp(lambdas) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + if self.config.final_sigmas_type == "default": + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + elif self.config.final_sigmas_type == "denoise_to_zero": + sigmas = np.concatenate([sigmas, np.array([0])]).astype(np.float32) + else: + raise ValueError( + f"final_sigmas_type must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" + ) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + if self.config.final_sigmas_type == "default": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "denoise_to_zero": + sigma_last = 0 + else: + raise ValueError( + f"final_sigmas_type must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" + ) + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) @@ -833,7 +856,7 @@ def step( lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15) - or self.config.use_karras_sigmas + or self.config.final_sigmas_type == "denoise_to_zero" ) lower_order_second = ( (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index fb3b488891df..80d3043afd80 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -826,7 +826,7 @@ def step( lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15) - or self.config.use_karras_sigmas + or self.config.final_sigmas_type == "denoise_to_zero" ) lower_order_second = ( (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 From 6fbdd12edc72cd44c86f19559ee744246f5d34ee Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 15 Jan 2024 19:42:23 +0000 Subject: [PATCH 04/25] fix dpm inverse --- .../scheduling_dpmsolver_multistep_inverse.py | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 80d3043afd80..6f56b73885f9 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -160,6 +160,7 @@ def __init__( lower_order_final: bool = True, euler_at_final: bool = False, use_karras_sigmas: Optional[bool] = False, + final_sigmas_type: Optional[str] = "default", # "denoise_to_zero", "default" lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, timestep_spacing: str = "linspace", @@ -264,13 +265,27 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = timesteps.copy().astype(np.int64) - sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + if self.config.final_sigmas_type == "default": + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + elif self.config.final_sigmas_type == "denoise_to_zero": + sigmas = np.concatenate([sigmas, np.array([0])]).astype(np.float32) + else: + raise ValueError( + f"final_sigmas_type must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" + ) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - sigma_max = ( - (1 - self.alphas_cumprod[self.noisiest_timestep]) / self.alphas_cumprod[self.noisiest_timestep] - ) ** 0.5 - sigmas = np.concatenate([sigmas, [sigma_max]]).astype(np.float32) + if self.config.final_sigmas_type == "default": + sigma_last = ( + (1 - self.alphas_cumprod[self.noisiest_timestep]) / self.alphas_cumprod[self.noisiest_timestep] + ) ** 0.5 + elif self.config.final_sigmas_type == "denoise_to_zero": + sigma_last = 0 + else: + raise ValueError( + f"final_sigmas_type must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" + ) + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) From 1b2a42871510fa0c240486c06b8bf54b6d293725 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 15 Jan 2024 21:10:11 +0000 Subject: [PATCH 05/25] dpm singlestep --- .../scheduling_dpmsolver_multistep.py | 6 ++--- .../scheduling_dpmsolver_multistep_inverse.py | 4 +-- .../scheduling_dpmsolver_singlestep.py | 25 +++++++++++++++++-- 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index d53d2c43c0ad..724aeecfeb59 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -274,7 +274,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc sigmas = np.concatenate([sigmas, np.array([0])]).astype(np.float32) else: raise ValueError( - f"final_sigmas_type must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" + f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" ) elif self.config.use_lu_lambdas: lambdas = np.flip(log_sigmas.copy()) @@ -287,7 +287,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc sigmas = np.concatenate([sigmas, np.array([0])]).astype(np.float32) else: raise ValueError( - f"final_sigmas_type must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" + f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" ) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) @@ -297,7 +297,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc sigma_last = 0 else: raise ValueError( - f"final_sigmas_type must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" + f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" ) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 6f56b73885f9..c2c7c684d4da 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -271,7 +271,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc sigmas = np.concatenate([sigmas, np.array([0])]).astype(np.float32) else: raise ValueError( - f"final_sigmas_type must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" + f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" ) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) @@ -283,7 +283,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc sigma_last = 0 else: raise ValueError( - f"final_sigmas_type must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" + f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" ) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index dea033822e14..6f6a4f705bed 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -150,6 +150,7 @@ def __init__( solver_type: str = "midpoint", lower_order_final: bool = True, use_karras_sigmas: Optional[bool] = False, + final_sigmas_type: Optional[str] = "default", # "denoise_to_zero", "default" lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, ): @@ -267,10 +268,24 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + if self.config.final_sigmas_type == "default": + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + elif self.config.final_sigmas_type == "denoise_to_zero": + sigmas = np.concatenate([sigmas, np.array([0.0])]).astype(np.float32) + else: + raise ValueError( + f" `final_sigmas_type` must be one of `default` or `denoise_to_zero`, but got {self.config.final_sigmas_type}" + ) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + if self.config.final_sigmas_type == "default": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "denoise_to_zero": + sigma_last = 0 + else: + raise ValueError( + f" `final_sigmas_type` must be one of `default` or `denoise_to_zero`, but got {self.config.final_sigmas_type}" + ) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) @@ -285,6 +300,12 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic ) self.register_to_config(lower_order_final=True) + if not self.config.lower_order_final and self.config.final_sigmas_type == "denoise_to_zero": + logger.warn( + "denoise_to_zero is not supported for `lower_order_final=False`. Changing scheduler {self.config} to have `lower_order_final` set to True." + ) + self.register_to_config(lower_order_final=True) + self.order_list = self.get_order_list(num_inference_steps) # add an index counter for schedulers that allow duplicated timesteps From c6ec05ff94193d3809b82ce3c63be58bc3780a8f Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 15 Jan 2024 22:01:23 +0000 Subject: [PATCH 06/25] add valueerror for algorithems that does not support denoise_to_zero --- src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | 5 +++++ .../schedulers/scheduling_dpmsolver_multistep_inverse.py | 5 +++++ src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py | 5 +++++ 3 files changed, 15 insertions(+) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 724aeecfeb59..ddbd550b2b6e 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -208,6 +208,11 @@ def __init__( else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "denoise_to_zero": + raise ValueError( + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithem_type` {algorithm_type}." + ) + # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index c2c7c684d4da..50cf249b8ec8 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -203,6 +203,11 @@ def __init__( else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "denoise_to_zero": + raise ValueError( + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithem_type` {algorithm_type}." + ) + # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32).copy() diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 6f6a4f705bed..24ec14dc8a1c 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -190,6 +190,11 @@ def __init__( else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + if algorithm_type != "dpmsolver++" and final_sigmas_type == "denoise_to_zero": + raise ValueError( + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithem_type` {algorithm_type}." + ) + # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() From aa9d57cce8e08a333f26386b42957b7a49d6bd41 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Tue, 16 Jan 2024 06:44:07 -1000 Subject: [PATCH 07/25] Apply suggestions from code review Co-authored-by: Patrick von Platen --- src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | 2 +- .../schedulers/scheduling_dpmsolver_multistep_inverse.py | 2 +- src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index ddbd550b2b6e..620d0a2d1d52 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -210,7 +210,7 @@ def __init__( if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "denoise_to_zero": raise ValueError( - f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithem_type` {algorithm_type}." + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}." ) # setable values diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 50cf249b8ec8..c7784da7fae2 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -205,7 +205,7 @@ def __init__( if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "denoise_to_zero": raise ValueError( - f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithem_type` {algorithm_type}." + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}." ) # setable values diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 24ec14dc8a1c..fb8fd80930b6 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -192,7 +192,7 @@ def __init__( if algorithm_type != "dpmsolver++" and final_sigmas_type == "denoise_to_zero": raise ValueError( - f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithem_type` {algorithm_type}." + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}." ) # setable values From 15d4fb39adc56dc19ab9013be97a274953681553 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 17 Jan 2024 20:51:04 +0000 Subject: [PATCH 08/25] remove dpmsolver and sde-dpmsolver --- src/diffusers/__init__.py | 2 + src/diffusers/schedulers/__init__.py | 4 +- .../schedulers/deprecated/__init__.py | 2 + .../scheduling_dpmsolver_multistep_legacy.py | 842 ++++++++++++++++++ .../scheduling_dpmsolver_multistep.py | 104 +-- 5 files changed, 858 insertions(+), 96 deletions(-) create mode 100644 src/diffusers/schedulers/deprecated/scheduling_dpmsolver_multistep_legacy.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 180b210953c1..412a61879334 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -142,6 +142,7 @@ "DEISMultistepScheduler", "DPMSolverMultistepInverseScheduler", "DPMSolverMultistepScheduler", + "DPMSolverMultistepSchedulerLegacy", "DPMSolverSinglestepScheduler", "EulerAncestralDiscreteScheduler", "EulerDiscreteScheduler", @@ -519,6 +520,7 @@ DEISMultistepScheduler, DPMSolverMultistepInverseScheduler, DPMSolverMultistepScheduler, + DPMSolverMultistepSchedulerLegacy, DPMSolverSinglestepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index e908ba87acdd..29e2eedcb6cf 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -38,7 +38,7 @@ _dummy_modules.update(get_objects_from_module(dummy_pt_objects)) else: - _import_structure["deprecated"] = ["KarrasVeScheduler", "ScoreSdeVpScheduler"] + _import_structure["deprecated"] = ["DPMSolverMultistepSchedulerLegacy", "KarrasVeScheduler", "ScoreSdeVpScheduler"] _import_structure["scheduling_amused"] = ["AmusedScheduler"] _import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"] _import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"] @@ -129,7 +129,7 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_pt_objects import * # noqa F403 else: - from .deprecated import KarrasVeScheduler, ScoreSdeVpScheduler + from .deprecated import DPMSolverMultistepSchedulerLegacy, KarrasVeScheduler, ScoreSdeVpScheduler from .scheduling_amused import AmusedScheduler from .scheduling_consistency_decoder import ConsistencyDecoderScheduler from .scheduling_consistency_models import CMStochasticIterativeScheduler diff --git a/src/diffusers/schedulers/deprecated/__init__.py b/src/diffusers/schedulers/deprecated/__init__.py index 786707f45206..2ae7bc1d16a3 100644 --- a/src/diffusers/schedulers/deprecated/__init__.py +++ b/src/diffusers/schedulers/deprecated/__init__.py @@ -21,6 +21,7 @@ _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) else: + _import_structure["scheduling_dpmsolver_multistep_legacy"] = ["DPMSolverMultistepSchedulerLegacy"] _import_structure["scheduling_karras_ve"] = ["KarrasVeScheduler"] _import_structure["scheduling_sde_vp"] = ["ScoreSdeVpScheduler"] @@ -32,6 +33,7 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_pt_objects import * # noqa F403 else: + from .scheduling_dpmsolver_multistep_legacy import DPMSolverMultistepSchedulerLegacy from .scheduling_karras_ve import KarrasVeScheduler from .scheduling_sde_vp import ScoreSdeVpScheduler diff --git a/src/diffusers/schedulers/deprecated/scheduling_dpmsolver_multistep_legacy.py b/src/diffusers/schedulers/deprecated/scheduling_dpmsolver_multistep_legacy.py new file mode 100644 index 000000000000..8bdf349dab01 --- /dev/null +++ b/src/diffusers/schedulers/deprecated/scheduling_dpmsolver_multistep_legacy.py @@ -0,0 +1,842 @@ +# Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import deprecate +from ...utils.torch_utils import randn_tensor +from ..scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class DPMSolverMultistepSchedulerLegacy(SchedulerMixin, ConfigMixin): + """ + `DPMSolverMultistepSchedulerLegacy` is a fast dedicated high-order solver for diffusion ODEs. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + solver_order (`int`, defaults to 2): + The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++"`. + algorithm_type (`str`, defaults to `dpmsolver`): + Algorithm type for the solver; can be `dpmsolver`or `sde-dpmsolver`. The + `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) + paper. + solver_type (`str`, defaults to `midpoint`): + Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the + sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + euler_at_final (`bool`, defaults to `False`): + Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail + richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference + steps, but sometimes may result in blurring. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + use_lu_lambdas (`bool`, *optional*, defaults to `False`): + Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during + the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of + `lambda(t)`. + lambda_min_clipped (`float`, defaults to `-inf`): + Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the + cosine (`squaredcos_cap_v2`) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output + contains the predicted Gaussian variance. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable + Diffusion. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver", # "dpmsolver", "sde-dpmsolver" + solver_type: str = "midpoint", + lower_order_final: bool = True, + euler_at_final: bool = False, + use_karras_sigmas: Optional[bool] = False, + use_lu_lambdas: Optional[bool] = False, + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # settings for DPM-Solver + if algorithm_type not in ["dpmsolver", "sde-dpmsolver"]: + raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") + + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self._step_index = None + self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + @property + def step_index(self): + """ + The index counter for current timestep. It will increae 1 after each scheduler step. + """ + return self._step_index + + def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) + last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item() + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = last_timestep // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + + if self.config.use_karras_sigmas: + sigmas = np.flip(sigmas).copy() + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + elif self.config.use_lu_lambdas: + lambdas = np.flip(log_sigmas.copy()) + lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps) + sigmas = np.exp(lambdas) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + else: + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def _convert_to_lu(self, in_lambdas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: + """Constructs the noise schedule of Lu et al. (2022).""" + + lambda_min: float = in_lambdas[-1].item() + lambda_max: float = in_lambdas[0].item() + + rho = 1.0 # 1.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = lambda_min ** (1 / rho) + max_inv_rho = lambda_max ** (1 / rho) + lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return lambdas + + def convert_model_output( + self, + model_output: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, + ) -> torch.FloatTensor: + """ + Convert the model output. DPM-Solver is designed to discretize an integral of the noise prediction model. + + + + The algorithm and model type are decoupled. You can use either DPMSolver for both noise + prediction and data prediction models. + + + + Args: + model_output (`torch.FloatTensor`): + The direct output from the learned diffusion model. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.FloatTensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + # DPM-Solver needs to solve an integral of the noise prediction model. + if self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + epsilon = model_output[:, :3] + else: + epsilon = model_output + elif self.config.prediction_type == "sample": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = (sample - alpha_t * model_output) / sigma_t + elif self.config.prediction_type == "v_prediction": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = alpha_t * model_output + sigma_t * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = (sample - sigma_t * epsilon) / alpha_t + x0_pred = self._threshold_sample(x0_pred) + epsilon = (sample - alpha_t * x0_pred) / sigma_t + + return epsilon + + def dpm_solver_first_order_update( + self, + model_output: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + noise: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> torch.FloatTensor: + """ + One step for the first-order DPMSolver (equivalent to DDIM). + + Args: + model_output (`torch.FloatTensor`): + The direct output from the learned diffusion model. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.FloatTensor`: + The sample tensor at the previous timestep. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ( + (alpha_t / alpha_s) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.FloatTensor], + *args, + sample: torch.FloatTensor = None, + noise: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> torch.FloatTensor: + """ + One step for the second-order multistep DPMSolver. + + Args: + model_output_list (`List[torch.FloatTensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.FloatTensor`: + The sample tensor at the previous timestep. + """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + ) + if self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * (torch.exp(h) - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + def multistep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.FloatTensor], + *args, + sample: torch.FloatTensor = None, + **kwargs, + ) -> torch.FloatTensor: + """ + One step for the third-order multistep DPMSolver. + + Args: + model_output_list (`List[torch.FloatTensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.FloatTensor`): + A current instance of a sample created by diffusion process. + + Returns: + `torch.FloatTensor`: + The sample tensor at the previous timestep. + """ + + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + self.sigmas[self.step_index - 2], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) + + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + + if self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 + ) + return x_t + + def _init_step_index(self, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + index_candidates = (self.timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + self._step_index = step_index + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + generator=None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep DPMSolver. + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Improve numerical stability for small number of steps + lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( + self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15) + ) + lower_order_second = ( + (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + if self.config.algorithm_type in ["sde-dpmsolver"]: + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + else: + noise = None + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) + else: + prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): + The input sample. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + return sample + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + step_indices = [] + for timestep in timesteps: + index_candidates = (schedule_timesteps == timestep).nonzero() + if len(index_candidates) == 0: + step_index = len(schedule_timesteps) - 1 + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + step_indices.append(step_index) + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index ddbd550b2b6e..c2c6bc90969e 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -106,9 +106,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++"`. algorithm_type (`str`, defaults to `dpmsolver++`): - Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The - `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) - paper, and the `dpmsolver++` type implements the algorithms in the + Algorithm type for the solver; can be `dpmsolver++` or `sde-dpmsolver++`. It implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. solver_type (`str`, defaults to `midpoint`): @@ -196,9 +194,13 @@ def __init__( self.init_noise_sigma = 1.0 # settings for DPM-Solver - if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: + if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"]: if algorithm_type == "deis": self.register_to_config(algorithm_type="dpmsolver++") + elif algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + raise ValueError( + f"`algorithm_type` {algorithm_type} is no longer supported in {self.__class__}. Please use `DPMSolverSchedulerLegacy` instead." + ) else: raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") @@ -208,11 +210,6 @@ def __init__( else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") - if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "denoise_to_zero": - raise ValueError( - f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithem_type` {algorithm_type}." - ) - # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() @@ -432,14 +429,11 @@ def convert_model_output( **kwargs, ) -> torch.FloatTensor: """ - Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is - designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an - integral of the data prediction model. + Convert the model output to predict data. DPM-Solver++ is designed to discretize an integral of the data prediction model. - The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise - prediction and data prediction models. + The algorithm and model type are decoupled. You can use DPMSolver++ for either noise prediction and data prediction models. @@ -469,7 +463,7 @@ def convert_model_output( # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: if self.config.prediction_type == "epsilon": - # DPM-Solver and DPM-Solver++ only need the "mean" output. + # DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned", "learned_range"]: model_output = model_output[:, :3] sigma = self.sigmas[self.step_index] @@ -492,37 +486,6 @@ def convert_model_output( return x0_pred - # DPM-Solver needs to solve an integral of the noise prediction model. - elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: - if self.config.prediction_type == "epsilon": - # DPM-Solver and DPM-Solver++ only need the "mean" output. - if self.config.variance_type in ["learned", "learned_range"]: - epsilon = model_output[:, :3] - else: - epsilon = model_output - elif self.config.prediction_type == "sample": - sigma = self.sigmas[self.step_index] - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - epsilon = (sample - alpha_t * model_output) / sigma_t - elif self.config.prediction_type == "v_prediction": - sigma = self.sigmas[self.step_index] - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - epsilon = alpha_t * model_output + sigma_t * sample - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" - " `v_prediction` for the DPMSolverMultistepScheduler." - ) - - if self.config.thresholding: - sigma = self.sigmas[self.step_index] - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - x0_pred = (sample - sigma_t * epsilon) / alpha_t - x0_pred = self._threshold_sample(x0_pred) - epsilon = (sample - alpha_t * x0_pred) / sigma_t - - return epsilon - def dpm_solver_first_order_update( self, model_output: torch.FloatTensor, @@ -574,8 +537,6 @@ def dpm_solver_first_order_update( h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output - elif self.config.algorithm_type == "dpmsolver": - x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output elif self.config.algorithm_type == "sde-dpmsolver++": assert noise is not None x_t = ( @@ -583,13 +544,6 @@ def dpm_solver_first_order_update( + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) - elif self.config.algorithm_type == "sde-dpmsolver": - assert noise is not None - x_t = ( - (alpha_t / alpha_s) * sample - - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output - + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise - ) return x_t def multistep_dpm_solver_second_order_update( @@ -667,20 +621,6 @@ def multistep_dpm_solver_second_order_update( - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 ) - elif self.config.algorithm_type == "dpmsolver": - # See https://arxiv.org/abs/2206.00927 for detailed derivations - if self.config.solver_type == "midpoint": - x_t = ( - (alpha_t / alpha_s0) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 - ) - elif self.config.solver_type == "heun": - x_t = ( - (alpha_t / alpha_s0) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - ) elif self.config.algorithm_type == "sde-dpmsolver++": assert noise is not None if self.config.solver_type == "midpoint": @@ -697,22 +637,6 @@ def multistep_dpm_solver_second_order_update( + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) - elif self.config.algorithm_type == "sde-dpmsolver": - assert noise is not None - if self.config.solver_type == "midpoint": - x_t = ( - (alpha_t / alpha_s0) * sample - - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * (torch.exp(h) - 1.0)) * D1 - + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise - ) - elif self.config.solver_type == "heun": - x_t = ( - (alpha_t / alpha_s0) * sample - - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise - ) return x_t def multistep_dpm_solver_third_order_update( @@ -790,14 +714,6 @@ def multistep_dpm_solver_third_order_update( + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 ) - elif self.config.algorithm_type == "dpmsolver": - # See https://arxiv.org/abs/2206.00927 for detailed derivations - x_t = ( - (alpha_t / alpha_s0) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 - ) return x_t def _init_step_index(self, timestep): @@ -872,7 +788,7 @@ def step( self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output - if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + if self.config.algorithm_type in ["sde-dpmsolver++"]: noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype ) From 1827a4c62c4083a049af8d348e70d9a035f20335 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 17 Jan 2024 23:28:50 +0000 Subject: [PATCH 09/25] udpate and deprecate inverse --- ...ling_dpmsolver_multistep_inverse_legacy.py | 878 ++++++++++++++++++ .../scheduling_dpmsolver_multistep_inverse.py | 12 +- 2 files changed, 881 insertions(+), 9 deletions(-) create mode 100644 src/diffusers/schedulers/deprecated/scheduling_dpmsolver_multistep_inverse_legacy.py diff --git a/src/diffusers/schedulers/deprecated/scheduling_dpmsolver_multistep_inverse_legacy.py b/src/diffusers/schedulers/deprecated/scheduling_dpmsolver_multistep_inverse_legacy.py new file mode 100644 index 000000000000..bd2d4ed2871c --- /dev/null +++ b/src/diffusers/schedulers/deprecated/scheduling_dpmsolver_multistep_inverse_legacy.py @@ -0,0 +1,878 @@ +# Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import deprecate +from ...utils.torch_utils import randn_tensor +from ..scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class DPMSolverMultistepInverseSchedulerLegacy(SchedulerMixin, ConfigMixin): + """ + `DPMSolverMultistepInverseSchedulerLegacy` is the reverse scheduler of [`DPMSolverMultistepSchedulerLegacy`]. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + solver_order (`int`, defaults to 2): + The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++"`. + algorithm_type (`str`, defaults to `dpmsolver`): + Algorithm type for the solver; can be `dpmsolver` or `sde-dpmsolver`. The + `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) + paper. + solver_type (`str`, defaults to `midpoint`): + Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the + sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + euler_at_final (`bool`, defaults to `False`): + Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail + richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference + steps, but sometimes may result in blurring. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + lambda_min_clipped (`float`, defaults to `-inf`): + Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the + cosine (`squaredcos_cap_v2`) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output + contains the predicted Gaussian variance. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable + Diffusion. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver", + solver_type: str = "midpoint", + lower_order_final: bool = True, + euler_at_final: bool = False, + use_karras_sigmas: Optional[bool] = False, + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # settings for DPM-Solver + if algorithm_type not in ["dpmsolver", "sde-dpmsolver"]: + raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") + + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32).copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self._step_index = None + self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.use_karras_sigmas = use_karras_sigmas + + @property + def step_index(self): + """ + The index counter for current timestep. It will increae 1 after each scheduler step. + """ + return self._step_index + + def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped).item() + self.noisiest_timestep = self.config.num_train_timesteps - 1 - clipped_idx + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.noisiest_timestep, num_inference_steps + 1).round()[:-1].copy().astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = (self.noisiest_timestep + 1) // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[:-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.arange(self.noisiest_timestep + 1, 0, -step_ratio).round()[::-1].copy().astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', " + "'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + timesteps = timesteps.copy().astype(np.int64) + sigmas = np.concatenate([sigmas[-1:], sigmas]).astype(np.float32) + else: + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigma_last = ( + (1 - self.alphas_cumprod[self.noisiest_timestep]) / self.alphas_cumprod[self.noisiest_timestep] + ) ** 0.5 + sigmas = np.concatenate([[sigma_last], sigmas]).astype(np.float32) + + self.sigmas = torch.from_numpy(sigmas) + + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep_legacy.DPMSolverMultistepSchedulerLegacy.convert_model_output + def convert_model_output( + self, + model_output: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, + ) -> torch.FloatTensor: + """ + Convert the model output to predict data. DPM-Solver is designed to discretize an integral of the noise prediction model. + + + + The algorithm and model type are decoupled. You can use either DPMSolver for either noise prediction and data prediction models. + + + + Args: + model_output (`torch.FloatTensor`): + The direct output from the learned diffusion model. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.FloatTensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + epsilon = model_output[:, :3] + else: + epsilon = model_output + elif self.config.prediction_type == "sample": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = (sample - alpha_t * model_output) / sigma_t + elif self.config.prediction_type == "v_prediction": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = alpha_t * model_output + sigma_t * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = (sample - sigma_t * epsilon) / alpha_t + x0_pred = self._threshold_sample(x0_pred) + epsilon = (sample - alpha_t * x0_pred) / sigma_t + + return epsilon + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep_legacy.DPMSolverMultistepSchedulerLegacy.dpm_solver_first_order_update + def dpm_solver_first_order_update( + self, + model_output: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + noise: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> torch.FloatTensor: + """ + One step for the first-order DPMSolver (equivalent to DDIM). + + Args: + model_output (`torch.FloatTensor`): + The direct output from the learned diffusion model. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.FloatTensor`: + The sample tensor at the previous timestep. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ( + (sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ( + (alpha_t / alpha_s) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep_legacy.DPMSolverMultistepSchedulerLegacy.multistep_dpm_solver_second_order_update + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.FloatTensor], + *args, + sample: torch.FloatTensor = None, + noise: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> torch.FloatTensor: + """ + One step for the second-order multistep DPMSolver. + + Args: + model_output_list (`List[torch.FloatTensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.FloatTensor`: + The sample tensor at the previous timestep. + """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * (torch.exp(h) - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep_legacy.DPMSolverMultistepSchedulerLegacy.multistep_dpm_solver_third_order_update + def multistep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.FloatTensor], + *args, + sample: torch.FloatTensor = None, + **kwargs, + ) -> torch.FloatTensor: + """ + One step for the third-order multistep DPMSolver. + + Args: + model_output_list (`List[torch.FloatTensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.FloatTensor`): + A current instance of a sample created by diffusion process. + + Returns: + `torch.FloatTensor`: + The sample tensor at the previous timestep. + """ + + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + self.sigmas[self.step_index - 2], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) + + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 + ) + return x_t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + index_candidates = (self.timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + self._step_index = step_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep_legacy.DPMSolverMultistepSchedulerLegacy.step + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + generator=None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep DPMSolver. + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Improve numerical stability for small number of steps + lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( + self.config.euler_at_final + or (self.config.lower_order_final and len(self.timesteps) < 15) + or self.config.final_sigmas_type == "denoise_to_zero" + ) + lower_order_second = ( + (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + else: + noise = None + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) + else: + prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep_legacy.DPMSolverMultistepSchedulerLegacy.scale_model_input + def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): + The input sample. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + step_indices = [] + for timestep in timesteps: + index_candidates = (schedule_timesteps == timestep).nonzero() + if len(index_candidates) == 0: + step_index = len(schedule_timesteps) - 1 + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + step_indices.append(step_index) + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 50cf249b8ec8..67b65940ea6e 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -271,9 +271,9 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = timesteps.copy().astype(np.int64) if self.config.final_sigmas_type == "default": - sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + sigmas = np.concatenate([sigmas[0], sigmas]).astype(np.float32) elif self.config.final_sigmas_type == "denoise_to_zero": - sigmas = np.concatenate([sigmas, np.array([0])]).astype(np.float32) + sigmas = np.concatenate([np.array([0]), sigmas]).astype(np.float32) else: raise ValueError( f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" @@ -290,15 +290,9 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc raise ValueError( f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" ) - sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + sigmas = np.concatenate([[sigma_last], sigmas]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) - - # when num_inference_steps == num_train_timesteps, we can end up with - # duplicates in timesteps. - _, unique_indices = np.unique(timesteps, return_index=True) - timesteps = timesteps[np.sort(unique_indices)] - self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) self.num_inference_steps = len(timesteps) From 4271d2f370cfa255abe88e46099b5016909d4b08 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 18 Jan 2024 02:04:25 +0000 Subject: [PATCH 10/25] reverse changes to inverse scheduler --- .../scheduling_dpmsolver_multistep_inverse.py | 35 ++++--------------- 1 file changed, 6 insertions(+), 29 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 1b412f486994..fadd465efb42 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -160,7 +160,6 @@ def __init__( lower_order_final: bool = True, euler_at_final: bool = False, use_karras_sigmas: Optional[bool] = False, - final_sigmas_type: Optional[str] = "default", # "denoise_to_zero", "default" lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, timestep_spacing: str = "linspace", @@ -203,11 +202,6 @@ def __init__( else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") - if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "denoise_to_zero": - raise ValueError( - f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}." - ) - # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32).copy() @@ -270,27 +264,13 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = timesteps.copy().astype(np.int64) - if self.config.final_sigmas_type == "default": - sigmas = np.concatenate([sigmas[0], sigmas]).astype(np.float32) - elif self.config.final_sigmas_type == "denoise_to_zero": - sigmas = np.concatenate([np.array([0]), sigmas]).astype(np.float32) - else: - raise ValueError( - f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" - ) + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - if self.config.final_sigmas_type == "default": - sigma_last = ( - (1 - self.alphas_cumprod[self.noisiest_timestep]) / self.alphas_cumprod[self.noisiest_timestep] - ) ** 0.5 - elif self.config.final_sigmas_type == "denoise_to_zero": - sigma_last = 0 - else: - raise ValueError( - f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" - ) - sigmas = np.concatenate([[sigma_last], sigmas]).astype(np.float32) + sigma_max = ( + (1 - self.alphas_cumprod[self.noisiest_timestep]) / self.alphas_cumprod[self.noisiest_timestep] + ) ** 0.5 + sigmas = np.concatenate([sigmas, [sigma_max]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) @@ -797,7 +777,6 @@ def _init_step_index(self, timestep): self._step_index = step_index - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step def step( self, model_output: torch.FloatTensor, @@ -838,9 +817,7 @@ def step( # Improve numerical stability for small number of steps lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( - self.config.euler_at_final - or (self.config.lower_order_final and len(self.timesteps) < 15) - or self.config.final_sigmas_type == "denoise_to_zero" + self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15) ) lower_order_second = ( (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 From 8d5485083539d4c78ccf0469539569013bcae8f4 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 18 Jan 2024 02:43:33 +0000 Subject: [PATCH 11/25] fix --- .../schedulers/scheduling_dpmsolver_multistep_inverse.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index fadd465efb42..0f64844e865c 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -777,6 +777,7 @@ def _init_step_index(self, timestep): self._step_index = step_index + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step def step( self, model_output: torch.FloatTensor, From b23cade876e39a9807ea6f6d4c66e0e2158e6d27 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 18 Jan 2024 02:44:35 +0000 Subject: [PATCH 12/25] remove file --- ...ling_dpmsolver_multistep_inverse_legacy.py | 878 ------------------ 1 file changed, 878 deletions(-) delete mode 100644 src/diffusers/schedulers/deprecated/scheduling_dpmsolver_multistep_inverse_legacy.py diff --git a/src/diffusers/schedulers/deprecated/scheduling_dpmsolver_multistep_inverse_legacy.py b/src/diffusers/schedulers/deprecated/scheduling_dpmsolver_multistep_inverse_legacy.py deleted file mode 100644 index bd2d4ed2871c..000000000000 --- a/src/diffusers/schedulers/deprecated/scheduling_dpmsolver_multistep_inverse_legacy.py +++ /dev/null @@ -1,878 +0,0 @@ -# Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver - -import math -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch - -from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import deprecate -from ...utils.torch_utils import randn_tensor -from ..scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput - - -# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - if alpha_transform_type == "cosine": - - def alpha_bar_fn(t): - return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 - - elif alpha_transform_type == "exp": - - def alpha_bar_fn(t): - return math.exp(t * -12.0) - - else: - raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float32) - - -class DPMSolverMultistepInverseSchedulerLegacy(SchedulerMixin, ConfigMixin): - """ - `DPMSolverMultistepInverseSchedulerLegacy` is the reverse scheduler of [`DPMSolverMultistepSchedulerLegacy`]. - - This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic - methods the library implements for all schedulers such as loading and saving. - - Args: - num_train_timesteps (`int`, defaults to 1000): - The number of diffusion steps to train the model. - beta_start (`float`, defaults to 0.0001): - The starting `beta` value of inference. - beta_end (`float`, defaults to 0.02): - The final `beta` value. - beta_schedule (`str`, defaults to `"linear"`): - The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - trained_betas (`np.ndarray`, *optional*): - Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. - solver_order (`int`, defaults to 2): - The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided - sampling, and `solver_order=3` for unconditional sampling. - prediction_type (`str`, defaults to `epsilon`, *optional*): - Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), - `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen - Video](https://imagen.research.google/video/paper.pdf) paper). - thresholding (`bool`, defaults to `False`): - Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such - as Stable Diffusion. - dynamic_thresholding_ratio (`float`, defaults to 0.995): - The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. - sample_max_value (`float`, defaults to 1.0): - The threshold value for dynamic thresholding. Valid only when `thresholding=True` and - `algorithm_type="dpmsolver++"`. - algorithm_type (`str`, defaults to `dpmsolver`): - Algorithm type for the solver; can be `dpmsolver` or `sde-dpmsolver`. The - `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) - paper. - solver_type (`str`, defaults to `midpoint`): - Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the - sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. - lower_order_final (`bool`, defaults to `True`): - Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can - stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. - euler_at_final (`bool`, defaults to `False`): - Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail - richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference - steps, but sometimes may result in blurring. - use_karras_sigmas (`bool`, *optional*, defaults to `False`): - Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, - the sigmas are determined according to a sequence of noise levels {σi}. - lambda_min_clipped (`float`, defaults to `-inf`): - Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the - cosine (`squaredcos_cap_v2`) noise schedule. - variance_type (`str`, *optional*): - Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output - contains the predicted Gaussian variance. - timestep_spacing (`str`, defaults to `"linspace"`): - The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. - steps_offset (`int`, defaults to 0): - An offset added to the inference steps. You can use a combination of `offset=1` and - `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable - Diffusion. - """ - - _compatibles = [e.name for e in KarrasDiffusionSchedulers] - order = 1 - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, - solver_order: int = 2, - prediction_type: str = "epsilon", - thresholding: bool = False, - dynamic_thresholding_ratio: float = 0.995, - sample_max_value: float = 1.0, - algorithm_type: str = "dpmsolver", - solver_type: str = "midpoint", - lower_order_final: bool = True, - euler_at_final: bool = False, - use_karras_sigmas: Optional[bool] = False, - lambda_min_clipped: float = -float("inf"), - variance_type: Optional[str] = None, - timestep_spacing: str = "linspace", - steps_offset: int = 0, - ): - if trained_betas is not None: - self.betas = torch.tensor(trained_betas, dtype=torch.float32) - elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - # Currently we only support VP-type noise schedule - self.alpha_t = torch.sqrt(self.alphas_cumprod) - self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) - self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) - self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 - - # standard deviation of the initial noise distribution - self.init_noise_sigma = 1.0 - - # settings for DPM-Solver - if algorithm_type not in ["dpmsolver", "sde-dpmsolver"]: - raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") - - if solver_type not in ["midpoint", "heun"]: - if solver_type in ["logrho", "bh1", "bh2"]: - self.register_to_config(solver_type="midpoint") - else: - raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") - - # setable values - self.num_inference_steps = None - timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32).copy() - self.timesteps = torch.from_numpy(timesteps) - self.model_outputs = [None] * solver_order - self.lower_order_nums = 0 - self._step_index = None - self.sigmas.to("cpu") # to avoid too much CPU/GPU communication - self.use_karras_sigmas = use_karras_sigmas - - @property - def step_index(self): - """ - The index counter for current timestep. It will increae 1 after each scheduler step. - """ - return self._step_index - - def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): - """ - Sets the discrete timesteps used for the diffusion chain (to be run before inference). - - Args: - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - """ - # Clipping the minimum of all lambda(t) for numerical stability. - # This is critical for cosine (squaredcos_cap_v2) noise schedule. - clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped).item() - self.noisiest_timestep = self.config.num_train_timesteps - 1 - clipped_idx - - # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 - if self.config.timestep_spacing == "linspace": - timesteps = ( - np.linspace(0, self.noisiest_timestep, num_inference_steps + 1).round()[:-1].copy().astype(np.int64) - ) - elif self.config.timestep_spacing == "leading": - step_ratio = (self.noisiest_timestep + 1) // (num_inference_steps + 1) - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[:-1].copy().astype(np.int64) - timesteps += self.config.steps_offset - elif self.config.timestep_spacing == "trailing": - step_ratio = self.config.num_train_timesteps / num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = np.arange(self.noisiest_timestep + 1, 0, -step_ratio).round()[::-1].copy().astype(np.int64) - timesteps -= 1 - else: - raise ValueError( - f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', " - "'leading' or 'trailing'." - ) - - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - log_sigmas = np.log(sigmas) - - if self.config.use_karras_sigmas: - sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - timesteps = timesteps.copy().astype(np.int64) - sigmas = np.concatenate([sigmas[-1:], sigmas]).astype(np.float32) - else: - sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - sigma_last = ( - (1 - self.alphas_cumprod[self.noisiest_timestep]) / self.alphas_cumprod[self.noisiest_timestep] - ) ** 0.5 - sigmas = np.concatenate([[sigma_last], sigmas]).astype(np.float32) - - self.sigmas = torch.from_numpy(sigmas) - - self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) - - self.num_inference_steps = len(timesteps) - - self.model_outputs = [ - None, - ] * self.config.solver_order - self.lower_order_nums = 0 - - # add an index counter for schedulers that allow duplicated timesteps - self._step_index = None - self.sigmas.to("cpu") # to avoid too much CPU/GPU communication - - # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample - def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: - """ - "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the - prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by - s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing - pixels from saturation at each step. We find that dynamic thresholding results in significantly better - photorealism as well as better image-text alignment, especially when using very large guidance weights." - - https://arxiv.org/abs/2205.11487 - """ - dtype = sample.dtype - batch_size, channels, *remaining_dims = sample.shape - - if dtype not in (torch.float32, torch.float64): - sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half - - # Flatten sample for doing quantile calculation along each image - sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) - - abs_sample = sample.abs() # "a certain percentile absolute pixel value" - - s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) - s = torch.clamp( - s, min=1, max=self.config.sample_max_value - ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] - s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 - sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" - - sample = sample.reshape(batch_size, channels, *remaining_dims) - sample = sample.to(dtype) - - return sample - - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t - def _sigma_to_t(self, sigma, log_sigmas): - # get log sigma - log_sigma = np.log(np.maximum(sigma, 1e-10)) - - # get distribution - dists = log_sigma - log_sigmas[:, np.newaxis] - - # get sigmas range - low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) - high_idx = low_idx + 1 - - low = log_sigmas[low_idx] - high = log_sigmas[high_idx] - - # interpolate sigmas - w = (low - log_sigma) / (low - high) - w = np.clip(w, 0, 1) - - # transform interpolation to time range - t = (1 - w) * low_idx + w * high_idx - t = t.reshape(sigma.shape) - return t - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t - def _sigma_to_alpha_sigma_t(self, sigma): - alpha_t = 1 / ((sigma**2 + 1) ** 0.5) - sigma_t = sigma * alpha_t - - return alpha_t, sigma_t - - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras - def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: - """Constructs the noise schedule of Karras et al. (2022).""" - - # Hack to make sure that other schedulers which copy this function don't break - # TODO: Add this logic to the other schedulers - if hasattr(self.config, "sigma_min"): - sigma_min = self.config.sigma_min - else: - sigma_min = None - - if hasattr(self.config, "sigma_max"): - sigma_max = self.config.sigma_max - else: - sigma_max = None - - sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() - sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() - - rho = 7.0 # 7.0 is the value used in the paper - ramp = np.linspace(0, 1, num_inference_steps) - min_inv_rho = sigma_min ** (1 / rho) - max_inv_rho = sigma_max ** (1 / rho) - sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho - return sigmas - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep_legacy.DPMSolverMultistepSchedulerLegacy.convert_model_output - def convert_model_output( - self, - model_output: torch.FloatTensor, - *args, - sample: torch.FloatTensor = None, - **kwargs, - ) -> torch.FloatTensor: - """ - Convert the model output to predict data. DPM-Solver is designed to discretize an integral of the noise prediction model. - - - - The algorithm and model type are decoupled. You can use either DPMSolver for either noise prediction and data prediction models. - - - - Args: - model_output (`torch.FloatTensor`): - The direct output from the learned diffusion model. - sample (`torch.FloatTensor`): - A current instance of a sample created by the diffusion process. - - Returns: - `torch.FloatTensor`: - The converted model output. - """ - timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) - if sample is None: - if len(args) > 1: - sample = args[1] - else: - raise ValueError("missing `sample` as a required keyward argument") - if timestep is not None: - deprecate( - "timesteps", - "1.0.0", - "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - if self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: - if self.config.prediction_type == "epsilon": - # DPM-Solver and DPM-Solver++ only need the "mean" output. - if self.config.variance_type in ["learned", "learned_range"]: - epsilon = model_output[:, :3] - else: - epsilon = model_output - elif self.config.prediction_type == "sample": - sigma = self.sigmas[self.step_index] - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - epsilon = (sample - alpha_t * model_output) / sigma_t - elif self.config.prediction_type == "v_prediction": - sigma = self.sigmas[self.step_index] - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - epsilon = alpha_t * model_output + sigma_t * sample - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" - " `v_prediction` for the DPMSolverMultistepScheduler." - ) - - if self.config.thresholding: - sigma = self.sigmas[self.step_index] - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - x0_pred = (sample - sigma_t * epsilon) / alpha_t - x0_pred = self._threshold_sample(x0_pred) - epsilon = (sample - alpha_t * x0_pred) / sigma_t - - return epsilon - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep_legacy.DPMSolverMultistepSchedulerLegacy.dpm_solver_first_order_update - def dpm_solver_first_order_update( - self, - model_output: torch.FloatTensor, - *args, - sample: torch.FloatTensor = None, - noise: Optional[torch.FloatTensor] = None, - **kwargs, - ) -> torch.FloatTensor: - """ - One step for the first-order DPMSolver (equivalent to DDIM). - - Args: - model_output (`torch.FloatTensor`): - The direct output from the learned diffusion model. - sample (`torch.FloatTensor`): - A current instance of a sample created by the diffusion process. - - Returns: - `torch.FloatTensor`: - The sample tensor at the previous timestep. - """ - timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) - prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) - if sample is None: - if len(args) > 2: - sample = args[2] - else: - raise ValueError(" missing `sample` as a required keyward argument") - if timestep is not None: - deprecate( - "timesteps", - "1.0.0", - "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - if prev_timestep is not None: - deprecate( - "prev_timestep", - "1.0.0", - "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) - alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) - lambda_t = torch.log(alpha_t) - torch.log(sigma_t) - lambda_s = torch.log(alpha_s) - torch.log(sigma_s) - - h = lambda_t - lambda_s - if self.config.algorithm_type == "dpmsolver++": - x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output - elif self.config.algorithm_type == "dpmsolver": - x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output - elif self.config.algorithm_type == "sde-dpmsolver++": - assert noise is not None - x_t = ( - (sigma_t / sigma_s * torch.exp(-h)) * sample - + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output - + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise - ) - elif self.config.algorithm_type == "sde-dpmsolver": - assert noise is not None - x_t = ( - (alpha_t / alpha_s) * sample - - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output - + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise - ) - return x_t - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep_legacy.DPMSolverMultistepSchedulerLegacy.multistep_dpm_solver_second_order_update - def multistep_dpm_solver_second_order_update( - self, - model_output_list: List[torch.FloatTensor], - *args, - sample: torch.FloatTensor = None, - noise: Optional[torch.FloatTensor] = None, - **kwargs, - ) -> torch.FloatTensor: - """ - One step for the second-order multistep DPMSolver. - - Args: - model_output_list (`List[torch.FloatTensor]`): - The direct outputs from learned diffusion model at current and latter timesteps. - sample (`torch.FloatTensor`): - A current instance of a sample created by the diffusion process. - - Returns: - `torch.FloatTensor`: - The sample tensor at the previous timestep. - """ - timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) - prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) - if sample is None: - if len(args) > 2: - sample = args[2] - else: - raise ValueError(" missing `sample` as a required keyward argument") - if timestep_list is not None: - deprecate( - "timestep_list", - "1.0.0", - "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - if prev_timestep is not None: - deprecate( - "prev_timestep", - "1.0.0", - "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - sigma_t, sigma_s0, sigma_s1 = ( - self.sigmas[self.step_index + 1], - self.sigmas[self.step_index], - self.sigmas[self.step_index - 1], - ) - - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) - alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) - alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) - - lambda_t = torch.log(alpha_t) - torch.log(sigma_t) - lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) - lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) - - m0, m1 = model_output_list[-1], model_output_list[-2] - - h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 - r0 = h_0 / h - D0, D1 = m0, (1.0 / r0) * (m0 - m1) - if self.config.algorithm_type == "dpmsolver++": - # See https://arxiv.org/abs/2211.01095 for detailed derivations - if self.config.solver_type == "midpoint": - x_t = ( - (sigma_t / sigma_s0) * sample - - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 - ) - elif self.config.solver_type == "heun": - x_t = ( - (sigma_t / sigma_s0) * sample - - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - ) - elif self.config.algorithm_type == "dpmsolver": - # See https://arxiv.org/abs/2206.00927 for detailed derivations - if self.config.solver_type == "midpoint": - x_t = ( - (alpha_t / alpha_s0) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 - ) - elif self.config.solver_type == "heun": - x_t = ( - (alpha_t / alpha_s0) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - ) - elif self.config.algorithm_type == "sde-dpmsolver++": - assert noise is not None - if self.config.solver_type == "midpoint": - x_t = ( - (sigma_t / sigma_s0 * torch.exp(-h)) * sample - + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 - + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 - + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise - ) - elif self.config.solver_type == "heun": - x_t = ( - (sigma_t / sigma_s0 * torch.exp(-h)) * sample - + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 - + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 - + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise - ) - elif self.config.algorithm_type == "sde-dpmsolver": - assert noise is not None - if self.config.solver_type == "midpoint": - x_t = ( - (alpha_t / alpha_s0) * sample - - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * (torch.exp(h) - 1.0)) * D1 - + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise - ) - elif self.config.solver_type == "heun": - x_t = ( - (alpha_t / alpha_s0) * sample - - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise - ) - return x_t - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep_legacy.DPMSolverMultistepSchedulerLegacy.multistep_dpm_solver_third_order_update - def multistep_dpm_solver_third_order_update( - self, - model_output_list: List[torch.FloatTensor], - *args, - sample: torch.FloatTensor = None, - **kwargs, - ) -> torch.FloatTensor: - """ - One step for the third-order multistep DPMSolver. - - Args: - model_output_list (`List[torch.FloatTensor]`): - The direct outputs from learned diffusion model at current and latter timesteps. - sample (`torch.FloatTensor`): - A current instance of a sample created by diffusion process. - - Returns: - `torch.FloatTensor`: - The sample tensor at the previous timestep. - """ - - timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) - prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) - if sample is None: - if len(args) > 2: - sample = args[2] - else: - raise ValueError(" missing`sample` as a required keyward argument") - if timestep_list is not None: - deprecate( - "timestep_list", - "1.0.0", - "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - if prev_timestep is not None: - deprecate( - "prev_timestep", - "1.0.0", - "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( - self.sigmas[self.step_index + 1], - self.sigmas[self.step_index], - self.sigmas[self.step_index - 1], - self.sigmas[self.step_index - 2], - ) - - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) - alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) - alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) - alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) - - lambda_t = torch.log(alpha_t) - torch.log(sigma_t) - lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) - lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) - lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) - - m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] - - h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 - r0, r1 = h_0 / h, h_1 / h - D0 = m0 - D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) - D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) - D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) - if self.config.algorithm_type == "dpmsolver++": - # See https://arxiv.org/abs/2206.00927 for detailed derivations - x_t = ( - (sigma_t / sigma_s0) * sample - - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 - ) - elif self.config.algorithm_type == "dpmsolver": - # See https://arxiv.org/abs/2206.00927 for detailed derivations - x_t = ( - (alpha_t / alpha_s0) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 - ) - return x_t - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index - def _init_step_index(self, timestep): - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - - index_candidates = (self.timesteps == timestep).nonzero() - - if len(index_candidates) == 0: - step_index = len(self.timesteps) - 1 - # The sigma index that is taken for the **very** first `step` - # is always the second index (or the last index if there is only 1) - # This way we can ensure we don't accidentally skip a sigma in - # case we start in the middle of the denoising schedule (e.g. for image-to-image) - elif len(index_candidates) > 1: - step_index = index_candidates[1].item() - else: - step_index = index_candidates[0].item() - - self._step_index = step_index - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep_legacy.DPMSolverMultistepSchedulerLegacy.step - def step( - self, - model_output: torch.FloatTensor, - timestep: int, - sample: torch.FloatTensor, - generator=None, - return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: - """ - Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with - the multistep DPMSolver. - - Args: - model_output (`torch.FloatTensor`): - The direct output from learned diffusion model. - timestep (`int`): - The current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - A current instance of a sample created by the diffusion process. - generator (`torch.Generator`, *optional*): - A random number generator. - return_dict (`bool`): - Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. - - Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a - tuple is returned where the first element is the sample tensor. - - """ - if self.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - if self.step_index is None: - self._init_step_index(timestep) - - # Improve numerical stability for small number of steps - lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( - self.config.euler_at_final - or (self.config.lower_order_final and len(self.timesteps) < 15) - or self.config.final_sigmas_type == "denoise_to_zero" - ) - lower_order_second = ( - (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 - ) - - model_output = self.convert_model_output(model_output, sample=sample) - for i in range(self.config.solver_order - 1): - self.model_outputs[i] = self.model_outputs[i + 1] - self.model_outputs[-1] = model_output - - if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: - noise = randn_tensor( - model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype - ) - else: - noise = None - - if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: - prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) - elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: - prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) - else: - prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample) - - if self.lower_order_nums < self.config.solver_order: - self.lower_order_nums += 1 - - # upon completion increase step index by one - self._step_index += 1 - - if not return_dict: - return (prev_sample,) - - return SchedulerOutput(prev_sample=prev_sample) - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep_legacy.DPMSolverMultistepSchedulerLegacy.scale_model_input - def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - - Args: - sample (`torch.FloatTensor`): - The input sample. - - Returns: - `torch.FloatTensor`: - A scaled input sample. - """ - return sample - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise - def add_noise( - self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.IntTensor, - ) -> torch.FloatTensor: - # Make sure sigmas and timesteps have the same device and dtype as original_samples - sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) - if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): - # mps does not support float64 - schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) - timesteps = timesteps.to(original_samples.device, dtype=torch.float32) - else: - schedule_timesteps = self.timesteps.to(original_samples.device) - timesteps = timesteps.to(original_samples.device) - - step_indices = [] - for timestep in timesteps: - index_candidates = (schedule_timesteps == timestep).nonzero() - if len(index_candidates) == 0: - step_index = len(schedule_timesteps) - 1 - elif len(index_candidates) > 1: - step_index = index_candidates[1].item() - else: - step_index = index_candidates[0].item() - step_indices.append(step_index) - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < len(original_samples.shape): - sigma = sigma.unsqueeze(-1) - - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - noisy_samples = alpha_t * original_samples + sigma_t * noise - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps From a2e090329358c8ffc9d8faf69f3c1a166f890955 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 18 Jan 2024 03:58:49 +0000 Subject: [PATCH 13/25] deprecate dpmsolver from dpm singlestep --- src/diffusers/__init__.py | 2 + src/diffusers/schedulers/__init__.py | 14 +- .../schedulers/deprecated/__init__.py | 2 + .../scheduling_dpmsolver_singlestep_legacy.py | 866 ++++++++++++++++++ .../scheduling_dpmsolver_singlestep.py | 75 +- 5 files changed, 891 insertions(+), 68 deletions(-) create mode 100644 src/diffusers/schedulers/deprecated/scheduling_dpmsolver_singlestep_legacy.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 412a61879334..bbf33a949241 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -144,6 +144,7 @@ "DPMSolverMultistepScheduler", "DPMSolverMultistepSchedulerLegacy", "DPMSolverSinglestepScheduler", + "DPMSolverSinglestepSchedulerLegacy", "EulerAncestralDiscreteScheduler", "EulerDiscreteScheduler", "HeunDiscreteScheduler", @@ -522,6 +523,7 @@ DPMSolverMultistepScheduler, DPMSolverMultistepSchedulerLegacy, DPMSolverSinglestepScheduler, + DPMSolverSinglestepSchedulerLegacy, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler, diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 29e2eedcb6cf..7cb52cb517f7 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -38,7 +38,12 @@ _dummy_modules.update(get_objects_from_module(dummy_pt_objects)) else: - _import_structure["deprecated"] = ["DPMSolverMultistepSchedulerLegacy", "KarrasVeScheduler", "ScoreSdeVpScheduler"] + _import_structure["deprecated"] = [ + "DPMSolverSinglestepSchedulerLegacy", + "DPMSolverSinglestepSchedulerLegacy", + "KarrasVeScheduler", + "ScoreSdeVpScheduler", + ] _import_structure["scheduling_amused"] = ["AmusedScheduler"] _import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"] _import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"] @@ -129,7 +134,12 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_pt_objects import * # noqa F403 else: - from .deprecated import DPMSolverMultistepSchedulerLegacy, KarrasVeScheduler, ScoreSdeVpScheduler + from .deprecated import ( + DPMSolverMultistepSchedulerLegacy, + DPMSolverSinglestepSchedulerLegacy, + KarrasVeScheduler, + ScoreSdeVpScheduler, + ) from .scheduling_amused import AmusedScheduler from .scheduling_consistency_decoder import ConsistencyDecoderScheduler from .scheduling_consistency_models import CMStochasticIterativeScheduler diff --git a/src/diffusers/schedulers/deprecated/__init__.py b/src/diffusers/schedulers/deprecated/__init__.py index 2ae7bc1d16a3..665e8dfb30ac 100644 --- a/src/diffusers/schedulers/deprecated/__init__.py +++ b/src/diffusers/schedulers/deprecated/__init__.py @@ -22,6 +22,7 @@ _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) else: _import_structure["scheduling_dpmsolver_multistep_legacy"] = ["DPMSolverMultistepSchedulerLegacy"] + _import_structure["scheduling_dpmsolver_singlestep_legacy"] = ["DPMSolverSinglestepSchedulerLegacy"] _import_structure["scheduling_karras_ve"] = ["KarrasVeScheduler"] _import_structure["scheduling_sde_vp"] = ["ScoreSdeVpScheduler"] @@ -34,6 +35,7 @@ from ..utils.dummy_pt_objects import * # noqa F403 else: from .scheduling_dpmsolver_multistep_legacy import DPMSolverMultistepSchedulerLegacy + from .scheduling_dpmsolver_singlestep_legacy import DPMSolverSinglestepSchedulerLegacy from .scheduling_karras_ve import KarrasVeScheduler from .scheduling_sde_vp import ScoreSdeVpScheduler diff --git a/src/diffusers/schedulers/deprecated/scheduling_dpmsolver_singlestep_legacy.py b/src/diffusers/schedulers/deprecated/scheduling_dpmsolver_singlestep_legacy.py new file mode 100644 index 000000000000..f018de65a034 --- /dev/null +++ b/src/diffusers/schedulers/deprecated/scheduling_dpmsolver_singlestep_legacy.py @@ -0,0 +1,866 @@ +# Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import deprecate, logging +from ..scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class DPMSolverSinglestepSchedulerLegacy(SchedulerMixin, ConfigMixin): + """ + `DPMSolverSinglestepSchedulerLegacy` is a fast dedicated high-order solver for diffusion ODEs. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + solver_order (`int`, defaults to 2): + The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++"`. + algorithm_type (`str`, defaults to `dpmsolver`): + Algorithm type for the solver: only support `dpmsolver`. It implements the algorithm in the [DPMSolver](https://huggingface.co/papers/2206.00927) + paper. + solver_type (`str`, defaults to `midpoint`): + Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the + sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + lambda_min_clipped (`float`, defaults to `-inf`): + Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the + cosine (`squaredcos_cap_v2`) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output + contains the predicted Gaussian variance. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver", + solver_type: str = "midpoint", + lower_order_final: bool = True, + use_karras_sigmas: Optional[bool] = False, + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # settings for DPM-Solver + if algorithm_type not in ["dpmsolver"]: + raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.sample = None + self.order_list = self.get_order_list(num_train_timesteps) + self._step_index = None + self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + def get_order_list(self, num_inference_steps: int) -> List[int]: + """ + Computes the solver order at each time step. + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + """ + steps = num_inference_steps + order = self.config.solver_order + if self.config.lower_order_final: + if order == 3: + if steps % 3 == 0: + orders = [1, 2, 3] * (steps // 3 - 1) + [1, 2] + [1] + elif steps % 3 == 1: + orders = [1, 2, 3] * (steps // 3) + [1] + else: + orders = [1, 2, 3] * (steps // 3) + [1, 2] + elif order == 2: + if steps % 2 == 0: + orders = [1, 2] * (steps // 2) + else: + orders = [1, 2] * (steps // 2) + [1] + elif order == 1: + orders = [1] * steps + else: + if order == 3: + orders = [1, 2, 3] * (steps // 3) + elif order == 2: + orders = [1, 2] * (steps // 2) + elif order == 1: + orders = [1] * steps + return orders + + @property + def step_index(self): + """ + The index counter for current timestep. It will increae 1 after each scheduler step. + """ + return self._step_index + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if self.config.use_karras_sigmas: + log_sigmas = np.log(sigmas) + sigmas = np.flip(sigmas).copy() + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + else: + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + + self.sigmas = torch.from_numpy(sigmas).to(device=device) + + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) + self.model_outputs = [None] * self.config.solver_order + self.sample = None + + if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0: + logger.warn( + "Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=True`." + ) + self.register_to_config(lower_order_final=True) + + if not self.config.lower_order_final and self.config.final_sigmas_type == "denoise_to_zero": + logger.warn( + "denoise_to_zero is not supported for `lower_order_final=False`. Changing scheduler {self.config} to have `lower_order_final` set to True." + ) + self.register_to_config(lower_order_final=True) + + self.order_list = self.get_order_list(num_inference_steps) + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def convert_model_output( + self, + model_output: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, + ) -> torch.FloatTensor: + """ + Convert the model output to noise. DPM-Solver is designed to discretize an integral of the noise prediction model. + + + + The algorithm and model type are decoupled. You can use either DPMSolver for both noise + prediction and data prediction models. + + + + Args: + model_output (`torch.FloatTensor`): + The direct output from the learned diffusion model. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.FloatTensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type == "dpmsolver": + if self.config.prediction_type == "epsilon": + # DPM-Solver only need the "mean" output. + if self.config.variance_type in ["learned_range"]: + model_output = model_output[:, :3] + return model_output + elif self.config.prediction_type == "sample": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = (sample - alpha_t * model_output) / sigma_t + return epsilon + elif self.config.prediction_type == "v_prediction": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = alpha_t * model_output + sigma_t * sample + return epsilon + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverSinglestepScheduler." + ) + + def dpm_solver_first_order_update( + self, + model_output: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, + ) -> torch.FloatTensor: + """ + One step for the first-order DPMSolver (equivalent to DDIM). + + Args: + model_output (`torch.FloatTensor`): + The direct output from the learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.FloatTensor`: + The sample tensor at the previous timestep. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output + return x_t + + def singlestep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.FloatTensor], + *args, + sample: torch.FloatTensor = None, + **kwargs, + ) -> torch.FloatTensor: + """ + One step for the second-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the + time `timestep_list[-2]`. + + Args: + model_output_list (`List[torch.FloatTensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): + The current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.FloatTensor`: + The sample tensor at the previous timestep. + """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s1, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m1, (1.0 / r0) * (m0 - m1) + + if self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s1) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s1) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + ) + return x_t + + def singlestep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.FloatTensor], + *args, + sample: torch.FloatTensor = None, + **kwargs, + ) -> torch.FloatTensor: + """ + One step for the third-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the + time `timestep_list[-3]`. + + Args: + model_output_list (`List[torch.FloatTensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): + The current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by diffusion process. + + Returns: + `torch.FloatTensor`: + The sample tensor at the previous timestep. + """ + + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + self.sigmas[self.step_index - 2], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) + + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + + h, h_0, h_1 = lambda_t - lambda_s2, lambda_s0 - lambda_s2, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m2 + D1_0, D1_1 = (1.0 / r1) * (m1 - m2), (1.0 / r0) * (m0 - m2) + D1 = (r0 * D1_0 - r1 * D1_1) / (r0 - r1) + D2 = 2.0 * (D1_1 - D1_0) / (r0 - r1) + + if self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s2) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1_1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s2) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 + ) + return x_t + + def singlestep_dpm_solver_update( + self, + model_output_list: List[torch.FloatTensor], + *args, + sample: torch.FloatTensor = None, + order: int = None, + **kwargs, + ) -> torch.FloatTensor: + """ + One step for the singlestep DPMSolver. + + Args: + model_output_list (`List[torch.FloatTensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): + The current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by diffusion process. + order (`int`): + The solver order at this step. + + Returns: + `torch.FloatTensor`: + The sample tensor at the previous timestep. + """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError(" missing `order` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if order == 1: + return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample) + elif order == 2: + return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample) + elif order == 3: + return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample) + else: + raise ValueError(f"Order must be 1, 2, 3, got {order}") + + def _init_step_index(self, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + index_candidates = (self.timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + self._step_index = step_index + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the singlestep DPMSolver. + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + order = self.order_list[self.step_index] + + # For img2img denoising might start with order>1 which is not possible + # In this case make sure that the first two steps are both order=1 + while self.model_outputs[-order] is None: + order -= 1 + + # For single-step solvers, we use the initial value at each time with order = 1. + if order == 1: + self.sample = sample + + prev_sample = self.singlestep_dpm_solver_update(self.model_outputs, sample=self.sample, order=order) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): + The input sample. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + step_indices = [] + for timestep in timesteps: + index_candidates = (schedule_timesteps == timestep).nonzero() + if len(index_candidates) == 0: + step_index = len(schedule_timesteps) - 1 + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + step_indices.append(step_index) + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index fb8fd80930b6..e595e57ef2c6 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -108,9 +108,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++"`. algorithm_type (`str`, defaults to `dpmsolver++`): - Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The - `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) - paper, and the `dpmsolver++` type implements the algorithms in the + Algorithm type for the solver; can be `dpmsolver++` or `sde-dpmsolver++`. The `dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. solver_type (`str`, defaults to `midpoint`): @@ -179,9 +177,13 @@ def __init__( self.init_noise_sigma = 1.0 # settings for DPM-Solver - if algorithm_type not in ["dpmsolver", "dpmsolver++"]: + if algorithm_type not in ["dpmsolver++"]: if algorithm_type == "deis": self.register_to_config(algorithm_type="dpmsolver++") + elif algorithm_type == "dpmsolver": + raise ValueError( + f"`algorithm_type` {algorithm_type} is no longer supported in {self.__class__}. Please use `DPMSolverSinglestepSchedulerLegacy` instead." + ) else: raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") if solver_type not in ["midpoint", "heun"]: @@ -190,11 +192,6 @@ def __init__( else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") - if algorithm_type != "dpmsolver++" and final_sigmas_type == "denoise_to_zero": - raise ValueError( - f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}." - ) - # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() @@ -416,13 +413,12 @@ def convert_model_output( **kwargs, ) -> torch.FloatTensor: """ - Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is - designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + Convert the model output to predict data. DPM-Solver++ is designed to discretize an integral of the data prediction model. - The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + The algorithm and model type are decoupled. You can use either DPMSolver++ for both noise prediction and data prediction models. @@ -452,7 +448,7 @@ def convert_model_output( # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type == "dpmsolver++": if self.config.prediction_type == "epsilon": - # DPM-Solver and DPM-Solver++ only need the "mean" output. + # DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned_range"]: model_output = model_output[:, :3] sigma = self.sigmas[self.step_index] @@ -474,28 +470,6 @@ def convert_model_output( x0_pred = self._threshold_sample(x0_pred) return x0_pred - # DPM-Solver needs to solve an integral of the noise prediction model. - elif self.config.algorithm_type == "dpmsolver": - if self.config.prediction_type == "epsilon": - # DPM-Solver and DPM-Solver++ only need the "mean" output. - if self.config.variance_type in ["learned_range"]: - model_output = model_output[:, :3] - return model_output - elif self.config.prediction_type == "sample": - sigma = self.sigmas[self.step_index] - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - epsilon = (sample - alpha_t * model_output) / sigma_t - return epsilon - elif self.config.prediction_type == "v_prediction": - sigma = self.sigmas[self.step_index] - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - epsilon = alpha_t * model_output + sigma_t * sample - return epsilon - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" - " `v_prediction` for the DPMSolverSinglestepScheduler." - ) def dpm_solver_first_order_update( self, @@ -549,8 +523,6 @@ def dpm_solver_first_order_update( h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output - elif self.config.algorithm_type == "dpmsolver": - x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output return x_t def singlestep_dpm_solver_second_order_update( @@ -631,20 +603,6 @@ def singlestep_dpm_solver_second_order_update( - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 ) - elif self.config.algorithm_type == "dpmsolver": - # See https://arxiv.org/abs/2206.00927 for detailed derivations - if self.config.solver_type == "midpoint": - x_t = ( - (alpha_t / alpha_s1) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 - ) - elif self.config.solver_type == "heun": - x_t = ( - (alpha_t / alpha_s1) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - ) return x_t def singlestep_dpm_solver_third_order_update( @@ -734,21 +692,6 @@ def singlestep_dpm_solver_third_order_update( + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 ) - elif self.config.algorithm_type == "dpmsolver": - # See https://arxiv.org/abs/2206.00927 for detailed derivations - if self.config.solver_type == "midpoint": - x_t = ( - (alpha_t / alpha_s2) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1_1 - ) - elif self.config.solver_type == "heun": - x_t = ( - (alpha_t / alpha_s2) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 - ) return x_t def singlestep_dpm_solver_update( From c8e9f8be8742af3ae9e1a76f06301f6a712e2f6a Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Wed, 17 Jan 2024 18:00:20 -1000 Subject: [PATCH 14/25] Update src/diffusers/schedulers/__init__.py --- src/diffusers/schedulers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 7cb52cb517f7..b8b760b3080f 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -39,7 +39,7 @@ else: _import_structure["deprecated"] = [ - "DPMSolverSinglestepSchedulerLegacy", + "DPMSolverMultistepSchedulerLegacy", "DPMSolverSinglestepSchedulerLegacy", "KarrasVeScheduler", "ScoreSdeVpScheduler", From d084322eb5e732c7c024d5eb740cdf5c0f7fd11c Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Wed, 17 Jan 2024 18:01:38 -1000 Subject: [PATCH 15/25] Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py --- src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index c2c6bc90969e..c882048c930b 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -199,7 +199,7 @@ def __init__( self.register_to_config(algorithm_type="dpmsolver++") elif algorithm_type in ["dpmsolver", "sde-dpmsolver"]: raise ValueError( - f"`algorithm_type` {algorithm_type} is no longer supported in {self.__class__}. Please use `DPMSolverSchedulerLegacy` instead." + f"`algorithm_type` {algorithm_type} is no longer supported in {self.__class__}. Please use `DPMSolverMultistepSchedulerLegacy` instead." ) else: raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") From 49b8e80261e84ee7d84c6bdfb9aeafffcfdacbb1 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 18 Jan 2024 19:03:11 +0000 Subject: [PATCH 16/25] Revert "reverse changes to inverse scheduler" This reverts commit 4271d2f370cfa255abe88e46099b5016909d4b08. --- src/diffusers/__init__.py | 2 - src/diffusers/schedulers/__init__.py | 14 +- .../schedulers/deprecated/__init__.py | 2 - ...ing_dpmsolver_multistep_inverse_legacy.py} | 420 +++++++++--------- .../scheduling_dpmsolver_multistep.py | 2 +- .../scheduling_dpmsolver_multistep_inverse.py | 34 +- .../scheduling_dpmsolver_singlestep.py | 75 +++- 7 files changed, 313 insertions(+), 236 deletions(-) rename src/diffusers/schedulers/deprecated/{scheduling_dpmsolver_singlestep_legacy.py => scheduling_dpmsolver_multistep_inverse_legacy.py} (70%) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index bbf33a949241..412a61879334 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -144,7 +144,6 @@ "DPMSolverMultistepScheduler", "DPMSolverMultistepSchedulerLegacy", "DPMSolverSinglestepScheduler", - "DPMSolverSinglestepSchedulerLegacy", "EulerAncestralDiscreteScheduler", "EulerDiscreteScheduler", "HeunDiscreteScheduler", @@ -523,7 +522,6 @@ DPMSolverMultistepScheduler, DPMSolverMultistepSchedulerLegacy, DPMSolverSinglestepScheduler, - DPMSolverSinglestepSchedulerLegacy, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler, diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index b8b760b3080f..29e2eedcb6cf 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -38,12 +38,7 @@ _dummy_modules.update(get_objects_from_module(dummy_pt_objects)) else: - _import_structure["deprecated"] = [ - "DPMSolverMultistepSchedulerLegacy", - "DPMSolverSinglestepSchedulerLegacy", - "KarrasVeScheduler", - "ScoreSdeVpScheduler", - ] + _import_structure["deprecated"] = ["DPMSolverMultistepSchedulerLegacy", "KarrasVeScheduler", "ScoreSdeVpScheduler"] _import_structure["scheduling_amused"] = ["AmusedScheduler"] _import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"] _import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"] @@ -134,12 +129,7 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_pt_objects import * # noqa F403 else: - from .deprecated import ( - DPMSolverMultistepSchedulerLegacy, - DPMSolverSinglestepSchedulerLegacy, - KarrasVeScheduler, - ScoreSdeVpScheduler, - ) + from .deprecated import DPMSolverMultistepSchedulerLegacy, KarrasVeScheduler, ScoreSdeVpScheduler from .scheduling_amused import AmusedScheduler from .scheduling_consistency_decoder import ConsistencyDecoderScheduler from .scheduling_consistency_models import CMStochasticIterativeScheduler diff --git a/src/diffusers/schedulers/deprecated/__init__.py b/src/diffusers/schedulers/deprecated/__init__.py index 665e8dfb30ac..2ae7bc1d16a3 100644 --- a/src/diffusers/schedulers/deprecated/__init__.py +++ b/src/diffusers/schedulers/deprecated/__init__.py @@ -22,7 +22,6 @@ _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) else: _import_structure["scheduling_dpmsolver_multistep_legacy"] = ["DPMSolverMultistepSchedulerLegacy"] - _import_structure["scheduling_dpmsolver_singlestep_legacy"] = ["DPMSolverSinglestepSchedulerLegacy"] _import_structure["scheduling_karras_ve"] = ["KarrasVeScheduler"] _import_structure["scheduling_sde_vp"] = ["ScoreSdeVpScheduler"] @@ -35,7 +34,6 @@ from ..utils.dummy_pt_objects import * # noqa F403 else: from .scheduling_dpmsolver_multistep_legacy import DPMSolverMultistepSchedulerLegacy - from .scheduling_dpmsolver_singlestep_legacy import DPMSolverSinglestepSchedulerLegacy from .scheduling_karras_ve import KarrasVeScheduler from .scheduling_sde_vp import ScoreSdeVpScheduler diff --git a/src/diffusers/schedulers/deprecated/scheduling_dpmsolver_singlestep_legacy.py b/src/diffusers/schedulers/deprecated/scheduling_dpmsolver_multistep_inverse_legacy.py similarity index 70% rename from src/diffusers/schedulers/deprecated/scheduling_dpmsolver_singlestep_legacy.py rename to src/diffusers/schedulers/deprecated/scheduling_dpmsolver_multistep_inverse_legacy.py index f018de65a034..bd2d4ed2871c 100644 --- a/src/diffusers/schedulers/deprecated/scheduling_dpmsolver_singlestep_legacy.py +++ b/src/diffusers/schedulers/deprecated/scheduling_dpmsolver_multistep_inverse_legacy.py @@ -21,13 +21,11 @@ import torch from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import deprecate, logging +from ...utils import deprecate +from ...utils.torch_utils import randn_tensor from ..scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, @@ -73,9 +71,9 @@ def alpha_bar_fn(t): return torch.tensor(betas, dtype=torch.float32) -class DPMSolverSinglestepSchedulerLegacy(SchedulerMixin, ConfigMixin): +class DPMSolverMultistepInverseSchedulerLegacy(SchedulerMixin, ConfigMixin): """ - `DPMSolverSinglestepSchedulerLegacy` is a fast dedicated high-order solver for diffusion ODEs. + `DPMSolverMultistepInverseSchedulerLegacy` is the reverse scheduler of [`DPMSolverMultistepSchedulerLegacy`]. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. @@ -108,7 +106,8 @@ class DPMSolverSinglestepSchedulerLegacy(SchedulerMixin, ConfigMixin): The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++"`. algorithm_type (`str`, defaults to `dpmsolver`): - Algorithm type for the solver: only support `dpmsolver`. It implements the algorithm in the [DPMSolver](https://huggingface.co/papers/2206.00927) + Algorithm type for the solver; can be `dpmsolver` or `sde-dpmsolver`. The + `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper. solver_type (`str`, defaults to `midpoint`): Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the @@ -116,6 +115,10 @@ class DPMSolverSinglestepSchedulerLegacy(SchedulerMixin, ConfigMixin): lower_order_final (`bool`, defaults to `True`): Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + euler_at_final (`bool`, defaults to `False`): + Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail + richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference + steps, but sometimes may result in blurring. use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. @@ -125,6 +128,13 @@ class DPMSolverSinglestepSchedulerLegacy(SchedulerMixin, ConfigMixin): variance_type (`str`, *optional*): Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output contains the predicted Gaussian variance. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable + Diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -137,7 +147,7 @@ def __init__( beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", - trained_betas: Optional[np.ndarray] = None, + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, solver_order: int = 2, prediction_type: str = "epsilon", thresholding: bool = False, @@ -146,9 +156,12 @@ def __init__( algorithm_type: str = "dpmsolver", solver_type: str = "midpoint", lower_order_final: bool = True, + euler_at_final: bool = False, use_karras_sigmas: Optional[bool] = False, lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -175,8 +188,9 @@ def __init__( self.init_noise_sigma = 1.0 # settings for DPM-Solver - if algorithm_type not in ["dpmsolver"]: + if algorithm_type not in ["dpmsolver", "sde-dpmsolver"]: raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") + if solver_type not in ["midpoint", "heun"]: if solver_type in ["logrho", "bh1", "bh2"]: self.register_to_config(solver_type="midpoint") @@ -185,47 +199,13 @@ def __init__( # setable values self.num_inference_steps = None - timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32).copy() self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order - self.sample = None - self.order_list = self.get_order_list(num_train_timesteps) + self.lower_order_nums = 0 self._step_index = None self.sigmas.to("cpu") # to avoid too much CPU/GPU communication - - def get_order_list(self, num_inference_steps: int) -> List[int]: - """ - Computes the solver order at each time step. - - Args: - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. - """ - steps = num_inference_steps - order = self.config.solver_order - if self.config.lower_order_final: - if order == 3: - if steps % 3 == 0: - orders = [1, 2, 3] * (steps // 3 - 1) + [1, 2] + [1] - elif steps % 3 == 1: - orders = [1, 2, 3] * (steps // 3) + [1] - else: - orders = [1, 2, 3] * (steps // 3) + [1, 2] - elif order == 2: - if steps % 2 == 0: - orders = [1, 2] * (steps // 2) - else: - orders = [1, 2] * (steps // 2) + [1] - elif order == 1: - orders = [1] * steps - else: - if order == 3: - orders = [1, 2, 3] * (steps // 3) - elif order == 2: - orders = [1, 2] * (steps // 2) - elif order == 1: - orders = [1] * steps - return orders + self.use_karras_sigmas = use_karras_sigmas @property def step_index(self): @@ -234,7 +214,7 @@ def step_index(self): """ return self._step_index - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -244,48 +224,59 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ - self.num_inference_steps = num_inference_steps # Clipping the minimum of all lambda(t) for numerical stability. # This is critical for cosine (squaredcos_cap_v2) noise schedule. - clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) - timesteps = ( - np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1) - .round()[::-1][:-1] - .copy() - .astype(np.int64) - ) + clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped).item() + self.noisiest_timestep = self.config.num_train_timesteps - 1 - clipped_idx + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.noisiest_timestep, num_inference_steps + 1).round()[:-1].copy().astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = (self.noisiest_timestep + 1) // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[:-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.arange(self.noisiest_timestep + 1, 0, -step_ratio).round()[::-1].copy().astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', " + "'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + if self.config.use_karras_sigmas: - log_sigmas = np.log(sigmas) - sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + timesteps = timesteps.copy().astype(np.int64) + sigmas = np.concatenate([sigmas[-1:], sigmas]).astype(np.float32) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 - sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + sigma_last = ( + (1 - self.alphas_cumprod[self.noisiest_timestep]) / self.alphas_cumprod[self.noisiest_timestep] + ) ** 0.5 + sigmas = np.concatenate([[sigma_last], sigmas]).astype(np.float32) - self.sigmas = torch.from_numpy(sigmas).to(device=device) + self.sigmas = torch.from_numpy(sigmas) self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) - self.model_outputs = [None] * self.config.solver_order - self.sample = None - if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0: - logger.warn( - "Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=True`." - ) - self.register_to_config(lower_order_final=True) - - if not self.config.lower_order_final and self.config.final_sigmas_type == "denoise_to_zero": - logger.warn( - "denoise_to_zero is not supported for `lower_order_final=False`. Changing scheduler {self.config} to have `lower_order_final` set to True." - ) - self.register_to_config(lower_order_final=True) + self.num_inference_steps = len(timesteps) - self.order_list = self.get_order_list(num_inference_steps) + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 # add an index counter for schedulers that allow duplicated timesteps self._step_index = None @@ -382,6 +373,7 @@ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep_legacy.DPMSolverMultistepSchedulerLegacy.convert_model_output def convert_model_output( self, model_output: torch.FloatTensor, @@ -390,12 +382,11 @@ def convert_model_output( **kwargs, ) -> torch.FloatTensor: """ - Convert the model output to noise. DPM-Solver is designed to discretize an integral of the noise prediction model. + Convert the model output to predict data. DPM-Solver is designed to discretize an integral of the noise prediction model. - The algorithm and model type are decoupled. You can use either DPMSolver for both noise - prediction and data prediction models. + The algorithm and model type are decoupled. You can use either DPMSolver for either noise prediction and data prediction models. @@ -421,34 +412,44 @@ def convert_model_output( "1.0.0", "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) - # DPM-Solver needs to solve an integral of the noise prediction model. - elif self.config.algorithm_type == "dpmsolver": + + if self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: if self.config.prediction_type == "epsilon": - # DPM-Solver only need the "mean" output. - if self.config.variance_type in ["learned_range"]: - model_output = model_output[:, :3] - return model_output + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + epsilon = model_output[:, :3] + else: + epsilon = model_output elif self.config.prediction_type == "sample": sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) epsilon = (sample - alpha_t * model_output) / sigma_t - return epsilon elif self.config.prediction_type == "v_prediction": sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) epsilon = alpha_t * model_output + sigma_t * sample - return epsilon else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" - " `v_prediction` for the DPMSolverSinglestepScheduler." + " `v_prediction` for the DPMSolverMultistepScheduler." ) + if self.config.thresholding: + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = (sample - sigma_t * epsilon) / alpha_t + x0_pred = self._threshold_sample(x0_pred) + epsilon = (sample - alpha_t * x0_pred) / sigma_t + + return epsilon + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep_legacy.DPMSolverMultistepSchedulerLegacy.dpm_solver_first_order_update def dpm_solver_first_order_update( self, model_output: torch.FloatTensor, *args, sample: torch.FloatTensor = None, + noise: Optional[torch.FloatTensor] = None, **kwargs, ) -> torch.FloatTensor: """ @@ -457,10 +458,6 @@ def dpm_solver_first_order_update( Args: model_output (`torch.FloatTensor`): The direct output from the learned diffusion model. - timestep (`int`): - The current discrete timestep in the diffusion chain. - prev_timestep (`int`): - The previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. @@ -488,34 +485,49 @@ def dpm_solver_first_order_update( "1.0.0", "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + h = lambda_t - lambda_s - if self.config.algorithm_type == "dpmsolver": + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "dpmsolver": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ( + (sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ( + (alpha_t / alpha_s) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) return x_t - def singlestep_dpm_solver_second_order_update( + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep_legacy.DPMSolverMultistepSchedulerLegacy.multistep_dpm_solver_second_order_update + def multistep_dpm_solver_second_order_update( self, model_output_list: List[torch.FloatTensor], *args, sample: torch.FloatTensor = None, + noise: Optional[torch.FloatTensor] = None, **kwargs, ) -> torch.FloatTensor: """ - One step for the second-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the - time `timestep_list[-2]`. + One step for the second-order multistep DPMSolver. Args: model_output_list (`List[torch.FloatTensor]`): The direct outputs from learned diffusion model at current and latter timesteps. - timestep (`int`): - The current and latter discrete timestep in the diffusion chain. - prev_timestep (`int`): - The previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. @@ -543,6 +555,7 @@ def singlestep_dpm_solver_second_order_update( "1.0.0", "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) + sigma_t, sigma_s0, sigma_s1 = ( self.sigmas[self.step_index + 1], self.sigmas[self.step_index], @@ -559,27 +572,73 @@ def singlestep_dpm_solver_second_order_update( m0, m1 = model_output_list[-1], model_output_list[-2] - h, h_0 = lambda_t - lambda_s1, lambda_s0 - lambda_s1 + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 r0 = h_0 / h - D0, D1 = m1, (1.0 / r0) * (m0 - m1) - - if self.config.algorithm_type == "dpmsolver": + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + ) + elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( - (alpha_t / alpha_s1) * sample + (alpha_t / alpha_s0) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 ) elif self.config.solver_type == "heun": x_t = ( - (alpha_t / alpha_s1) * sample + (alpha_t / alpha_s0) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * (torch.exp(h) - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) return x_t - def singlestep_dpm_solver_third_order_update( + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep_legacy.DPMSolverMultistepSchedulerLegacy.multistep_dpm_solver_third_order_update + def multistep_dpm_solver_third_order_update( self, model_output_list: List[torch.FloatTensor], *args, @@ -587,16 +646,11 @@ def singlestep_dpm_solver_third_order_update( **kwargs, ) -> torch.FloatTensor: """ - One step for the third-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the - time `timestep_list[-3]`. + One step for the third-order multistep DPMSolver. Args: model_output_list (`List[torch.FloatTensor]`): The direct outputs from learned diffusion model at current and latter timesteps. - timestep (`int`): - The current and latter discrete timestep in the diffusion chain. - prev_timestep (`int`): - The previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by diffusion process. @@ -645,92 +699,31 @@ def singlestep_dpm_solver_third_order_update( m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] - h, h_0, h_1 = lambda_t - lambda_s2, lambda_s0 - lambda_s2, lambda_s1 - lambda_s2 + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 r0, r1 = h_0 / h, h_1 / h - D0 = m2 - D1_0, D1_1 = (1.0 / r1) * (m1 - m2), (1.0 / r0) * (m0 - m2) - D1 = (r0 * D1_0 - r1 * D1_1) / (r0 - r1) - D2 = 2.0 * (D1_1 - D1_0) / (r0 - r1) - - if self.config.algorithm_type == "dpmsolver": + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2206.00927 for detailed derivations - if self.config.solver_type == "midpoint": - x_t = ( - (alpha_t / alpha_s2) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1_1 - ) - elif self.config.solver_type == "heun": - x_t = ( - (alpha_t / alpha_s2) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 - ) - return x_t - - def singlestep_dpm_solver_update( - self, - model_output_list: List[torch.FloatTensor], - *args, - sample: torch.FloatTensor = None, - order: int = None, - **kwargs, - ) -> torch.FloatTensor: - """ - One step for the singlestep DPMSolver. - - Args: - model_output_list (`List[torch.FloatTensor]`): - The direct outputs from learned diffusion model at current and latter timesteps. - timestep (`int`): - The current and latter discrete timestep in the diffusion chain. - prev_timestep (`int`): - The previous discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - A current instance of a sample created by diffusion process. - order (`int`): - The solver order at this step. - - Returns: - `torch.FloatTensor`: - The sample tensor at the previous timestep. - """ - timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) - prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) - if sample is None: - if len(args) > 2: - sample = args[2] - else: - raise ValueError(" missing`sample` as a required keyward argument") - if order is None: - if len(args) > 3: - order = args[3] - else: - raise ValueError(" missing `order` as a required keyward argument") - if timestep_list is not None: - deprecate( - "timestep_list", - "1.0.0", - "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 ) - - if prev_timestep is not None: - deprecate( - "prev_timestep", - "1.0.0", - "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 ) + return x_t - if order == 1: - return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample) - elif order == 2: - return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample) - elif order == 3: - return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample) - else: - raise ValueError(f"Order must be 1, 2, 3, got {order}") - + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index def _init_step_index(self, timestep): if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) @@ -750,16 +743,18 @@ def _init_step_index(self, timestep): self._step_index = step_index + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep_legacy.DPMSolverMultistepSchedulerLegacy.step def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, + generator=None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with - the singlestep DPMSolver. + the multistep DPMSolver. Args: model_output (`torch.FloatTensor`): @@ -768,6 +763,8 @@ def step( The current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. @@ -785,23 +782,37 @@ def step( if self.step_index is None: self._init_step_index(timestep) + # Improve numerical stability for small number of steps + lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( + self.config.euler_at_final + or (self.config.lower_order_final and len(self.timesteps) < 15) + or self.config.final_sigmas_type == "denoise_to_zero" + ) + lower_order_second = ( + (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + model_output = self.convert_model_output(model_output, sample=sample) for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output - order = self.order_list[self.step_index] - - # For img2img denoising might start with order>1 which is not possible - # In this case make sure that the first two steps are both order=1 - while self.model_outputs[-order] is None: - order -= 1 + if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + else: + noise = None - # For single-step solvers, we use the initial value at each time with order = 1. - if order == 1: - self.sample = sample + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) + else: + prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample) - prev_sample = self.singlestep_dpm_solver_update(self.model_outputs, sample=self.sample, order=order) + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 # upon completion increase step index by one self._step_index += 1 @@ -811,6 +822,7 @@ def step( return SchedulerOutput(prev_sample=prev_sample) + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep_legacy.DPMSolverMultistepSchedulerLegacy.scale_model_input def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index c882048c930b..c2c6bc90969e 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -199,7 +199,7 @@ def __init__( self.register_to_config(algorithm_type="dpmsolver++") elif algorithm_type in ["dpmsolver", "sde-dpmsolver"]: raise ValueError( - f"`algorithm_type` {algorithm_type} is no longer supported in {self.__class__}. Please use `DPMSolverMultistepSchedulerLegacy` instead." + f"`algorithm_type` {algorithm_type} is no longer supported in {self.__class__}. Please use `DPMSolverSchedulerLegacy` instead." ) else: raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 0f64844e865c..1b412f486994 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -160,6 +160,7 @@ def __init__( lower_order_final: bool = True, euler_at_final: bool = False, use_karras_sigmas: Optional[bool] = False, + final_sigmas_type: Optional[str] = "default", # "denoise_to_zero", "default" lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, timestep_spacing: str = "linspace", @@ -202,6 +203,11 @@ def __init__( else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "denoise_to_zero": + raise ValueError( + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}." + ) + # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32).copy() @@ -264,13 +270,27 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = timesteps.copy().astype(np.int64) - sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + if self.config.final_sigmas_type == "default": + sigmas = np.concatenate([sigmas[0], sigmas]).astype(np.float32) + elif self.config.final_sigmas_type == "denoise_to_zero": + sigmas = np.concatenate([np.array([0]), sigmas]).astype(np.float32) + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" + ) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - sigma_max = ( - (1 - self.alphas_cumprod[self.noisiest_timestep]) / self.alphas_cumprod[self.noisiest_timestep] - ) ** 0.5 - sigmas = np.concatenate([sigmas, [sigma_max]]).astype(np.float32) + if self.config.final_sigmas_type == "default": + sigma_last = ( + (1 - self.alphas_cumprod[self.noisiest_timestep]) / self.alphas_cumprod[self.noisiest_timestep] + ) ** 0.5 + elif self.config.final_sigmas_type == "denoise_to_zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" + ) + sigmas = np.concatenate([[sigma_last], sigmas]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) @@ -818,7 +838,9 @@ def step( # Improve numerical stability for small number of steps lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( - self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15) + self.config.euler_at_final + or (self.config.lower_order_final and len(self.timesteps) < 15) + or self.config.final_sigmas_type == "denoise_to_zero" ) lower_order_second = ( (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index e595e57ef2c6..fb8fd80930b6 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -108,7 +108,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++"`. algorithm_type (`str`, defaults to `dpmsolver++`): - Algorithm type for the solver; can be `dpmsolver++` or `sde-dpmsolver++`. The `dpmsolver++` type implements the algorithms in the + Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The + `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) + paper, and the `dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. solver_type (`str`, defaults to `midpoint`): @@ -177,13 +179,9 @@ def __init__( self.init_noise_sigma = 1.0 # settings for DPM-Solver - if algorithm_type not in ["dpmsolver++"]: + if algorithm_type not in ["dpmsolver", "dpmsolver++"]: if algorithm_type == "deis": self.register_to_config(algorithm_type="dpmsolver++") - elif algorithm_type == "dpmsolver": - raise ValueError( - f"`algorithm_type` {algorithm_type} is no longer supported in {self.__class__}. Please use `DPMSolverSinglestepSchedulerLegacy` instead." - ) else: raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") if solver_type not in ["midpoint", "heun"]: @@ -192,6 +190,11 @@ def __init__( else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + if algorithm_type != "dpmsolver++" and final_sigmas_type == "denoise_to_zero": + raise ValueError( + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}." + ) + # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() @@ -413,12 +416,13 @@ def convert_model_output( **kwargs, ) -> torch.FloatTensor: """ - Convert the model output to predict data. DPM-Solver++ is designed to discretize an + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an integral of the data prediction model. - The algorithm and model type are decoupled. You can use either DPMSolver++ for both noise + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise prediction and data prediction models. @@ -448,7 +452,7 @@ def convert_model_output( # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type == "dpmsolver++": if self.config.prediction_type == "epsilon": - # DPM-Solver++ only need the "mean" output. + # DPM-Solver and DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned_range"]: model_output = model_output[:, :3] sigma = self.sigmas[self.step_index] @@ -470,6 +474,28 @@ def convert_model_output( x0_pred = self._threshold_sample(x0_pred) return x0_pred + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type == "dpmsolver": + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned_range"]: + model_output = model_output[:, :3] + return model_output + elif self.config.prediction_type == "sample": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = (sample - alpha_t * model_output) / sigma_t + return epsilon + elif self.config.prediction_type == "v_prediction": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = alpha_t * model_output + sigma_t * sample + return epsilon + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverSinglestepScheduler." + ) def dpm_solver_first_order_update( self, @@ -523,6 +549,8 @@ def dpm_solver_first_order_update( h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output return x_t def singlestep_dpm_solver_second_order_update( @@ -603,6 +631,20 @@ def singlestep_dpm_solver_second_order_update( - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s1) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s1) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + ) return x_t def singlestep_dpm_solver_third_order_update( @@ -692,6 +734,21 @@ def singlestep_dpm_solver_third_order_update( + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s2) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1_1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s2) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 + ) return x_t def singlestep_dpm_solver_update( From 179acd70a2f3c85cfb54a28ecf96917ff2809daf Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 18 Jan 2024 19:05:21 +0000 Subject: [PATCH 17/25] Revert "udpate and deprecate inverse" This reverts commit 1827a4c62c4083a049af8d348e70d9a035f20335. --- ...ling_dpmsolver_multistep_inverse_legacy.py | 878 ------------------ .../scheduling_dpmsolver_multistep_inverse.py | 12 +- 2 files changed, 9 insertions(+), 881 deletions(-) delete mode 100644 src/diffusers/schedulers/deprecated/scheduling_dpmsolver_multistep_inverse_legacy.py diff --git a/src/diffusers/schedulers/deprecated/scheduling_dpmsolver_multistep_inverse_legacy.py b/src/diffusers/schedulers/deprecated/scheduling_dpmsolver_multistep_inverse_legacy.py deleted file mode 100644 index bd2d4ed2871c..000000000000 --- a/src/diffusers/schedulers/deprecated/scheduling_dpmsolver_multistep_inverse_legacy.py +++ /dev/null @@ -1,878 +0,0 @@ -# Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver - -import math -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch - -from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import deprecate -from ...utils.torch_utils import randn_tensor -from ..scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput - - -# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - if alpha_transform_type == "cosine": - - def alpha_bar_fn(t): - return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 - - elif alpha_transform_type == "exp": - - def alpha_bar_fn(t): - return math.exp(t * -12.0) - - else: - raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float32) - - -class DPMSolverMultistepInverseSchedulerLegacy(SchedulerMixin, ConfigMixin): - """ - `DPMSolverMultistepInverseSchedulerLegacy` is the reverse scheduler of [`DPMSolverMultistepSchedulerLegacy`]. - - This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic - methods the library implements for all schedulers such as loading and saving. - - Args: - num_train_timesteps (`int`, defaults to 1000): - The number of diffusion steps to train the model. - beta_start (`float`, defaults to 0.0001): - The starting `beta` value of inference. - beta_end (`float`, defaults to 0.02): - The final `beta` value. - beta_schedule (`str`, defaults to `"linear"`): - The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - trained_betas (`np.ndarray`, *optional*): - Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. - solver_order (`int`, defaults to 2): - The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided - sampling, and `solver_order=3` for unconditional sampling. - prediction_type (`str`, defaults to `epsilon`, *optional*): - Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), - `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen - Video](https://imagen.research.google/video/paper.pdf) paper). - thresholding (`bool`, defaults to `False`): - Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such - as Stable Diffusion. - dynamic_thresholding_ratio (`float`, defaults to 0.995): - The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. - sample_max_value (`float`, defaults to 1.0): - The threshold value for dynamic thresholding. Valid only when `thresholding=True` and - `algorithm_type="dpmsolver++"`. - algorithm_type (`str`, defaults to `dpmsolver`): - Algorithm type for the solver; can be `dpmsolver` or `sde-dpmsolver`. The - `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) - paper. - solver_type (`str`, defaults to `midpoint`): - Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the - sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. - lower_order_final (`bool`, defaults to `True`): - Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can - stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. - euler_at_final (`bool`, defaults to `False`): - Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail - richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference - steps, but sometimes may result in blurring. - use_karras_sigmas (`bool`, *optional*, defaults to `False`): - Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, - the sigmas are determined according to a sequence of noise levels {σi}. - lambda_min_clipped (`float`, defaults to `-inf`): - Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the - cosine (`squaredcos_cap_v2`) noise schedule. - variance_type (`str`, *optional*): - Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output - contains the predicted Gaussian variance. - timestep_spacing (`str`, defaults to `"linspace"`): - The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. - steps_offset (`int`, defaults to 0): - An offset added to the inference steps. You can use a combination of `offset=1` and - `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable - Diffusion. - """ - - _compatibles = [e.name for e in KarrasDiffusionSchedulers] - order = 1 - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, - solver_order: int = 2, - prediction_type: str = "epsilon", - thresholding: bool = False, - dynamic_thresholding_ratio: float = 0.995, - sample_max_value: float = 1.0, - algorithm_type: str = "dpmsolver", - solver_type: str = "midpoint", - lower_order_final: bool = True, - euler_at_final: bool = False, - use_karras_sigmas: Optional[bool] = False, - lambda_min_clipped: float = -float("inf"), - variance_type: Optional[str] = None, - timestep_spacing: str = "linspace", - steps_offset: int = 0, - ): - if trained_betas is not None: - self.betas = torch.tensor(trained_betas, dtype=torch.float32) - elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - # Currently we only support VP-type noise schedule - self.alpha_t = torch.sqrt(self.alphas_cumprod) - self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) - self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) - self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 - - # standard deviation of the initial noise distribution - self.init_noise_sigma = 1.0 - - # settings for DPM-Solver - if algorithm_type not in ["dpmsolver", "sde-dpmsolver"]: - raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") - - if solver_type not in ["midpoint", "heun"]: - if solver_type in ["logrho", "bh1", "bh2"]: - self.register_to_config(solver_type="midpoint") - else: - raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") - - # setable values - self.num_inference_steps = None - timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32).copy() - self.timesteps = torch.from_numpy(timesteps) - self.model_outputs = [None] * solver_order - self.lower_order_nums = 0 - self._step_index = None - self.sigmas.to("cpu") # to avoid too much CPU/GPU communication - self.use_karras_sigmas = use_karras_sigmas - - @property - def step_index(self): - """ - The index counter for current timestep. It will increae 1 after each scheduler step. - """ - return self._step_index - - def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): - """ - Sets the discrete timesteps used for the diffusion chain (to be run before inference). - - Args: - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - """ - # Clipping the minimum of all lambda(t) for numerical stability. - # This is critical for cosine (squaredcos_cap_v2) noise schedule. - clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped).item() - self.noisiest_timestep = self.config.num_train_timesteps - 1 - clipped_idx - - # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 - if self.config.timestep_spacing == "linspace": - timesteps = ( - np.linspace(0, self.noisiest_timestep, num_inference_steps + 1).round()[:-1].copy().astype(np.int64) - ) - elif self.config.timestep_spacing == "leading": - step_ratio = (self.noisiest_timestep + 1) // (num_inference_steps + 1) - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[:-1].copy().astype(np.int64) - timesteps += self.config.steps_offset - elif self.config.timestep_spacing == "trailing": - step_ratio = self.config.num_train_timesteps / num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = np.arange(self.noisiest_timestep + 1, 0, -step_ratio).round()[::-1].copy().astype(np.int64) - timesteps -= 1 - else: - raise ValueError( - f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', " - "'leading' or 'trailing'." - ) - - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - log_sigmas = np.log(sigmas) - - if self.config.use_karras_sigmas: - sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - timesteps = timesteps.copy().astype(np.int64) - sigmas = np.concatenate([sigmas[-1:], sigmas]).astype(np.float32) - else: - sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - sigma_last = ( - (1 - self.alphas_cumprod[self.noisiest_timestep]) / self.alphas_cumprod[self.noisiest_timestep] - ) ** 0.5 - sigmas = np.concatenate([[sigma_last], sigmas]).astype(np.float32) - - self.sigmas = torch.from_numpy(sigmas) - - self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) - - self.num_inference_steps = len(timesteps) - - self.model_outputs = [ - None, - ] * self.config.solver_order - self.lower_order_nums = 0 - - # add an index counter for schedulers that allow duplicated timesteps - self._step_index = None - self.sigmas.to("cpu") # to avoid too much CPU/GPU communication - - # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample - def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: - """ - "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the - prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by - s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing - pixels from saturation at each step. We find that dynamic thresholding results in significantly better - photorealism as well as better image-text alignment, especially when using very large guidance weights." - - https://arxiv.org/abs/2205.11487 - """ - dtype = sample.dtype - batch_size, channels, *remaining_dims = sample.shape - - if dtype not in (torch.float32, torch.float64): - sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half - - # Flatten sample for doing quantile calculation along each image - sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) - - abs_sample = sample.abs() # "a certain percentile absolute pixel value" - - s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) - s = torch.clamp( - s, min=1, max=self.config.sample_max_value - ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] - s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 - sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" - - sample = sample.reshape(batch_size, channels, *remaining_dims) - sample = sample.to(dtype) - - return sample - - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t - def _sigma_to_t(self, sigma, log_sigmas): - # get log sigma - log_sigma = np.log(np.maximum(sigma, 1e-10)) - - # get distribution - dists = log_sigma - log_sigmas[:, np.newaxis] - - # get sigmas range - low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) - high_idx = low_idx + 1 - - low = log_sigmas[low_idx] - high = log_sigmas[high_idx] - - # interpolate sigmas - w = (low - log_sigma) / (low - high) - w = np.clip(w, 0, 1) - - # transform interpolation to time range - t = (1 - w) * low_idx + w * high_idx - t = t.reshape(sigma.shape) - return t - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t - def _sigma_to_alpha_sigma_t(self, sigma): - alpha_t = 1 / ((sigma**2 + 1) ** 0.5) - sigma_t = sigma * alpha_t - - return alpha_t, sigma_t - - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras - def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: - """Constructs the noise schedule of Karras et al. (2022).""" - - # Hack to make sure that other schedulers which copy this function don't break - # TODO: Add this logic to the other schedulers - if hasattr(self.config, "sigma_min"): - sigma_min = self.config.sigma_min - else: - sigma_min = None - - if hasattr(self.config, "sigma_max"): - sigma_max = self.config.sigma_max - else: - sigma_max = None - - sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() - sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() - - rho = 7.0 # 7.0 is the value used in the paper - ramp = np.linspace(0, 1, num_inference_steps) - min_inv_rho = sigma_min ** (1 / rho) - max_inv_rho = sigma_max ** (1 / rho) - sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho - return sigmas - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep_legacy.DPMSolverMultistepSchedulerLegacy.convert_model_output - def convert_model_output( - self, - model_output: torch.FloatTensor, - *args, - sample: torch.FloatTensor = None, - **kwargs, - ) -> torch.FloatTensor: - """ - Convert the model output to predict data. DPM-Solver is designed to discretize an integral of the noise prediction model. - - - - The algorithm and model type are decoupled. You can use either DPMSolver for either noise prediction and data prediction models. - - - - Args: - model_output (`torch.FloatTensor`): - The direct output from the learned diffusion model. - sample (`torch.FloatTensor`): - A current instance of a sample created by the diffusion process. - - Returns: - `torch.FloatTensor`: - The converted model output. - """ - timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) - if sample is None: - if len(args) > 1: - sample = args[1] - else: - raise ValueError("missing `sample` as a required keyward argument") - if timestep is not None: - deprecate( - "timesteps", - "1.0.0", - "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - if self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: - if self.config.prediction_type == "epsilon": - # DPM-Solver and DPM-Solver++ only need the "mean" output. - if self.config.variance_type in ["learned", "learned_range"]: - epsilon = model_output[:, :3] - else: - epsilon = model_output - elif self.config.prediction_type == "sample": - sigma = self.sigmas[self.step_index] - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - epsilon = (sample - alpha_t * model_output) / sigma_t - elif self.config.prediction_type == "v_prediction": - sigma = self.sigmas[self.step_index] - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - epsilon = alpha_t * model_output + sigma_t * sample - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" - " `v_prediction` for the DPMSolverMultistepScheduler." - ) - - if self.config.thresholding: - sigma = self.sigmas[self.step_index] - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - x0_pred = (sample - sigma_t * epsilon) / alpha_t - x0_pred = self._threshold_sample(x0_pred) - epsilon = (sample - alpha_t * x0_pred) / sigma_t - - return epsilon - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep_legacy.DPMSolverMultistepSchedulerLegacy.dpm_solver_first_order_update - def dpm_solver_first_order_update( - self, - model_output: torch.FloatTensor, - *args, - sample: torch.FloatTensor = None, - noise: Optional[torch.FloatTensor] = None, - **kwargs, - ) -> torch.FloatTensor: - """ - One step for the first-order DPMSolver (equivalent to DDIM). - - Args: - model_output (`torch.FloatTensor`): - The direct output from the learned diffusion model. - sample (`torch.FloatTensor`): - A current instance of a sample created by the diffusion process. - - Returns: - `torch.FloatTensor`: - The sample tensor at the previous timestep. - """ - timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) - prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) - if sample is None: - if len(args) > 2: - sample = args[2] - else: - raise ValueError(" missing `sample` as a required keyward argument") - if timestep is not None: - deprecate( - "timesteps", - "1.0.0", - "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - if prev_timestep is not None: - deprecate( - "prev_timestep", - "1.0.0", - "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) - alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) - lambda_t = torch.log(alpha_t) - torch.log(sigma_t) - lambda_s = torch.log(alpha_s) - torch.log(sigma_s) - - h = lambda_t - lambda_s - if self.config.algorithm_type == "dpmsolver++": - x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output - elif self.config.algorithm_type == "dpmsolver": - x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output - elif self.config.algorithm_type == "sde-dpmsolver++": - assert noise is not None - x_t = ( - (sigma_t / sigma_s * torch.exp(-h)) * sample - + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output - + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise - ) - elif self.config.algorithm_type == "sde-dpmsolver": - assert noise is not None - x_t = ( - (alpha_t / alpha_s) * sample - - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output - + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise - ) - return x_t - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep_legacy.DPMSolverMultistepSchedulerLegacy.multistep_dpm_solver_second_order_update - def multistep_dpm_solver_second_order_update( - self, - model_output_list: List[torch.FloatTensor], - *args, - sample: torch.FloatTensor = None, - noise: Optional[torch.FloatTensor] = None, - **kwargs, - ) -> torch.FloatTensor: - """ - One step for the second-order multistep DPMSolver. - - Args: - model_output_list (`List[torch.FloatTensor]`): - The direct outputs from learned diffusion model at current and latter timesteps. - sample (`torch.FloatTensor`): - A current instance of a sample created by the diffusion process. - - Returns: - `torch.FloatTensor`: - The sample tensor at the previous timestep. - """ - timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) - prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) - if sample is None: - if len(args) > 2: - sample = args[2] - else: - raise ValueError(" missing `sample` as a required keyward argument") - if timestep_list is not None: - deprecate( - "timestep_list", - "1.0.0", - "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - if prev_timestep is not None: - deprecate( - "prev_timestep", - "1.0.0", - "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - sigma_t, sigma_s0, sigma_s1 = ( - self.sigmas[self.step_index + 1], - self.sigmas[self.step_index], - self.sigmas[self.step_index - 1], - ) - - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) - alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) - alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) - - lambda_t = torch.log(alpha_t) - torch.log(sigma_t) - lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) - lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) - - m0, m1 = model_output_list[-1], model_output_list[-2] - - h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 - r0 = h_0 / h - D0, D1 = m0, (1.0 / r0) * (m0 - m1) - if self.config.algorithm_type == "dpmsolver++": - # See https://arxiv.org/abs/2211.01095 for detailed derivations - if self.config.solver_type == "midpoint": - x_t = ( - (sigma_t / sigma_s0) * sample - - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 - ) - elif self.config.solver_type == "heun": - x_t = ( - (sigma_t / sigma_s0) * sample - - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - ) - elif self.config.algorithm_type == "dpmsolver": - # See https://arxiv.org/abs/2206.00927 for detailed derivations - if self.config.solver_type == "midpoint": - x_t = ( - (alpha_t / alpha_s0) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 - ) - elif self.config.solver_type == "heun": - x_t = ( - (alpha_t / alpha_s0) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - ) - elif self.config.algorithm_type == "sde-dpmsolver++": - assert noise is not None - if self.config.solver_type == "midpoint": - x_t = ( - (sigma_t / sigma_s0 * torch.exp(-h)) * sample - + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 - + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 - + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise - ) - elif self.config.solver_type == "heun": - x_t = ( - (sigma_t / sigma_s0 * torch.exp(-h)) * sample - + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 - + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 - + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise - ) - elif self.config.algorithm_type == "sde-dpmsolver": - assert noise is not None - if self.config.solver_type == "midpoint": - x_t = ( - (alpha_t / alpha_s0) * sample - - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * (torch.exp(h) - 1.0)) * D1 - + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise - ) - elif self.config.solver_type == "heun": - x_t = ( - (alpha_t / alpha_s0) * sample - - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise - ) - return x_t - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep_legacy.DPMSolverMultistepSchedulerLegacy.multistep_dpm_solver_third_order_update - def multistep_dpm_solver_third_order_update( - self, - model_output_list: List[torch.FloatTensor], - *args, - sample: torch.FloatTensor = None, - **kwargs, - ) -> torch.FloatTensor: - """ - One step for the third-order multistep DPMSolver. - - Args: - model_output_list (`List[torch.FloatTensor]`): - The direct outputs from learned diffusion model at current and latter timesteps. - sample (`torch.FloatTensor`): - A current instance of a sample created by diffusion process. - - Returns: - `torch.FloatTensor`: - The sample tensor at the previous timestep. - """ - - timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) - prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) - if sample is None: - if len(args) > 2: - sample = args[2] - else: - raise ValueError(" missing`sample` as a required keyward argument") - if timestep_list is not None: - deprecate( - "timestep_list", - "1.0.0", - "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - if prev_timestep is not None: - deprecate( - "prev_timestep", - "1.0.0", - "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( - self.sigmas[self.step_index + 1], - self.sigmas[self.step_index], - self.sigmas[self.step_index - 1], - self.sigmas[self.step_index - 2], - ) - - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) - alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) - alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) - alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) - - lambda_t = torch.log(alpha_t) - torch.log(sigma_t) - lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) - lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) - lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) - - m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] - - h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 - r0, r1 = h_0 / h, h_1 / h - D0 = m0 - D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) - D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) - D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) - if self.config.algorithm_type == "dpmsolver++": - # See https://arxiv.org/abs/2206.00927 for detailed derivations - x_t = ( - (sigma_t / sigma_s0) * sample - - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 - ) - elif self.config.algorithm_type == "dpmsolver": - # See https://arxiv.org/abs/2206.00927 for detailed derivations - x_t = ( - (alpha_t / alpha_s0) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 - ) - return x_t - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index - def _init_step_index(self, timestep): - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - - index_candidates = (self.timesteps == timestep).nonzero() - - if len(index_candidates) == 0: - step_index = len(self.timesteps) - 1 - # The sigma index that is taken for the **very** first `step` - # is always the second index (or the last index if there is only 1) - # This way we can ensure we don't accidentally skip a sigma in - # case we start in the middle of the denoising schedule (e.g. for image-to-image) - elif len(index_candidates) > 1: - step_index = index_candidates[1].item() - else: - step_index = index_candidates[0].item() - - self._step_index = step_index - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep_legacy.DPMSolverMultistepSchedulerLegacy.step - def step( - self, - model_output: torch.FloatTensor, - timestep: int, - sample: torch.FloatTensor, - generator=None, - return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: - """ - Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with - the multistep DPMSolver. - - Args: - model_output (`torch.FloatTensor`): - The direct output from learned diffusion model. - timestep (`int`): - The current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - A current instance of a sample created by the diffusion process. - generator (`torch.Generator`, *optional*): - A random number generator. - return_dict (`bool`): - Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. - - Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a - tuple is returned where the first element is the sample tensor. - - """ - if self.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - if self.step_index is None: - self._init_step_index(timestep) - - # Improve numerical stability for small number of steps - lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( - self.config.euler_at_final - or (self.config.lower_order_final and len(self.timesteps) < 15) - or self.config.final_sigmas_type == "denoise_to_zero" - ) - lower_order_second = ( - (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 - ) - - model_output = self.convert_model_output(model_output, sample=sample) - for i in range(self.config.solver_order - 1): - self.model_outputs[i] = self.model_outputs[i + 1] - self.model_outputs[-1] = model_output - - if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: - noise = randn_tensor( - model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype - ) - else: - noise = None - - if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: - prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) - elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: - prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) - else: - prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample) - - if self.lower_order_nums < self.config.solver_order: - self.lower_order_nums += 1 - - # upon completion increase step index by one - self._step_index += 1 - - if not return_dict: - return (prev_sample,) - - return SchedulerOutput(prev_sample=prev_sample) - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep_legacy.DPMSolverMultistepSchedulerLegacy.scale_model_input - def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - - Args: - sample (`torch.FloatTensor`): - The input sample. - - Returns: - `torch.FloatTensor`: - A scaled input sample. - """ - return sample - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise - def add_noise( - self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.IntTensor, - ) -> torch.FloatTensor: - # Make sure sigmas and timesteps have the same device and dtype as original_samples - sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) - if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): - # mps does not support float64 - schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) - timesteps = timesteps.to(original_samples.device, dtype=torch.float32) - else: - schedule_timesteps = self.timesteps.to(original_samples.device) - timesteps = timesteps.to(original_samples.device) - - step_indices = [] - for timestep in timesteps: - index_candidates = (schedule_timesteps == timestep).nonzero() - if len(index_candidates) == 0: - step_index = len(schedule_timesteps) - 1 - elif len(index_candidates) > 1: - step_index = index_candidates[1].item() - else: - step_index = index_candidates[0].item() - step_indices.append(step_index) - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < len(original_samples.shape): - sigma = sigma.unsqueeze(-1) - - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - noisy_samples = alpha_t * original_samples + sigma_t * noise - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 1b412f486994..c7784da7fae2 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -271,9 +271,9 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = timesteps.copy().astype(np.int64) if self.config.final_sigmas_type == "default": - sigmas = np.concatenate([sigmas[0], sigmas]).astype(np.float32) + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) elif self.config.final_sigmas_type == "denoise_to_zero": - sigmas = np.concatenate([np.array([0]), sigmas]).astype(np.float32) + sigmas = np.concatenate([sigmas, np.array([0])]).astype(np.float32) else: raise ValueError( f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" @@ -290,9 +290,15 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc raise ValueError( f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" ) - sigmas = np.concatenate([[sigma_last], sigmas]).astype(np.float32) + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) + + # when num_inference_steps == num_train_timesteps, we can end up with + # duplicates in timesteps. + _, unique_indices = np.unique(timesteps, return_index=True) + timesteps = timesteps[np.sort(unique_indices)] + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) self.num_inference_steps = len(timesteps) From 8b16ab643eef81d156a1e38086d6f1bd782b2914 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 18 Jan 2024 19:05:44 +0000 Subject: [PATCH 18/25] Revert "remove dpmsolver and sde-dpmsolver" This reverts commit 15d4fb39adc56dc19ab9013be97a274953681553. --- src/diffusers/__init__.py | 2 - src/diffusers/schedulers/__init__.py | 4 +- .../schedulers/deprecated/__init__.py | 2 - .../scheduling_dpmsolver_multistep_legacy.py | 842 ------------------ .../scheduling_dpmsolver_multistep.py | 104 ++- 5 files changed, 96 insertions(+), 858 deletions(-) delete mode 100644 src/diffusers/schedulers/deprecated/scheduling_dpmsolver_multistep_legacy.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 412a61879334..180b210953c1 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -142,7 +142,6 @@ "DEISMultistepScheduler", "DPMSolverMultistepInverseScheduler", "DPMSolverMultistepScheduler", - "DPMSolverMultistepSchedulerLegacy", "DPMSolverSinglestepScheduler", "EulerAncestralDiscreteScheduler", "EulerDiscreteScheduler", @@ -520,7 +519,6 @@ DEISMultistepScheduler, DPMSolverMultistepInverseScheduler, DPMSolverMultistepScheduler, - DPMSolverMultistepSchedulerLegacy, DPMSolverSinglestepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 29e2eedcb6cf..e908ba87acdd 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -38,7 +38,7 @@ _dummy_modules.update(get_objects_from_module(dummy_pt_objects)) else: - _import_structure["deprecated"] = ["DPMSolverMultistepSchedulerLegacy", "KarrasVeScheduler", "ScoreSdeVpScheduler"] + _import_structure["deprecated"] = ["KarrasVeScheduler", "ScoreSdeVpScheduler"] _import_structure["scheduling_amused"] = ["AmusedScheduler"] _import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"] _import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"] @@ -129,7 +129,7 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_pt_objects import * # noqa F403 else: - from .deprecated import DPMSolverMultistepSchedulerLegacy, KarrasVeScheduler, ScoreSdeVpScheduler + from .deprecated import KarrasVeScheduler, ScoreSdeVpScheduler from .scheduling_amused import AmusedScheduler from .scheduling_consistency_decoder import ConsistencyDecoderScheduler from .scheduling_consistency_models import CMStochasticIterativeScheduler diff --git a/src/diffusers/schedulers/deprecated/__init__.py b/src/diffusers/schedulers/deprecated/__init__.py index 2ae7bc1d16a3..786707f45206 100644 --- a/src/diffusers/schedulers/deprecated/__init__.py +++ b/src/diffusers/schedulers/deprecated/__init__.py @@ -21,7 +21,6 @@ _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) else: - _import_structure["scheduling_dpmsolver_multistep_legacy"] = ["DPMSolverMultistepSchedulerLegacy"] _import_structure["scheduling_karras_ve"] = ["KarrasVeScheduler"] _import_structure["scheduling_sde_vp"] = ["ScoreSdeVpScheduler"] @@ -33,7 +32,6 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_pt_objects import * # noqa F403 else: - from .scheduling_dpmsolver_multistep_legacy import DPMSolverMultistepSchedulerLegacy from .scheduling_karras_ve import KarrasVeScheduler from .scheduling_sde_vp import ScoreSdeVpScheduler diff --git a/src/diffusers/schedulers/deprecated/scheduling_dpmsolver_multistep_legacy.py b/src/diffusers/schedulers/deprecated/scheduling_dpmsolver_multistep_legacy.py deleted file mode 100644 index 8bdf349dab01..000000000000 --- a/src/diffusers/schedulers/deprecated/scheduling_dpmsolver_multistep_legacy.py +++ /dev/null @@ -1,842 +0,0 @@ -# Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver - -import math -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch - -from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import deprecate -from ...utils.torch_utils import randn_tensor -from ..scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput - - -# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - if alpha_transform_type == "cosine": - - def alpha_bar_fn(t): - return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 - - elif alpha_transform_type == "exp": - - def alpha_bar_fn(t): - return math.exp(t * -12.0) - - else: - raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float32) - - -class DPMSolverMultistepSchedulerLegacy(SchedulerMixin, ConfigMixin): - """ - `DPMSolverMultistepSchedulerLegacy` is a fast dedicated high-order solver for diffusion ODEs. - - This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic - methods the library implements for all schedulers such as loading and saving. - - Args: - num_train_timesteps (`int`, defaults to 1000): - The number of diffusion steps to train the model. - beta_start (`float`, defaults to 0.0001): - The starting `beta` value of inference. - beta_end (`float`, defaults to 0.02): - The final `beta` value. - beta_schedule (`str`, defaults to `"linear"`): - The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - trained_betas (`np.ndarray`, *optional*): - Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. - solver_order (`int`, defaults to 2): - The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided - sampling, and `solver_order=3` for unconditional sampling. - prediction_type (`str`, defaults to `epsilon`, *optional*): - Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), - `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen - Video](https://imagen.research.google/video/paper.pdf) paper). - thresholding (`bool`, defaults to `False`): - Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such - as Stable Diffusion. - dynamic_thresholding_ratio (`float`, defaults to 0.995): - The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. - sample_max_value (`float`, defaults to 1.0): - The threshold value for dynamic thresholding. Valid only when `thresholding=True` and - `algorithm_type="dpmsolver++"`. - algorithm_type (`str`, defaults to `dpmsolver`): - Algorithm type for the solver; can be `dpmsolver`or `sde-dpmsolver`. The - `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) - paper. - solver_type (`str`, defaults to `midpoint`): - Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the - sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. - lower_order_final (`bool`, defaults to `True`): - Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can - stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. - euler_at_final (`bool`, defaults to `False`): - Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail - richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference - steps, but sometimes may result in blurring. - use_karras_sigmas (`bool`, *optional*, defaults to `False`): - Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, - the sigmas are determined according to a sequence of noise levels {σi}. - use_lu_lambdas (`bool`, *optional*, defaults to `False`): - Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during - the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of - `lambda(t)`. - lambda_min_clipped (`float`, defaults to `-inf`): - Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the - cosine (`squaredcos_cap_v2`) noise schedule. - variance_type (`str`, *optional*): - Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output - contains the predicted Gaussian variance. - timestep_spacing (`str`, defaults to `"linspace"`): - The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. - steps_offset (`int`, defaults to 0): - An offset added to the inference steps. You can use a combination of `offset=1` and - `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable - Diffusion. - """ - - _compatibles = [e.name for e in KarrasDiffusionSchedulers] - order = 1 - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, - solver_order: int = 2, - prediction_type: str = "epsilon", - thresholding: bool = False, - dynamic_thresholding_ratio: float = 0.995, - sample_max_value: float = 1.0, - algorithm_type: str = "dpmsolver", # "dpmsolver", "sde-dpmsolver" - solver_type: str = "midpoint", - lower_order_final: bool = True, - euler_at_final: bool = False, - use_karras_sigmas: Optional[bool] = False, - use_lu_lambdas: Optional[bool] = False, - lambda_min_clipped: float = -float("inf"), - variance_type: Optional[str] = None, - timestep_spacing: str = "linspace", - steps_offset: int = 0, - ): - if trained_betas is not None: - self.betas = torch.tensor(trained_betas, dtype=torch.float32) - elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - # Currently we only support VP-type noise schedule - self.alpha_t = torch.sqrt(self.alphas_cumprod) - self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) - self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) - self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 - - # standard deviation of the initial noise distribution - self.init_noise_sigma = 1.0 - - # settings for DPM-Solver - if algorithm_type not in ["dpmsolver", "sde-dpmsolver"]: - raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") - - if solver_type not in ["midpoint", "heun"]: - if solver_type in ["logrho", "bh1", "bh2"]: - self.register_to_config(solver_type="midpoint") - else: - raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") - - # setable values - self.num_inference_steps = None - timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() - self.timesteps = torch.from_numpy(timesteps) - self.model_outputs = [None] * solver_order - self.lower_order_nums = 0 - self._step_index = None - self.sigmas.to("cpu") # to avoid too much CPU/GPU communication - - @property - def step_index(self): - """ - The index counter for current timestep. It will increae 1 after each scheduler step. - """ - return self._step_index - - def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): - """ - Sets the discrete timesteps used for the diffusion chain (to be run before inference). - - Args: - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - """ - # Clipping the minimum of all lambda(t) for numerical stability. - # This is critical for cosine (squaredcos_cap_v2) noise schedule. - clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) - last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item() - - # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 - if self.config.timestep_spacing == "linspace": - timesteps = ( - np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64) - ) - elif self.config.timestep_spacing == "leading": - step_ratio = last_timestep // (num_inference_steps + 1) - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) - timesteps += self.config.steps_offset - elif self.config.timestep_spacing == "trailing": - step_ratio = self.config.num_train_timesteps / num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) - timesteps -= 1 - else: - raise ValueError( - f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." - ) - - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - log_sigmas = np.log(sigmas) - - if self.config.use_karras_sigmas: - sigmas = np.flip(sigmas).copy() - sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) - elif self.config.use_lu_lambdas: - lambdas = np.flip(log_sigmas.copy()) - lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps) - sigmas = np.exp(lambdas) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) - else: - sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 - sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) - - self.sigmas = torch.from_numpy(sigmas) - self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) - - self.num_inference_steps = len(timesteps) - - self.model_outputs = [ - None, - ] * self.config.solver_order - self.lower_order_nums = 0 - - # add an index counter for schedulers that allow duplicated timesteps - self._step_index = None - self.sigmas.to("cpu") # to avoid too much CPU/GPU communication - - # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample - def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: - """ - "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the - prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by - s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing - pixels from saturation at each step. We find that dynamic thresholding results in significantly better - photorealism as well as better image-text alignment, especially when using very large guidance weights." - - https://arxiv.org/abs/2205.11487 - """ - dtype = sample.dtype - batch_size, channels, *remaining_dims = sample.shape - - if dtype not in (torch.float32, torch.float64): - sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half - - # Flatten sample for doing quantile calculation along each image - sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) - - abs_sample = sample.abs() # "a certain percentile absolute pixel value" - - s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) - s = torch.clamp( - s, min=1, max=self.config.sample_max_value - ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] - s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 - sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" - - sample = sample.reshape(batch_size, channels, *remaining_dims) - sample = sample.to(dtype) - - return sample - - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t - def _sigma_to_t(self, sigma, log_sigmas): - # get log sigma - log_sigma = np.log(np.maximum(sigma, 1e-10)) - - # get distribution - dists = log_sigma - log_sigmas[:, np.newaxis] - - # get sigmas range - low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) - high_idx = low_idx + 1 - - low = log_sigmas[low_idx] - high = log_sigmas[high_idx] - - # interpolate sigmas - w = (low - log_sigma) / (low - high) - w = np.clip(w, 0, 1) - - # transform interpolation to time range - t = (1 - w) * low_idx + w * high_idx - t = t.reshape(sigma.shape) - return t - - def _sigma_to_alpha_sigma_t(self, sigma): - alpha_t = 1 / ((sigma**2 + 1) ** 0.5) - sigma_t = sigma * alpha_t - - return alpha_t, sigma_t - - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras - def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: - """Constructs the noise schedule of Karras et al. (2022).""" - - # Hack to make sure that other schedulers which copy this function don't break - # TODO: Add this logic to the other schedulers - if hasattr(self.config, "sigma_min"): - sigma_min = self.config.sigma_min - else: - sigma_min = None - - if hasattr(self.config, "sigma_max"): - sigma_max = self.config.sigma_max - else: - sigma_max = None - - sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() - sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() - - rho = 7.0 # 7.0 is the value used in the paper - ramp = np.linspace(0, 1, num_inference_steps) - min_inv_rho = sigma_min ** (1 / rho) - max_inv_rho = sigma_max ** (1 / rho) - sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho - return sigmas - - def _convert_to_lu(self, in_lambdas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: - """Constructs the noise schedule of Lu et al. (2022).""" - - lambda_min: float = in_lambdas[-1].item() - lambda_max: float = in_lambdas[0].item() - - rho = 1.0 # 1.0 is the value used in the paper - ramp = np.linspace(0, 1, num_inference_steps) - min_inv_rho = lambda_min ** (1 / rho) - max_inv_rho = lambda_max ** (1 / rho) - lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho - return lambdas - - def convert_model_output( - self, - model_output: torch.FloatTensor, - *args, - sample: torch.FloatTensor = None, - **kwargs, - ) -> torch.FloatTensor: - """ - Convert the model output. DPM-Solver is designed to discretize an integral of the noise prediction model. - - - - The algorithm and model type are decoupled. You can use either DPMSolver for both noise - prediction and data prediction models. - - - - Args: - model_output (`torch.FloatTensor`): - The direct output from the learned diffusion model. - sample (`torch.FloatTensor`): - A current instance of a sample created by the diffusion process. - - Returns: - `torch.FloatTensor`: - The converted model output. - """ - timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) - if sample is None: - if len(args) > 1: - sample = args[1] - else: - raise ValueError("missing `sample` as a required keyward argument") - if timestep is not None: - deprecate( - "timesteps", - "1.0.0", - "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - # DPM-Solver needs to solve an integral of the noise prediction model. - if self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: - if self.config.prediction_type == "epsilon": - # DPM-Solver only need the "mean" output. - if self.config.variance_type in ["learned", "learned_range"]: - epsilon = model_output[:, :3] - else: - epsilon = model_output - elif self.config.prediction_type == "sample": - sigma = self.sigmas[self.step_index] - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - epsilon = (sample - alpha_t * model_output) / sigma_t - elif self.config.prediction_type == "v_prediction": - sigma = self.sigmas[self.step_index] - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - epsilon = alpha_t * model_output + sigma_t * sample - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" - " `v_prediction` for the DPMSolverMultistepScheduler." - ) - - if self.config.thresholding: - sigma = self.sigmas[self.step_index] - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - x0_pred = (sample - sigma_t * epsilon) / alpha_t - x0_pred = self._threshold_sample(x0_pred) - epsilon = (sample - alpha_t * x0_pred) / sigma_t - - return epsilon - - def dpm_solver_first_order_update( - self, - model_output: torch.FloatTensor, - *args, - sample: torch.FloatTensor = None, - noise: Optional[torch.FloatTensor] = None, - **kwargs, - ) -> torch.FloatTensor: - """ - One step for the first-order DPMSolver (equivalent to DDIM). - - Args: - model_output (`torch.FloatTensor`): - The direct output from the learned diffusion model. - sample (`torch.FloatTensor`): - A current instance of a sample created by the diffusion process. - - Returns: - `torch.FloatTensor`: - The sample tensor at the previous timestep. - """ - timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) - prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) - if sample is None: - if len(args) > 2: - sample = args[2] - else: - raise ValueError(" missing `sample` as a required keyward argument") - if timestep is not None: - deprecate( - "timesteps", - "1.0.0", - "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - if prev_timestep is not None: - deprecate( - "prev_timestep", - "1.0.0", - "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) - alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) - lambda_t = torch.log(alpha_t) - torch.log(sigma_t) - lambda_s = torch.log(alpha_s) - torch.log(sigma_s) - - h = lambda_t - lambda_s - if self.config.algorithm_type == "dpmsolver": - x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output - elif self.config.algorithm_type == "sde-dpmsolver": - assert noise is not None - x_t = ( - (alpha_t / alpha_s) * sample - - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output - + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise - ) - return x_t - - def multistep_dpm_solver_second_order_update( - self, - model_output_list: List[torch.FloatTensor], - *args, - sample: torch.FloatTensor = None, - noise: Optional[torch.FloatTensor] = None, - **kwargs, - ) -> torch.FloatTensor: - """ - One step for the second-order multistep DPMSolver. - - Args: - model_output_list (`List[torch.FloatTensor]`): - The direct outputs from learned diffusion model at current and latter timesteps. - sample (`torch.FloatTensor`): - A current instance of a sample created by the diffusion process. - - Returns: - `torch.FloatTensor`: - The sample tensor at the previous timestep. - """ - timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) - prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) - if sample is None: - if len(args) > 2: - sample = args[2] - else: - raise ValueError(" missing `sample` as a required keyward argument") - if timestep_list is not None: - deprecate( - "timestep_list", - "1.0.0", - "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - if prev_timestep is not None: - deprecate( - "prev_timestep", - "1.0.0", - "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - sigma_t, sigma_s0, sigma_s1 = ( - self.sigmas[self.step_index + 1], - self.sigmas[self.step_index], - self.sigmas[self.step_index - 1], - ) - - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) - alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) - alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) - - lambda_t = torch.log(alpha_t) - torch.log(sigma_t) - lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) - lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) - - m0, m1 = model_output_list[-1], model_output_list[-2] - - h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 - r0 = h_0 / h - D0, D1 = m0, (1.0 / r0) * (m0 - m1) - if self.config.algorithm_type == "dpmsolver": - # See https://arxiv.org/abs/2206.00927 for detailed derivations - if self.config.solver_type == "midpoint": - x_t = ( - (alpha_t / alpha_s0) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 - ) - elif self.config.solver_type == "heun": - x_t = ( - (alpha_t / alpha_s0) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - ) - if self.config.algorithm_type == "sde-dpmsolver": - assert noise is not None - if self.config.solver_type == "midpoint": - x_t = ( - (alpha_t / alpha_s0) * sample - - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * (torch.exp(h) - 1.0)) * D1 - + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise - ) - elif self.config.solver_type == "heun": - x_t = ( - (alpha_t / alpha_s0) * sample - - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise - ) - return x_t - - def multistep_dpm_solver_third_order_update( - self, - model_output_list: List[torch.FloatTensor], - *args, - sample: torch.FloatTensor = None, - **kwargs, - ) -> torch.FloatTensor: - """ - One step for the third-order multistep DPMSolver. - - Args: - model_output_list (`List[torch.FloatTensor]`): - The direct outputs from learned diffusion model at current and latter timesteps. - sample (`torch.FloatTensor`): - A current instance of a sample created by diffusion process. - - Returns: - `torch.FloatTensor`: - The sample tensor at the previous timestep. - """ - - timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) - prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) - if sample is None: - if len(args) > 2: - sample = args[2] - else: - raise ValueError(" missing`sample` as a required keyward argument") - if timestep_list is not None: - deprecate( - "timestep_list", - "1.0.0", - "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - if prev_timestep is not None: - deprecate( - "prev_timestep", - "1.0.0", - "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( - self.sigmas[self.step_index + 1], - self.sigmas[self.step_index], - self.sigmas[self.step_index - 1], - self.sigmas[self.step_index - 2], - ) - - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) - alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) - alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) - alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) - - lambda_t = torch.log(alpha_t) - torch.log(sigma_t) - lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) - lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) - lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) - - m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] - - h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 - r0, r1 = h_0 / h, h_1 / h - D0 = m0 - D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) - D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) - D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) - - if self.config.algorithm_type == "dpmsolver": - # See https://arxiv.org/abs/2206.00927 for detailed derivations - x_t = ( - (alpha_t / alpha_s0) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 - ) - return x_t - - def _init_step_index(self, timestep): - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - - index_candidates = (self.timesteps == timestep).nonzero() - - if len(index_candidates) == 0: - step_index = len(self.timesteps) - 1 - # The sigma index that is taken for the **very** first `step` - # is always the second index (or the last index if there is only 1) - # This way we can ensure we don't accidentally skip a sigma in - # case we start in the middle of the denoising schedule (e.g. for image-to-image) - elif len(index_candidates) > 1: - step_index = index_candidates[1].item() - else: - step_index = index_candidates[0].item() - - self._step_index = step_index - - def step( - self, - model_output: torch.FloatTensor, - timestep: int, - sample: torch.FloatTensor, - generator=None, - return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: - """ - Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with - the multistep DPMSolver. - - Args: - model_output (`torch.FloatTensor`): - The direct output from learned diffusion model. - timestep (`int`): - The current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - A current instance of a sample created by the diffusion process. - generator (`torch.Generator`, *optional*): - A random number generator. - return_dict (`bool`): - Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. - - Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a - tuple is returned where the first element is the sample tensor. - - """ - if self.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - if self.step_index is None: - self._init_step_index(timestep) - - # Improve numerical stability for small number of steps - lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( - self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15) - ) - lower_order_second = ( - (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 - ) - - model_output = self.convert_model_output(model_output, sample=sample) - for i in range(self.config.solver_order - 1): - self.model_outputs[i] = self.model_outputs[i + 1] - self.model_outputs[-1] = model_output - - if self.config.algorithm_type in ["sde-dpmsolver"]: - noise = randn_tensor( - model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype - ) - else: - noise = None - - if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: - prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) - elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: - prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) - else: - prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample) - - if self.lower_order_nums < self.config.solver_order: - self.lower_order_nums += 1 - - # upon completion increase step index by one - self._step_index += 1 - - if not return_dict: - return (prev_sample,) - - return SchedulerOutput(prev_sample=prev_sample) - - def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - - Args: - sample (`torch.FloatTensor`): - The input sample. - - Returns: - `torch.FloatTensor`: - A scaled input sample. - """ - return sample - - def add_noise( - self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.IntTensor, - ) -> torch.FloatTensor: - # Make sure sigmas and timesteps have the same device and dtype as original_samples - sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) - if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): - # mps does not support float64 - schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) - timesteps = timesteps.to(original_samples.device, dtype=torch.float32) - else: - schedule_timesteps = self.timesteps.to(original_samples.device) - timesteps = timesteps.to(original_samples.device) - - step_indices = [] - for timestep in timesteps: - index_candidates = (schedule_timesteps == timestep).nonzero() - if len(index_candidates) == 0: - step_index = len(schedule_timesteps) - 1 - elif len(index_candidates) > 1: - step_index = index_candidates[1].item() - else: - step_index = index_candidates[0].item() - step_indices.append(step_index) - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < len(original_samples.shape): - sigma = sigma.unsqueeze(-1) - - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - noisy_samples = alpha_t * original_samples + sigma_t * noise - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index c2c6bc90969e..ddbd550b2b6e 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -106,7 +106,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++"`. algorithm_type (`str`, defaults to `dpmsolver++`): - Algorithm type for the solver; can be `dpmsolver++` or `sde-dpmsolver++`. It implements the algorithms in the + Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The + `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) + paper, and the `dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. solver_type (`str`, defaults to `midpoint`): @@ -194,13 +196,9 @@ def __init__( self.init_noise_sigma = 1.0 # settings for DPM-Solver - if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"]: + if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: if algorithm_type == "deis": self.register_to_config(algorithm_type="dpmsolver++") - elif algorithm_type in ["dpmsolver", "sde-dpmsolver"]: - raise ValueError( - f"`algorithm_type` {algorithm_type} is no longer supported in {self.__class__}. Please use `DPMSolverSchedulerLegacy` instead." - ) else: raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") @@ -210,6 +208,11 @@ def __init__( else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "denoise_to_zero": + raise ValueError( + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithem_type` {algorithm_type}." + ) + # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() @@ -429,11 +432,14 @@ def convert_model_output( **kwargs, ) -> torch.FloatTensor: """ - Convert the model output to predict data. DPM-Solver++ is designed to discretize an integral of the data prediction model. + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. - The algorithm and model type are decoupled. You can use DPMSolver++ for either noise prediction and data prediction models. + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. @@ -463,7 +469,7 @@ def convert_model_output( # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: if self.config.prediction_type == "epsilon": - # DPM-Solver++ only need the "mean" output. + # DPM-Solver and DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned", "learned_range"]: model_output = model_output[:, :3] sigma = self.sigmas[self.step_index] @@ -486,6 +492,37 @@ def convert_model_output( return x0_pred + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + epsilon = model_output[:, :3] + else: + epsilon = model_output + elif self.config.prediction_type == "sample": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = (sample - alpha_t * model_output) / sigma_t + elif self.config.prediction_type == "v_prediction": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = alpha_t * model_output + sigma_t * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = (sample - sigma_t * epsilon) / alpha_t + x0_pred = self._threshold_sample(x0_pred) + epsilon = (sample - alpha_t * x0_pred) / sigma_t + + return epsilon + def dpm_solver_first_order_update( self, model_output: torch.FloatTensor, @@ -537,6 +574,8 @@ def dpm_solver_first_order_update( h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output elif self.config.algorithm_type == "sde-dpmsolver++": assert noise is not None x_t = ( @@ -544,6 +583,13 @@ def dpm_solver_first_order_update( + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ( + (alpha_t / alpha_s) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) return x_t def multistep_dpm_solver_second_order_update( @@ -621,6 +667,20 @@ def multistep_dpm_solver_second_order_update( - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + ) elif self.config.algorithm_type == "sde-dpmsolver++": assert noise is not None if self.config.solver_type == "midpoint": @@ -637,6 +697,22 @@ def multistep_dpm_solver_second_order_update( + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * (torch.exp(h) - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) return x_t def multistep_dpm_solver_third_order_update( @@ -714,6 +790,14 @@ def multistep_dpm_solver_third_order_update( + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 + ) return x_t def _init_step_index(self, timestep): @@ -788,7 +872,7 @@ def step( self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output - if self.config.algorithm_type in ["sde-dpmsolver++"]: + if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype ) From 41e25e99067280ef3cc536fb24492c4cfbfc1d7e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 18 Jan 2024 19:25:51 +0000 Subject: [PATCH 19/25] fix --- src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index ddbd550b2b6e..620d0a2d1d52 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -210,7 +210,7 @@ def __init__( if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "denoise_to_zero": raise ValueError( - f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithem_type` {algorithm_type}." + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}." ) # setable values From 085498d0501b249b0be1cde88788bc346d6f1caf Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 18 Jan 2024 19:54:04 +0000 Subject: [PATCH 20/25] dpm inverse --- .../scheduling_dpmsolver_multistep_inverse.py | 39 +++++-------------- 1 file changed, 10 insertions(+), 29 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index c7784da7fae2..c8e51af702ff 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -160,12 +160,15 @@ def __init__( lower_order_final: bool = True, euler_at_final: bool = False, use_karras_sigmas: Optional[bool] = False, - final_sigmas_type: Optional[str] = "default", # "denoise_to_zero", "default" lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, timestep_spacing: str = "linspace", steps_offset: int = 0, ): + if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message, warning=True) + if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": @@ -203,11 +206,6 @@ def __init__( else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") - if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "denoise_to_zero": - raise ValueError( - f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}." - ) - # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32).copy() @@ -270,27 +268,13 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = timesteps.copy().astype(np.int64) - if self.config.final_sigmas_type == "default": - sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) - elif self.config.final_sigmas_type == "denoise_to_zero": - sigmas = np.concatenate([sigmas, np.array([0])]).astype(np.float32) - else: - raise ValueError( - f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" - ) + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - if self.config.final_sigmas_type == "default": - sigma_last = ( - (1 - self.alphas_cumprod[self.noisiest_timestep]) / self.alphas_cumprod[self.noisiest_timestep] - ) ** 0.5 - elif self.config.final_sigmas_type == "denoise_to_zero": - sigma_last = 0 - else: - raise ValueError( - f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" - ) - sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + sigma_max = ( + (1 - self.alphas_cumprod[self.noisiest_timestep]) / self.alphas_cumprod[self.noisiest_timestep] + ) ** 0.5 + sigmas = np.concatenate([sigmas, [sigma_max]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) @@ -803,7 +787,6 @@ def _init_step_index(self, timestep): self._step_index = step_index - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step def step( self, model_output: torch.FloatTensor, @@ -844,9 +827,7 @@ def step( # Improve numerical stability for small number of steps lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( - self.config.euler_at_final - or (self.config.lower_order_final and len(self.timesteps) < 15) - or self.config.final_sigmas_type == "denoise_to_zero" + self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15) ) lower_order_second = ( (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 From 4666722221673122da4985862f45a53e30af747a Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 18 Jan 2024 20:33:15 +0000 Subject: [PATCH 21/25] fix --- .../scheduling_dpmsolver_multistep.py | 50 ++++++++----------- .../scheduling_dpmsolver_singlestep.py | 46 ++++++++--------- 2 files changed, 44 insertions(+), 52 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 620d0a2d1d52..a6d6a4b642e6 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -128,6 +128,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of `lambda(t)`. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma + is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. lambda_min_clipped (`float`, defaults to `-inf`): Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the cosine (`squaredcos_cap_v2`) noise schedule. @@ -164,13 +167,17 @@ def __init__( lower_order_final: bool = True, euler_at_final: bool = False, use_karras_sigmas: Optional[bool] = False, - final_sigmas_type: Optional[str] = "default", # "denoise_to_zero", "default" use_lu_lambdas: Optional[bool] = False, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, timestep_spacing: str = "linspace", steps_offset: int = 0, ): + if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message, warning=True) + if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": @@ -208,9 +215,9 @@ def __init__( else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") - if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "denoise_to_zero": + if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero": raise ValueError( - f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}." + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead." ) # setable values @@ -273,39 +280,24 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - if self.config.final_sigmas_type == "default": - sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) - elif self.config.final_sigmas_type == "denoise_to_zero": - sigmas = np.concatenate([sigmas, np.array([0])]).astype(np.float32) - else: - raise ValueError( - f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" - ) elif self.config.use_lu_lambdas: lambdas = np.flip(log_sigmas.copy()) lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps) sigmas = np.exp(lambdas) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - if self.config.final_sigmas_type == "default": - sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) - elif self.config.final_sigmas_type == "denoise_to_zero": - sigmas = np.concatenate([sigmas, np.array([0])]).astype(np.float32) - else: - raise ValueError( - f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" - ) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - if self.config.final_sigmas_type == "default": - sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 - elif self.config.final_sigmas_type == "denoise_to_zero": - sigma_last = 0 - else: - raise ValueError( - f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" - ) - sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) @@ -861,7 +853,7 @@ def step( lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15) - or self.config.final_sigmas_type == "denoise_to_zero" + or self.config.final_sigmas_type == "zero" ) lower_order_second = ( (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index fb8fd80930b6..84b04cb644e2 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -108,7 +108,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++"`. algorithm_type (`str`, defaults to `dpmsolver++`): - Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The + Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or @@ -122,6 +122,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. + final_sigmas_type (`str`, *optional*, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma + is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. lambda_min_clipped (`float`, defaults to `-inf`): Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the cosine (`squaredcos_cap_v2`) noise schedule. @@ -150,10 +153,14 @@ def __init__( solver_type: str = "midpoint", lower_order_final: bool = True, use_karras_sigmas: Optional[bool] = False, - final_sigmas_type: Optional[str] = "default", # "denoise_to_zero", "default" + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, ): + if algorithm_type == "dpmsolver": + deprecation_message = "algorithm_type `dpmsolver` is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + deprecate("algorithm_types=dpmsolver", "1.0.0", deprecation_message, warning=True) + if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": @@ -190,9 +197,9 @@ def __init__( else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") - if algorithm_type != "dpmsolver++" and final_sigmas_type == "denoise_to_zero": + if algorithm_type != "dpmsolver++" and final_sigmas_type == "zero": raise ValueError( - f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}." + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please chooose `sigma_min` instead." ) # setable values @@ -273,25 +280,18 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - if self.config.final_sigmas_type == "default": - sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) - elif self.config.final_sigmas_type == "denoise_to_zero": - sigmas = np.concatenate([sigmas, np.array([0.0])]).astype(np.float32) - else: - raise ValueError( - f" `final_sigmas_type` must be one of `default` or `denoise_to_zero`, but got {self.config.final_sigmas_type}" - ) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - if self.config.final_sigmas_type == "default": - sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 - elif self.config.final_sigmas_type == "denoise_to_zero": - sigma_last = 0 - else: - raise ValueError( - f" `final_sigmas_type` must be one of `default` or `denoise_to_zero`, but got {self.config.final_sigmas_type}" - ) - sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f" `final_sigmas_type` must be one of `sigma_min` or `zero`, but got {self.config.final_sigmas_type}" + ) + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) @@ -305,9 +305,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic ) self.register_to_config(lower_order_final=True) - if not self.config.lower_order_final and self.config.final_sigmas_type == "denoise_to_zero": + if not self.config.lower_order_final and self.config.final_sigmas_type == "zero": logger.warn( - "denoise_to_zero is not supported for `lower_order_final=False`. Changing scheduler {self.config} to have `lower_order_final` set to True." + " `last_sigmas_type='zero'` is not supported for `lower_order_final=False`. Changing scheduler {self.config} to have `lower_order_final` set to True." ) self.register_to_config(lower_order_final=True) From b68106d910cb81324be4791248cdb002e51291a8 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 18 Jan 2024 20:59:36 +0000 Subject: [PATCH 22/25] fix deprecate --- src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | 2 +- .../schedulers/scheduling_dpmsolver_multistep_inverse.py | 2 +- src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index a6d6a4b642e6..44ac6f2d0cf9 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -176,7 +176,7 @@ def __init__( ): if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" - deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message, warning=True) + deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message) if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index c8e51af702ff..aea3c4c4b9f6 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -167,7 +167,7 @@ def __init__( ): if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" - deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message, warning=True) + deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message) if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 84b04cb644e2..b92faf359920 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -159,7 +159,7 @@ def __init__( ): if algorithm_type == "dpmsolver": deprecation_message = "algorithm_type `dpmsolver` is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" - deprecate("algorithm_types=dpmsolver", "1.0.0", deprecation_message, warning=True) + deprecate("algorithm_types=dpmsolver", "1.0.0", deprecation_message) if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) From 91c02fded021ed6cd8d2533b6d05d66dac0433d1 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 18 Jan 2024 21:09:01 +0000 Subject: [PATCH 23/25] fix tests --- tests/schedulers/test_scheduler_dpm_multi.py | 1 + tests/schedulers/test_scheduler_dpm_single.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py index 7fe71941b4e7..515cea5bc4ba 100644 --- a/tests/schedulers/test_scheduler_dpm_multi.py +++ b/tests/schedulers/test_scheduler_dpm_multi.py @@ -32,6 +32,7 @@ def get_scheduler_config(self, **kwargs): "euler_at_final": False, "lambda_min_clipped": -float("inf"), "variance_type": None, + "final_sigmas_type": "sigma_min", } config.update(**kwargs) diff --git a/tests/schedulers/test_scheduler_dpm_single.py b/tests/schedulers/test_scheduler_dpm_single.py index b1f14cb530e6..26c7e5329ba1 100644 --- a/tests/schedulers/test_scheduler_dpm_single.py +++ b/tests/schedulers/test_scheduler_dpm_single.py @@ -30,6 +30,7 @@ def get_scheduler_config(self, **kwargs): "solver_type": "midpoint", "lambda_min_clipped": -float("inf"), "variance_type": None, + "final_sigmas_type": "sigma_min" } config.update(**kwargs) From 643cdc1d0a4ba7ff4b6bc6022f98c9bc1670f500 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 18 Jan 2024 21:15:39 +0000 Subject: [PATCH 24/25] style --- tests/schedulers/test_scheduler_dpm_single.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/schedulers/test_scheduler_dpm_single.py b/tests/schedulers/test_scheduler_dpm_single.py index 26c7e5329ba1..251a1506f39b 100644 --- a/tests/schedulers/test_scheduler_dpm_single.py +++ b/tests/schedulers/test_scheduler_dpm_single.py @@ -30,7 +30,7 @@ def get_scheduler_config(self, **kwargs): "solver_type": "midpoint", "lambda_min_clipped": -float("inf"), "variance_type": None, - "final_sigmas_type": "sigma_min" + "final_sigmas_type": "sigma_min", } config.update(**kwargs) From 64e35fd0e6acf3f8c5b6c5a19116b596612a1a77 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 18 Jan 2024 22:59:01 +0000 Subject: [PATCH 25/25] fix format --- src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | 2 +- src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 44ac6f2d0cf9..8b60c02929d6 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -130,7 +130,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): `lambda(t)`. final_sigmas_type (`str`, defaults to `"zero"`): The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma - is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. lambda_min_clipped (`float`, defaults to `-inf`): Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the cosine (`squaredcos_cap_v2`) noise schedule. diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index b92faf359920..f4ee8b8bdc5d 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -124,7 +124,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): the sigmas are determined according to a sequence of noise levels {σi}. final_sigmas_type (`str`, *optional*, defaults to `"zero"`): The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma - is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. lambda_min_clipped (`float`, defaults to `-inf`): Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the cosine (`squaredcos_cap_v2`) noise schedule.