Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,10 @@ def _get_init_keys(cls):

@classmethod
def extract_init_dict(cls, config_dict, **kwargs):
# Skip keys that were not present in the original config, so default __init__ values were used
used_defaults = config_dict.get("_use_default_values", [])
config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"}

# 0. Copy origin config dict
original_dict = dict(config_dict.items())

Expand Down Expand Up @@ -503,6 +507,9 @@ def extract_init_dict(cls, config_dict, **kwargs):
# 7. Define "hidden" config parameters that were saved for compatible classes
hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}

# 8. Take note of the parameters that were not present in the loaded config
hidden_config_dict["_use_default_values"] = expected_keys - set(init_dict)

return init_dict, unused_kwargs, hidden_config_dict

@classmethod
Expand Down
11 changes: 9 additions & 2 deletions src/diffusers/schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic

self.num_inference_steps = num_inference_steps

# "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "leading":
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
.round()[::-1]
.copy()
.astype(np.int64)
)
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
Expand Down
37 changes: 34 additions & 3 deletions src/diffusers/schedulers/scheduling_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
(https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
sample_max_value (`float`, default `1.0`):
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
timestep_spacing (`str`, default `"leading"`):
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
Expand All @@ -134,6 +141,8 @@ def __init__(
dynamic_thresholding_ratio: float = 0.995,
clip_sample_range: float = 1.0,
sample_max_value: float = 1.0,
timestep_spacing: str = "leading",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default is leading here

steps_offset: int = 1,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
Expand Down Expand Up @@ -228,11 +237,33 @@ def set_timesteps(
)

self.num_inference_steps = num_inference_steps

step_ratio = self.config.num_train_timesteps // self.num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
self.custom_timesteps = False

# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
.round()[::-1]
.copy()
.astype(np.int64)
)
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)

self.timesteps = torch.from_numpy(timesteps).to(device)

def _get_variance(self, t, predicted_variance=None, variance_type=None):
Expand Down
40 changes: 33 additions & 7 deletions src/diffusers/schedulers/scheduling_deis_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,13 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
lower_order_final (`bool`, default `True`):
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
find this trick can stabilize the sampling of DEIS for steps < 15, especially for steps <= 10.

timestep_spacing (`str`, default `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
Expand All @@ -125,6 +131,8 @@ def __init__(
algorithm_type: str = "deis",
solver_type: str = "logrho",
lower_order_final: bool = True,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
Expand Down Expand Up @@ -181,12 +189,30 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
)
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
)
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1)
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)

# when num_inference_steps == num_train_timesteps, we can end up with
# duplicates in timesteps.
Expand Down
38 changes: 32 additions & 6 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
diffusion ODEs.
timestep_spacing (`str`, default `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
Expand All @@ -158,6 +165,8 @@ def __init__(
use_karras_sigmas: Optional[bool] = False,
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
Expand Down Expand Up @@ -218,12 +227,29 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
# Clipping the minimum of all lambda(t) for numerical stability.
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
)
last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item()

# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64)
)
elif self.config.timestep_spacing == "leading":
step_ratio = last_timestep // (num_inference_steps + 1)
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)

if self.use_karras_sigmas:
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
Expand Down
40 changes: 36 additions & 4 deletions src/diffusers/schedulers/scheduling_dpmsolver_sde.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,13 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
noise_sampler_seed (`int`, *optional*, defaults to `None`):
The random seed to use for the noise sampler. If `None`, a random seed will be generated.
timestep_spacing (`str`, default `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
Expand All @@ -149,6 +156,8 @@ def __init__(
prediction_type: str = "epsilon",
use_karras_sigmas: Optional[bool] = False,
noise_sampler_seed: Optional[int] = None,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
Expand Down Expand Up @@ -187,6 +196,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
pos = 0
return indices[pos].item()

@property
def init_noise_sigma(self):
# standard deviation of the initial noise distribution
if self.config.timestep_spacing in ["linspace", "trailing"]:
return self.sigmas.max()

return (self.sigmas.max() ** 2 + 1) ** 0.5

def scale_model_input(
self,
sample: torch.FloatTensor,
Expand Down Expand Up @@ -226,7 +243,25 @@ def set_timesteps(

num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps

timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)

sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
Expand All @@ -242,9 +277,6 @@ def set_timesteps(
sigmas = torch.from_numpy(sigmas).to(device=device)
self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])

# standard deviation of the initial noise distribution
self.init_noise_sigma = self.sigmas.max()

timesteps = torch.from_numpy(timesteps)
second_order_timesteps = torch.from_numpy(second_order_timesteps)
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
Expand Down
Loading