Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -786,13 +786,40 @@ def prepare_control_image(
return image

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
if denoising_start is None:
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
else:
t_start = 0

t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]

# Strength is irrelevant if we directly request a timestep to start at;
# that is, strength is determined by the denoising_start instead.
if denoising_start is not None:
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
- (denoising_start * self.scheduler.config.num_train_timesteps)
)
)

num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
# if the scheduler is a 2nd order scheduler we might have to do +1
# because `num_inference_steps` might be even given that every timestep
# (except the highest one) is duplicated. If `num_inference_steps` is even it would
# mean that we cut the timesteps in the middle of the denoising step
# (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
# we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
num_inference_steps = num_inference_steps + 1

# because t_n+1 >= t_n, we slice the timesteps starting from the end
timesteps = timesteps[-num_inference_steps:]
return timesteps, num_inference_steps

return timesteps, num_inference_steps - t_start

# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents
Expand Down Expand Up @@ -983,6 +1010,14 @@ def do_classifier_free_guidance(self):
def cross_attention_kwargs(self):
return self._cross_attention_kwargs

@property
def denoising_end(self):
return self._denoising_end

@property
def denoising_start(self):
return self._denoising_start

@property
def num_timesteps(self):
return self._num_timesteps
Expand All @@ -999,6 +1034,8 @@ def __call__(
width: Optional[int] = None,
strength: float = 0.8,
num_inference_steps: int = 50,
denoising_start: Optional[float] = None,
denoising_end: Optional[float] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
Expand Down Expand Up @@ -1240,6 +1277,8 @@ def __call__(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._denoising_end = denoising_end
self._denoising_start = denoising_start

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
Expand Down Expand Up @@ -1326,11 +1365,20 @@ def __call__(
assert False

# 5. Prepare timesteps
def denoising_value_valid(dnv):
return isinstance(self.denoising_end, float) and 0 < dnv < 1

self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps,
strength,
device,
denoising_start=self.denoising_start if denoising_value_valid else None,
)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
self._num_timesteps = len(timesteps)

add_noise = True if self.denoising_start is None else False
# 6. Prepare latent variables
latents = self.prepare_latents(
image,
Expand All @@ -1340,7 +1388,7 @@ def __call__(
prompt_embeds.dtype,
device,
generator,
True,
add_noise,
)

# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
Expand Down Expand Up @@ -1399,6 +1447,28 @@ def __call__(

# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order

if (
self.denoising_end is not None
and self.denoising_start is not None
and denoising_value_valid(self.denoising_end)
and denoising_value_valid(self.denoising_start)
and self.denoising_start >= self.denoising_end
):
raise ValueError(
f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
+ f" {self.denoising_end} when using type float."
)
elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
)
)
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps]

with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
Expand Down
8 changes: 5 additions & 3 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,16 +266,18 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
sigma_last = [0.0] if self.euler_at_final else sigmas[-1:]
sigmas = np.concatenate([sigmas, sigma_last]).astype(np.float32)
elif self.config.use_lu_lambdas:
lambdas = np.flip(log_sigmas.copy())
lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
sigmas = np.exp(lambdas)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
sigma_last = [0.0] if self.euler_at_final else sigmas[-1:]
sigmas = np.concatenate([sigmas, sigma_last]).astype(np.float32)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
sigma_last = 0.0 if self.euler_at_final else ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)

self.sigmas = torch.from_numpy(sigmas)
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/schedulers/scheduling_lms_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def __init__(
prediction_type: str = "epsilon",
timestep_spacing: str = "linspace",
steps_offset: int = 0,
use_nihanth_order_dropoff: bool = False,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
Expand Down Expand Up @@ -166,6 +167,7 @@ def __init__(
self.set_timesteps(num_train_timesteps, None)
self.derivatives = []
self.is_scale_input_called = False
self.use_nihanth_order_dropoff = use_nihanth_order_dropoff

self._step_index = None

Expand Down Expand Up @@ -400,6 +402,9 @@ def step(

# 3. Compute linear multistep coefficients
order = min(self.step_index + 1, order)
if self.use_nihanth_order_dropoff:
# Use order=1 for the last 15 steps
order = 1 if len(self.timesteps) - self.step_index < 15 else order
lms_coeffs = [self.get_lms_coefficient(order, self.step_index, curr_order) for curr_order in range(order)]

# 4. Compute previous sample based on the derivatives path
Expand Down