From f90595882236f688c29cc58031dec8374d6de03c Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 3 Jul 2023 15:55:16 +0200 Subject: [PATCH 01/18] Add timestep_spacing to DDPM, LMSDiscrete, PNDM. --- src/diffusers/schedulers/scheduling_ddim.py | 2 +- src/diffusers/schedulers/scheduling_ddpm.py | 35 ++++++++++++++-- .../schedulers/scheduling_lms_discrete.py | 40 +++++++++++++++++-- src/diffusers/schedulers/scheduling_pndm.py | 29 +++++++++++--- 4 files changed, 93 insertions(+), 13 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index bab6f8acea03..3ffa58c55f2d 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -302,7 +302,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.num_inference_steps = num_inference_steps - # "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891 + # "leading" and "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 5d24766d68c7..43d252f1c3e8 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -114,6 +114,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, default `"leading"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -134,6 +141,8 @@ def __init__( dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + steps_offset: int = 1, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -228,11 +237,31 @@ def set_timesteps( ) self.num_inference_steps = num_inference_steps - - step_ratio = self.config.num_train_timesteps // self.num_inference_steps - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) self.custom_timesteps = False + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps).round()[ + ::-1 + ].copy().astype(np.int64) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + self.timesteps = torch.from_numpy(timesteps).to(device) def _get_variance(self, t, predicted_variance=None, variance_type=None): diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 0656475c3093..cd323f0de140 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -102,6 +102,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -117,6 +124,8 @@ def __init__( trained_betas: Optional[Union[np.ndarray, List[float]]] = None, use_karras_sigmas: Optional[bool] = False, prediction_type: str = "epsilon", + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -140,9 +149,6 @@ def __init__( sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - # setable values self.num_inference_steps = None self.use_karras_sigmas = use_karras_sigmas @@ -150,6 +156,14 @@ def __init__( self.derivatives = [] self.is_scale_input_called = False + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing == "linspace": + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: @@ -207,6 +221,26 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 01c02a21bbfc..7ce348c7e2d0 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -85,11 +85,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + timestep_spacing (`str`, default `"leading"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. - """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -106,6 +108,7 @@ def __init__( skip_prk_steps: bool = False, set_alpha_to_one: bool = False, prediction_type: str = "epsilon", + timestep_spacing: str = "leading", steps_offset: int = 0, ): if trained_betas is not None: @@ -159,11 +162,25 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic """ self.num_inference_steps = num_inference_steps - step_ratio = self.config.num_train_timesteps // self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round() - self._timesteps += self.config.steps_offset + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + self._timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps).round().astype(np.int64) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round() + self._timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + self._timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio))[::-1].astype(np.int64) + self._timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) if self.config.skip_prk_steps: # for some models like stable diffusion the prk steps can/should be skipped to From d9b68b974818c68f2f9ceebfee28235a18da161d Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 3 Jul 2023 16:02:23 +0200 Subject: [PATCH 02/18] Remove spurious line. --- src/diffusers/schedulers/scheduling_lms_discrete.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index cd323f0de140..c8fda1df8021 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -219,8 +219,6 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic """ self.num_inference_steps = num_inference_steps - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() - # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() From eedf428eb51cd6a8c1b9bef94e76128d57863699 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 3 Jul 2023 16:27:10 +0200 Subject: [PATCH 03/18] More easy schedulers. --- .../scheduling_euler_ancestral_discrete.py | 41 +++++++++++++++++-- .../schedulers/scheduling_euler_discrete.py | 16 ++++++-- .../schedulers/scheduling_heun_discrete.py | 40 ++++++++++++++++-- .../scheduling_k_dpm_2_ancestral_discrete.py | 40 ++++++++++++++++-- .../schedulers/scheduling_k_dpm_2_discrete.py | 40 ++++++++++++++++-- 5 files changed, 157 insertions(+), 20 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 6b08e9bfc207..5ec6118c6405 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -99,6 +99,13 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ @@ -114,6 +121,8 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -137,15 +146,20 @@ def __init__( sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.is_scale_input_called = False + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing == "linspace": + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: @@ -179,7 +193,26 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic """ self.num_inference_steps = num_inference_steps - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 1a55a1b0e0d8..e8825a2cbdb8 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -205,17 +205,25 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic """ self.num_inference_steps = num_inference_steps - # "linspace" and "leading" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891 + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[ - ::-1 - ].copy() + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index 100e2012ea20..e734ef4b91e8 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -78,6 +78,13 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -93,6 +100,8 @@ def __init__( trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", use_karras_sigmas: Optional[bool] = False, + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -128,6 +137,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): pos = 0 return indices[pos].item() + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing == "linspace": + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + def scale_model_input( self, sample: torch.FloatTensor, @@ -166,7 +183,25 @@ def set_timesteps( num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps - timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) @@ -180,9 +215,6 @@ def set_timesteps( sigmas = torch.from_numpy(sigmas).to(device=device) self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - timesteps = torch.from_numpy(timesteps) timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index 2fa0431e1292..f7591352d193 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -78,6 +78,13 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -92,6 +99,8 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -127,6 +136,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): pos = 0 return indices[pos].item() + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing == "linspace": + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + def scale_model_input( self, sample: torch.FloatTensor, @@ -169,7 +186,25 @@ def set_timesteps( num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps - timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device) @@ -197,9 +232,6 @@ def set_timesteps( self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]]) self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]]) - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - if str(device).startswith("mps"): # mps does not support float64 timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index bb80c4a54bfe..0b7d9c7967ab 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -77,6 +77,13 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -91,6 +98,8 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -126,6 +135,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): pos = 0 return indices[pos].item() + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing == "linspace": + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + def scale_model_input( self, sample: torch.FloatTensor, @@ -168,7 +185,25 @@ def set_timesteps( num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps - timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device) @@ -185,9 +220,6 @@ def set_timesteps( [sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]] ) - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - if str(device).startswith("mps"): # mps does not support float64 timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) From 9f0db7bff90402b671972b12ddb2cb2963846d3c Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 3 Jul 2023 16:30:13 +0200 Subject: [PATCH 04/18] Add `linspace` to DDIM --- src/diffusers/schedulers/scheduling_ddim.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 3ffa58c55f2d..df5b72bde3c9 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -302,8 +302,12 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.num_inference_steps = num_inference_steps - # "leading" and "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 - if self.config.timestep_spacing == "leading": + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps).round()[ + ::-1 + ].copy().astype(np.int64) + elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 From 2bde0780f7ece8625cd79fd0bdb490523bb81e23 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 3 Jul 2023 17:36:37 +0200 Subject: [PATCH 05/18] Noise sigma for `trailing`. --- src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py | 2 +- src/diffusers/schedulers/scheduling_euler_discrete.py | 2 +- src/diffusers/schedulers/scheduling_heun_discrete.py | 2 +- .../schedulers/scheduling_k_dpm_2_ancestral_discrete.py | 2 +- src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py | 2 +- src/diffusers/schedulers/scheduling_lms_discrete.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 5ec6118c6405..01c22dc428f5 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -155,7 +155,7 @@ def __init__( @property def init_noise_sigma(self): # standard deviation of the initial noise distribution - if self.config.timestep_spacing == "linspace": + if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index e8825a2cbdb8..654a8f5999ad 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -165,7 +165,7 @@ def __init__( @property def init_noise_sigma(self): # standard deviation of the initial noise distribution - if self.config.timestep_spacing == "linspace": + if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index e734ef4b91e8..28f29067a544 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -140,7 +140,7 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): @property def init_noise_sigma(self): # standard deviation of the initial noise distribution - if self.config.timestep_spacing == "linspace": + if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index f7591352d193..d4a35ab82502 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -139,7 +139,7 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): @property def init_noise_sigma(self): # standard deviation of the initial noise distribution - if self.config.timestep_spacing == "linspace": + if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index 0b7d9c7967ab..39079fde10d2 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -138,7 +138,7 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): @property def init_noise_sigma(self): # standard deviation of the initial noise distribution - if self.config.timestep_spacing == "linspace": + if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index c8fda1df8021..cfa3fe09d273 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -159,7 +159,7 @@ def __init__( @property def init_noise_sigma(self): # standard deviation of the initial noise distribution - if self.config.timestep_spacing == "linspace": + if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 From 2ad9bb18da117c407ae316778a19feecef58f284 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 3 Jul 2023 18:48:09 +0200 Subject: [PATCH 06/18] Add timestep_spacing to DEISMultistepScheduler. Not sure the range is the way it was intended. --- .../schedulers/scheduling_deis_multistep.py | 51 ++++++++++++++++--- 1 file changed, 44 insertions(+), 7 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 8ea001a882d0..88e0a0d1eef7 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -103,7 +103,13 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): lower_order_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically find this trick can stabilize the sampling of DEIS for steps < 15, especially for steps <= 10. - + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -125,6 +131,8 @@ def __init__( algorithm_type: str = "deis", solver_type: str = "logrho", lower_order_final: bool = True, + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -181,12 +189,41 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ - timesteps = ( - np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) - .round()[::-1][:-1] - .copy() - .astype(np.int64) - ) + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + self.config.timestep_spacing = "trailing" + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (np.arange(0, num_inference_steps + 1) * step_ratio) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + np.arange(self.config.num_train_timesteps, 0, -step_ratio) + .round() + .copy() + .astype(np.int64) + ) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) # when num_inference_steps == num_train_timesteps, we can end up with # duplicates in timesteps. From fcc5531b06e89a8f69ad8a07f82d589672a00deb Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 3 Jul 2023 20:06:17 +0200 Subject: [PATCH 07/18] Fix: remove line used to debug. --- src/diffusers/schedulers/scheduling_deis_multistep.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 88e0a0d1eef7..90898b94bc6b 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -190,7 +190,6 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 - self.config.timestep_spacing = "trailing" if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) From f6e47829a21a53a2fe3a00ad88eb7102e78824c6 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 3 Jul 2023 20:19:32 +0200 Subject: [PATCH 08/18] Support timestep_spacing in DPMSolverMultistep, DPMSolverSDE, UniPC --- .../scheduling_dpmsolver_multistep.py | 42 ++++++++++++-- .../schedulers/scheduling_dpmsolver_sde.py | 40 +++++++++++-- .../schedulers/scheduling_unipc_multistep.py | 56 +++++++++++++++++-- 3 files changed, 122 insertions(+), 16 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index e72b1bdc23b5..91e90c8ac265 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -218,12 +218,42 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc # Clipping the minimum of all lambda(t) for numerical stability. # This is critical for cosine (squaredcos_cap_v2) noise schedule. clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) - timesteps = ( - np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1) - .round()[::-1][:-1] - .copy() - .astype(np.int64) - ) + last_timestep = self.config.num_train_timesteps - clipped_idx + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, last_timestep - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = last_timestep // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (np.arange(0, num_inference_steps + 1) * step_ratio) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + np.arange(last_timestep, 0, -step_ratio) + .round() + .copy() + .astype(np.int64) + ) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) if self.use_karras_sigmas: sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index ae9229981152..da8b71788b75 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -133,6 +133,13 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. noise_sampler_seed (`int`, *optional*, defaults to `None`): The random seed to use for the noise sampler. If `None`, a random seed will be generated. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -149,6 +156,8 @@ def __init__( prediction_type: str = "epsilon", use_karras_sigmas: Optional[bool] = False, noise_sampler_seed: Optional[int] = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -187,6 +196,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): pos = 0 return indices[pos].item() + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + def scale_model_input( self, sample: torch.FloatTensor, @@ -226,7 +243,25 @@ def set_timesteps( num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps - timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) @@ -242,9 +277,6 @@ def set_timesteps( sigmas = torch.from_numpy(sigmas).to(device=device) self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - timesteps = torch.from_numpy(timesteps) second_order_timesteps = torch.from_numpy(second_order_timesteps) timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 2cce68f7d962..451929edfc09 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -117,6 +117,20 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): by disable the corrector at the first few steps (e.g., disable_corrector=[0]) solver_p (`SchedulerMixin`, default `None`): can be any other scheduler. If specified, the algorithm will become solver_p + UniC. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -140,6 +154,8 @@ def __init__( lower_order_final: bool = True, disable_corrector: List[int] = [], solver_p: SchedulerMixin = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -194,12 +210,40 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ - timesteps = ( - np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) - .round()[::-1][:-1] - .copy() - .astype(np.int64) - ) + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (np.arange(0, num_inference_steps + 1) * step_ratio) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + np.arange(self.config.num_train_timesteps, 0, -step_ratio) + .round() + .copy() + .astype(np.int64) + ) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) # when num_inference_steps == num_train_timesteps, we can end up with # duplicates in timesteps. From 5cede2b5631d646d1ed51512a7bd6dd334d89d47 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 4 Jul 2023 00:19:04 +0200 Subject: [PATCH 09/18] Fix: convert to numpy. --- src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 91e90c8ac265..ba5340a809e1 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -218,7 +218,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc # Clipping the minimum of all lambda(t) for numerical stability. # This is critical for cosine (squaredcos_cap_v2) noise schedule. clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) - last_timestep = self.config.num_train_timesteps - clipped_idx + last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item() # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": From 04b9a76c7f152da4b2001037d686885b5869e200 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 4 Jul 2023 01:14:10 +0200 Subject: [PATCH 10/18] Use sched. defaults when instantiating from_config For params not present in the original configuration. This makes it possible to switch pipeline schedulers even if they use different timestep_spacing (or any other param). --- src/diffusers/configuration_utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 1a030e467134..59ade8da4404 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -423,6 +423,10 @@ def _get_init_keys(cls): @classmethod def extract_init_dict(cls, config_dict, **kwargs): + # Skip keys that were not present in the original config, so default __init__ values were used + used_defaults = config_dict.get("_use_default_values", []) + config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"} + # 0. Copy origin config dict original_dict = dict(config_dict.items()) @@ -450,6 +454,10 @@ def extract_init_dict(cls, config_dict, **kwargs): else: compatible_classes = [] + # Keys not present in the config - default values were used + # TODO: remove the ones passed in kwargs? + used_default_keys = set(expected_keys) - set(original_dict.keys()) + expected_keys_comp_cls = set() for c in compatible_classes: expected_keys_c = cls._get_init_keys(c) @@ -503,6 +511,9 @@ def extract_init_dict(cls, config_dict, **kwargs): # 7. Define "hidden" config parameters that were saved for compatible classes hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict} + # 8. Take note of the parameters that were not present in the loaded config + hidden_config_dict["_use_default_values"] = used_default_keys + return init_dict, unused_kwargs, hidden_config_dict @classmethod From e9dc9fdf253c46aba421cc02c03381ad0ed6a6c6 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 4 Jul 2023 12:02:03 +0200 Subject: [PATCH 11/18] Apply suggestions from code review Co-authored-by: Patrick von Platen --- src/diffusers/configuration_utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 59ade8da4404..81509384e79a 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -454,10 +454,6 @@ def extract_init_dict(cls, config_dict, **kwargs): else: compatible_classes = [] - # Keys not present in the config - default values were used - # TODO: remove the ones passed in kwargs? - used_default_keys = set(expected_keys) - set(original_dict.keys()) - expected_keys_comp_cls = set() for c in compatible_classes: expected_keys_c = cls._get_init_keys(c) @@ -512,7 +508,7 @@ def extract_init_dict(cls, config_dict, **kwargs): hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict} # 8. Take note of the parameters that were not present in the loaded config - hidden_config_dict["_use_default_values"] = used_default_keys + hidden_config_dict["_use_default_values"] = expected_keys - set(init_dict) return init_dict, unused_kwargs, hidden_config_dict From a30f92a77888d2153a9487c9bf6f7f8282268ecf Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 4 Jul 2023 12:27:30 +0200 Subject: [PATCH 12/18] Missing args in DPMSolverMultistep --- .../schedulers/scheduling_dpmsolver_multistep.py | 9 +++++++++ .../schedulers/scheduling_euler_ancestral_discrete.py | 1 - 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index ba5340a809e1..b96c177ebe5e 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -134,6 +134,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -158,6 +165,8 @@ def __init__( use_karras_sigmas: Optional[bool] = False, lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 01c22dc428f5..132729b7eb88 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -106,7 +106,6 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. - """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] From a9ccadea710a88ed6ad6966b3d421dcfdacec638 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 4 Jul 2023 12:28:19 +0200 Subject: [PATCH 13/18] Test: default args not in config --- tests/schedulers/test_schedulers.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/schedulers/test_schedulers.py b/tests/schedulers/test_schedulers.py index a2d065f388bd..b2df100b5d3e 100755 --- a/tests/schedulers/test_schedulers.py +++ b/tests/schedulers/test_schedulers.py @@ -24,6 +24,8 @@ import diffusers from diffusers import ( + DDIMScheduler, + DiffusionPipeline, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, IPNDMScheduler, @@ -202,6 +204,31 @@ def test_save_load_from_different_config_comp_schedulers(self): assert cap_logger_2.out == "{'f'} was not found in config. Values will be initialized to default values.\n" assert cap_logger_3.out == "{'f'} was not found in config. Values will be initialized to default values.\n" + def test_default_arguments_not_in_config(self): + pipe = DiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-pipe", torch_dtype=torch.float16 + ) + assert pipe.scheduler.__class__ == DDIMScheduler + + # Default for PNDMScheduler + assert pipe.scheduler.config.timestep_spacing == "leading" + + # Switch to a different one, verify we use the default for that class + pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) + assert pipe.scheduler.config.timestep_spacing == "linspace" + + # Override with kwargs + pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") + assert pipe.scheduler.config.timestep_spacing == "trailing" + + # Verify overridden kwargs stick + pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) + assert pipe.scheduler.config.timestep_spacing == "trailing" + + # And stick + pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) + assert pipe.scheduler.config.timestep_spacing == "trailing" + class SchedulerCommonTest(unittest.TestCase): scheduler_classes = () From 1ac79fe835650d4126130048d08a9c65445841b9 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 4 Jul 2023 12:29:26 +0200 Subject: [PATCH 14/18] Style --- src/diffusers/configuration_utils.py | 2 +- src/diffusers/schedulers/scheduling_ddim.py | 9 ++++++--- src/diffusers/schedulers/scheduling_ddpm.py | 10 ++++++---- .../schedulers/scheduling_deis_multistep.py | 14 ++------------ .../scheduling_dpmsolver_multistep.py | 19 +++---------------- .../scheduling_euler_ancestral_discrete.py | 4 +++- .../schedulers/scheduling_euler_discrete.py | 4 +++- .../schedulers/scheduling_lms_discrete.py | 4 +++- src/diffusers/schedulers/scheduling_pndm.py | 8 ++++++-- .../schedulers/scheduling_unipc_multistep.py | 14 ++------------ 10 files changed, 35 insertions(+), 53 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 81509384e79a..33bc32f6b743 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -508,7 +508,7 @@ def extract_init_dict(cls, config_dict, **kwargs): hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict} # 8. Take note of the parameters that were not present in the loaded config - hidden_config_dict["_use_default_values"] = expected_keys - set(init_dict) + hidden_config_dict["_use_default_values"] = expected_keys - set(init_dict) return init_dict, unused_kwargs, hidden_config_dict diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index df5b72bde3c9..99602d14038b 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -304,9 +304,12 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps).round()[ - ::-1 - ].copy().astype(np.int64) + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 43d252f1c3e8..930e14108a67 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -241,9 +241,12 @@ def set_timesteps( # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps).round()[ - ::-1 - ].copy().astype(np.int64) + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio @@ -261,7 +264,6 @@ def set_timesteps( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) - self.timesteps = torch.from_numpy(timesteps).to(device) def _get_variance(self, t, predicted_variance=None, variance_type=None): diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 90898b94bc6b..95fb9870d883 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -201,23 +201,13 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1) # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = ( - (np.arange(0, num_inference_steps + 1) * step_ratio) - .round()[::-1][:-1] - .copy() - .astype(np.int64) - ) + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = ( - np.arange(self.config.num_train_timesteps, 0, -step_ratio) - .round() - .copy() - .astype(np.int64) - ) + timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64) timesteps -= 1 else: raise ValueError( diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index b96c177ebe5e..8c09fe5b8ca2 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -232,32 +232,19 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( - np.linspace(0, last_timestep - 1, num_inference_steps + 1) - .round()[::-1][:-1] - .copy() - .astype(np.int64) + np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = last_timestep // (num_inference_steps + 1) # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = ( - (np.arange(0, num_inference_steps + 1) * step_ratio) - .round()[::-1][:-1] - .copy() - .astype(np.int64) - ) + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = ( - np.arange(last_timestep, 0, -step_ratio) - .round() - .copy() - .astype(np.int64) - ) + timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) timesteps -= 1 else: raise ValueError( diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 132729b7eb88..6b8c2f1a8a28 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -194,7 +194,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[ + ::-1 + ].copy() elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 654a8f5999ad..26057599be5e 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -207,7 +207,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[ + ::-1 + ].copy() elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index cfa3fe09d273..1256660b843c 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -221,7 +221,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[ + ::-1 + ].copy() elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 7ce348c7e2d0..70ee1301129c 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -164,7 +164,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.num_inference_steps = num_inference_steps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": - self._timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps).round().astype(np.int64) + self._timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps).round().astype(np.int64) + ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio @@ -175,7 +177,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - self._timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio))[::-1].astype(np.int64) + self._timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio))[::-1].astype( + np.int64 + ) self._timesteps -= 1 else: raise ValueError( diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 451929edfc09..0fc5b5d7da8c 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -222,23 +222,13 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1) # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = ( - (np.arange(0, num_inference_steps + 1) * step_ratio) - .round()[::-1][:-1] - .copy() - .astype(np.int64) - ) + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = ( - np.arange(self.config.num_train_timesteps, 0, -step_ratio) - .round() - .copy() - .astype(np.int64) - ) + timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64) timesteps -= 1 else: raise ValueError( From 27c134dc80ee26b920067c6e6857746f98851cbe Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 4 Jul 2023 12:35:04 +0200 Subject: [PATCH 15/18] Fix scheduler name in test --- tests/schedulers/test_schedulers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/schedulers/test_schedulers.py b/tests/schedulers/test_schedulers.py index b2df100b5d3e..cecc5cd00499 100755 --- a/tests/schedulers/test_schedulers.py +++ b/tests/schedulers/test_schedulers.py @@ -210,7 +210,7 @@ def test_default_arguments_not_in_config(self): ) assert pipe.scheduler.__class__ == DDIMScheduler - # Default for PNDMScheduler + # Default for DDIMScheduler assert pipe.scheduler.config.timestep_spacing == "leading" # Switch to a different one, verify we use the default for that class From c6fe8b0b66de3171612f56dc43c61a860e66c817 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 4 Jul 2023 13:25:00 +0200 Subject: [PATCH 16/18] Remove duplicated entries --- src/diffusers/schedulers/scheduling_unipc_multistep.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 0fc5b5d7da8c..f9e1863c0ae0 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -124,13 +124,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. - timestep_spacing (`str`, default `"linspace"`): - The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample - Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. - steps_offset (`int`, default `0`): - an offset added to the inference steps. You can use a combination of `offset=1` and - `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in - stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] From 981cf960db2d7a3bb816420fc0bd21f4700ebb92 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 4 Jul 2023 13:30:16 +0200 Subject: [PATCH 17/18] Add test for solver_type This test currently fails in main. When switching from DEIS to UniPC, solver_type is "logrho" (the default value from DEIS), which gets translated to "bh1" by UniPC. This is different to the default value for UniPC: "bh2". This is where the translation happens: https://github.com/huggingface/diffusers/blob/36d22d0709dc19776e3016fb3392d0f5578b0ab2/src/diffusers/schedulers/scheduling_unipc_multistep.py#L171 --- tests/schedulers/test_schedulers.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/schedulers/test_schedulers.py b/tests/schedulers/test_schedulers.py index cecc5cd00499..2554c0d1ec15 100755 --- a/tests/schedulers/test_schedulers.py +++ b/tests/schedulers/test_schedulers.py @@ -25,11 +25,13 @@ import diffusers from diffusers import ( DDIMScheduler, + DEISMultistepScheduler, DiffusionPipeline, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, IPNDMScheduler, LMSDiscreteScheduler, + UniPCMultistepScheduler, VQDiffusionScheduler, logging, ) @@ -229,6 +231,19 @@ def test_default_arguments_not_in_config(self): pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) assert pipe.scheduler.config.timestep_spacing == "trailing" + def test_default_solver_type_after_switch(self): + pipe = DiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-pipe", torch_dtype=torch.float16 + ) + assert pipe.scheduler.__class__ == DDIMScheduler + + pipe.scheduler = DEISMultistepScheduler.from_config(pipe.scheduler.config) + assert pipe.scheduler.config.solver_type == "logrho" + + # Switch to UniPC, verify the solver is the default + pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + assert pipe.scheduler.config.solver_type == "bh2" + class SchedulerCommonTest(unittest.TestCase): scheduler_classes = () From 1774aad51e293b85de1efb57dc9847dfd190c783 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 4 Jul 2023 14:46:55 +0200 Subject: [PATCH 18/18] UniPC: use same default for solver_type Fixes a bug when switching from UniPC from another scheduler (i.e., DEIS) that uses a different solver type. The solver is now the same as if we had instantiated the scheduler directly. --- src/diffusers/schedulers/scheduling_unipc_multistep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index f9e1863c0ae0..24aa36d660eb 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -177,7 +177,7 @@ def __init__( if solver_type not in ["bh1", "bh2"]: if solver_type in ["midpoint", "heun", "logrho"]: - self.register_to_config(solver_type="bh1") + self.register_to_config(solver_type="bh2") else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")