diff --git a/src/diffusers/schedulers/scheduling_sasolver.py b/src/diffusers/schedulers/scheduling_sasolver.py index 2ff9d002d691..e09db049ce99 100644 --- a/src/diffusers/schedulers/scheduling_sasolver.py +++ b/src/diffusers/schedulers/scheduling_sasolver.py @@ -165,15 +165,27 @@ def __init__( if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + self.betas = torch.linspace( + beta_start, beta_end, num_train_timesteps, dtype=torch.float32 + ) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + self.betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=torch.float32, + ) + ** 2 + ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + raise NotImplementedError( + f"{beta_schedule} does is not implemented for {self.__class__}" + ) self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) @@ -187,11 +199,15 @@ def __init__( self.init_noise_sigma = 1.0 if algorithm_type not in ["data_prediction", "noise_prediction"]: - raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") + raise NotImplementedError( + f"{algorithm_type} does is not implemented for {self.__class__}" + ) # setable values self.num_inference_steps = None - timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + timesteps = np.linspace( + 0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32 + )[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.timestep_list = [None] * max(predictor_order, corrector_order - 1) self.model_outputs = [None] * max(predictor_order, corrector_order - 1) @@ -213,7 +229,9 @@ def step_index(self): """ return self._step_index - def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): + def set_timesteps( + self, num_inference_steps: int = None, device: Union[str, torch.device] = None + ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -225,26 +243,38 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc """ # Clipping the minimum of all lambda(t) for numerical stability. # This is critical for cosine (squaredcos_cap_v2) noise schedule. - clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) + clipped_idx = torch.searchsorted( + torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped + ) last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item() # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( - np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64) + np.linspace(0, last_timestep - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = last_timestep // (num_inference_steps + 1) # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) + timesteps = ( + (np.arange(0, num_inference_steps + 1) * step_ratio) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) + timesteps = ( + np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) + ) timesteps -= 1 else: raise ValueError( @@ -255,8 +285,12 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc if self.config.use_karras_sigmas: log_sigmas = np.log(sigmas) sigmas = np.flip(sigmas).copy() - sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + sigmas = self._convert_to_karras( + in_sigmas=sigmas, num_inference_steps=num_inference_steps + ) + timesteps = np.array( + [self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas] + ).round() sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) @@ -264,7 +298,9 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) - self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) + self.timesteps = torch.from_numpy(timesteps).to( + device=device, dtype=torch.int64 + ) self.num_inference_steps = len(timesteps) self.model_outputs = [ @@ -292,7 +328,9 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: batch_size, channels, *remaining_dims = sample.shape if dtype not in (torch.float32, torch.float64): - sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + sample = ( + sample.float() + ) # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) @@ -304,7 +342,9 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 - sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + sample = ( + torch.clamp(sample, -s, s) / s + ) # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, *remaining_dims) sample = sample.to(dtype) @@ -320,7 +360,11 @@ def _sigma_to_t(self, sigma, log_sigmas): dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range - low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + low_idx = ( + np.cumsum((dists >= 0), axis=0) + .argmax(axis=0) + .clip(max=log_sigmas.shape[0] - 2) + ) high_idx = low_idx + 1 low = log_sigmas[low_idx] @@ -343,7 +387,9 @@ def _sigma_to_alpha_sigma_t(self, sigma): return alpha_t, sigma_t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras - def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: + def _convert_to_karras( + self, in_sigmas: torch.FloatTensor, num_inference_steps + ) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" # Hack to make sure that other schedulers which copy this function don't break @@ -460,21 +506,27 @@ def convert_model_output( return epsilon - def get_coefficients_exponential_negative(self, order, interval_start, interval_end): + def get_coefficients_exponential_negative( + self, order, interval_start, interval_end + ): """ Calculate the integral of exp(-x) * x^order dx from interval_start to interval_end """ assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3" if order == 0: - return torch.exp(-interval_end) * (torch.exp(interval_end - interval_start) - 1) + return torch.exp(-interval_end) * ( + torch.exp(interval_end - interval_start) - 1 + ) elif order == 1: return torch.exp(-interval_end) * ( - (interval_start + 1) * torch.exp(interval_end - interval_start) - (interval_end + 1) + (interval_start + 1) * torch.exp(interval_end - interval_start) + - (interval_end + 1) ) elif order == 2: return torch.exp(-interval_end) * ( - (interval_start**2 + 2 * interval_start + 2) * torch.exp(interval_end - interval_start) + (interval_start**2 + 2 * interval_start + 2) + * torch.exp(interval_end - interval_start) - (interval_end**2 + 2 * interval_end + 2) ) elif order == 3: @@ -484,7 +536,9 @@ def get_coefficients_exponential_negative(self, order, interval_start, interval_ - (interval_end**3 + 3 * interval_end**2 + 6 * interval_end + 6) ) - def get_coefficients_exponential_positive(self, order, interval_start, interval_end, tau): + def get_coefficients_exponential_positive( + self, order, interval_start, interval_end, tau + ): """ Calculate the integral of exp(x(1+tau^2)) * x^order dx from interval_start to interval_end """ @@ -496,14 +550,17 @@ def get_coefficients_exponential_positive(self, order, interval_start, interval_ if order == 0: return ( - torch.exp(interval_end_cov) * (1 - torch.exp(-(interval_end_cov - interval_start_cov))) / (1 + tau**2) + torch.exp(interval_end_cov) + * (1 - torch.exp(-(interval_end_cov - interval_start_cov))) + / (1 + tau**2) ) elif order == 1: return ( torch.exp(interval_end_cov) * ( (interval_end_cov - 1) - - (interval_start_cov - 1) * torch.exp(-(interval_end_cov - interval_start_cov)) + - (interval_start_cov - 1) + * torch.exp(-(interval_end_cov - interval_start_cov)) ) / ((1 + tau**2) ** 2) ) @@ -521,8 +578,18 @@ def get_coefficients_exponential_positive(self, order, interval_start, interval_ return ( torch.exp(interval_end_cov) * ( - (interval_end_cov**3 - 3 * interval_end_cov**2 + 6 * interval_end_cov - 6) - - (interval_start_cov**3 - 3 * interval_start_cov**2 + 6 * interval_start_cov - 6) + ( + interval_end_cov**3 + - 3 * interval_end_cov**2 + + 6 * interval_end_cov + - 6 + ) + - ( + interval_start_cov**3 + - 3 * interval_start_cov**2 + + 6 * interval_start_cov + - 6 + ) * torch.exp(-(interval_end_cov - interval_start_cov)) ) / ((1 + tau**2) ** 4) @@ -539,13 +606,25 @@ def lagrange_polynomial_coefficient(self, order, lambda_list): return [[1]] elif order == 1: return [ - [1 / (lambda_list[0] - lambda_list[1]), -lambda_list[1] / (lambda_list[0] - lambda_list[1])], - [1 / (lambda_list[1] - lambda_list[0]), -lambda_list[0] / (lambda_list[1] - lambda_list[0])], + [ + 1 / (lambda_list[0] - lambda_list[1]), + -lambda_list[1] / (lambda_list[0] - lambda_list[1]), + ], + [ + 1 / (lambda_list[1] - lambda_list[0]), + -lambda_list[0] / (lambda_list[1] - lambda_list[0]), + ], ] elif order == 2: - denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2]) - denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2]) - denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1]) + denominator1 = (lambda_list[0] - lambda_list[1]) * ( + lambda_list[0] - lambda_list[2] + ) + denominator2 = (lambda_list[1] - lambda_list[0]) * ( + lambda_list[1] - lambda_list[2] + ) + denominator3 = (lambda_list[2] - lambda_list[0]) * ( + lambda_list[2] - lambda_list[1] + ) return [ [ 1 / denominator1, @@ -631,24 +710,36 @@ def lagrange_polynomial_coefficient(self, order, lambda_list): ], ] - def get_coefficients_fn(self, order, interval_start, interval_end, lambda_list, tau): + def get_coefficients_fn( + self, order, interval_start, interval_end, lambda_list, tau + ): assert order in [1, 2, 3, 4] - assert order == len(lambda_list), "the length of lambda list must be equal to the order" + assert order == len( + lambda_list + ), "the length of lambda list must be equal to the order" coefficients = [] - lagrange_coefficient = self.lagrange_polynomial_coefficient(order - 1, lambda_list) + lagrange_coefficient = self.lagrange_polynomial_coefficient( + order - 1, lambda_list + ) for i in range(order): coefficient = 0 for j in range(order): if self.predict_x0: - coefficient += lagrange_coefficient[i][j] * self.get_coefficients_exponential_positive( + coefficient += lagrange_coefficient[i][ + j + ] * self.get_coefficients_exponential_positive( order - 1 - j, interval_start, interval_end, tau ) else: - coefficient += lagrange_coefficient[i][j] * self.get_coefficients_exponential_negative( + coefficient += lagrange_coefficient[i][ + j + ] * self.get_coefficients_exponential_negative( order - 1 - j, interval_start, interval_end ) coefficients.append(coefficient) - assert len(coefficients) == order, "the length of coefficients does not match the order" + assert ( + len(coefficients) == order + ), "the length of coefficients does not match the order" return coefficients def stochastic_adams_bashforth_update( @@ -706,7 +797,10 @@ def stochastic_adams_bashforth_update( "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) model_output_list = self.model_outputs - sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + sigma_t, sigma_s0 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + ) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) @@ -722,8 +816,9 @@ def stochastic_adams_bashforth_update( lambda_si = torch.log(alpha_si) - torch.log(sigma_si) lambda_list.append(lambda_si) - - gradient_coefficients = self.get_coefficients_fn(order, lambda_s0, lambda_t, lambda_list, tau) + gradient_coefficients = self.get_coefficients_fn( + order, lambda_s0, lambda_t, lambda_list, tau + ) x = sample @@ -741,13 +836,21 @@ def stochastic_adams_bashforth_update( gradient_coefficients[0] += ( 1.0 * torch.exp((1 + tau**2) * lambda_t) - * (h**2 / 2 - (h * (1 + tau**2) - 1 + torch.exp((1 + tau**2) * (-h))) / ((1 + tau**2) ** 2)) + * ( + h**2 / 2 + - (h * (1 + tau**2) - 1 + torch.exp((1 + tau**2) * (-h))) + / ((1 + tau**2) ** 2) + ) / (lambda_s0 - temp_lambda_s) ) gradient_coefficients[1] -= ( 1.0 * torch.exp((1 + tau**2) * lambda_t) - * (h**2 / 2 - (h * (1 + tau**2) - 1 + torch.exp((1 + tau**2) * (-h))) / ((1 + tau**2) ** 2)) + * ( + h**2 / 2 + - (h * (1 + tau**2) - 1 + torch.exp((1 + tau**2) * (-h))) + / ((1 + tau**2) ** 2) + ) / (lambda_s0 - temp_lambda_s) ) @@ -761,7 +864,12 @@ def stochastic_adams_bashforth_update( * model_output_list[-(i + 1)] ) else: - gradient_part += -(1 + tau**2) * alpha_t * gradient_coefficients[i] * model_output_list[-(i + 1)] + gradient_part += ( + -(1 + tau**2) + * alpha_t + * gradient_coefficients[i] + * model_output_list[-(i + 1)] + ) if self.predict_x0: noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau**2 * h)) * noise @@ -769,7 +877,11 @@ def stochastic_adams_bashforth_update( noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * noise if self.predict_x0: - x_t = torch.exp(-(tau**2) * h) * (sigma_t / sigma_s0) * x + gradient_part + noise_part + x_t = ( + torch.exp(-(tau**2) * h) * (sigma_t / sigma_s0) * x + + gradient_part + + noise_part + ) else: x_t = (alpha_t / alpha_s0) * x + gradient_part + noise_part @@ -841,7 +953,10 @@ def stochastic_adams_moulton_update( ) model_output_list = self.model_outputs - sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1] + sigma_t, sigma_s0 = ( + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) @@ -856,10 +971,11 @@ def stochastic_adams_moulton_update( lambda_si = torch.log(alpha_si) - torch.log(sigma_si) lambda_list.append(lambda_si) - model_prev_list = model_output_list + [this_model_output] - gradient_coefficients = self.get_coefficients_fn(order, lambda_s0, lambda_t, lambda_list, tau) + gradient_coefficients = self.get_coefficients_fn( + order, lambda_s0, lambda_t, lambda_list, tau + ) x = last_sample @@ -874,12 +990,20 @@ def stochastic_adams_moulton_update( gradient_coefficients[0] += ( 1.0 * torch.exp((1 + tau**2) * lambda_t) - * (h / 2 - (h * (1 + tau**2) - 1 + torch.exp((1 + tau**2) * (-h))) / ((1 + tau**2) ** 2 * h)) + * ( + h / 2 + - (h * (1 + tau**2) - 1 + torch.exp((1 + tau**2) * (-h))) + / ((1 + tau**2) ** 2 * h) + ) ) gradient_coefficients[1] -= ( 1.0 * torch.exp((1 + tau**2) * lambda_t) - * (h / 2 - (h * (1 + tau**2) - 1 + torch.exp((1 + tau**2) * (-h))) / ((1 + tau**2) ** 2 * h)) + * ( + h / 2 + - (h * (1 + tau**2) - 1 + torch.exp((1 + tau**2) * (-h))) + / ((1 + tau**2) ** 2 * h) + ) ) for i in range(order): @@ -892,15 +1016,26 @@ def stochastic_adams_moulton_update( * model_prev_list[-(i + 1)] ) else: - gradient_part += -(1 + tau**2) * alpha_t * gradient_coefficients[i] * model_prev_list[-(i + 1)] + gradient_part += ( + -(1 + tau**2) + * alpha_t + * gradient_coefficients[i] + * model_prev_list[-(i + 1)] + ) if self.predict_x0: - noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau**2 * h)) * last_noise + noise_part = ( + sigma_t * torch.sqrt(1 - torch.exp(-2 * tau**2 * h)) * last_noise + ) else: noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * last_noise if self.predict_x0: - x_t = torch.exp(-(tau**2) * h) * (sigma_t / sigma_s0) * x + gradient_part + noise_part + x_t = ( + torch.exp(-(tau**2) * h) * (sigma_t / sigma_s0) * x + + gradient_part + + noise_part + ) else: x_t = (alpha_t / alpha_s0) * x + gradient_part + noise_part @@ -979,7 +1114,9 @@ def step( tau=current_tau, ) - for i in range(max(self.config.predictor_order, self.config.corrector_order - 1) - 1): + for i in range( + max(self.config.predictor_order, self.config.corrector_order - 1) - 1 + ): self.model_outputs[i] = self.model_outputs[i + 1] self.timestep_list[i] = self.timestep_list[i + 1] @@ -987,18 +1124,29 @@ def step( self.timestep_list[-1] = timestep noise = randn_tensor( - model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + model_output.shape, + generator=generator, + device=model_output.device, + dtype=model_output.dtype, ) if self.config.lower_order_final: - this_predictor_order = min(self.config.predictor_order, len(self.timesteps) - self.step_index) - this_corrector_order = min(self.config.corrector_order, len(self.timesteps) - self.step_index + 1) + this_predictor_order = min( + self.config.predictor_order, len(self.timesteps) - self.step_index + ) + this_corrector_order = min( + self.config.corrector_order, len(self.timesteps) - self.step_index + 1 + ) else: this_predictor_order = self.config.predictor_order this_corrector_order = self.config.corrector_order - self.this_predictor_order = min(this_predictor_order, self.lower_order_nums + 1) # warmup for multistep - self.this_corrector_order = min(this_corrector_order, self.lower_order_nums + 2) # warmup for multistep + self.this_predictor_order = min( + this_predictor_order, self.lower_order_nums + 1 + ) # warmup for multistep + self.this_corrector_order = min( + this_corrector_order, self.lower_order_nums + 2 + ) # warmup for multistep assert self.this_predictor_order > 0 assert self.this_corrector_order > 0 @@ -1014,7 +1162,9 @@ def step( tau=current_tau, ) - if self.lower_order_nums < max(self.config.predictor_order, self.config.corrector_order - 1): + if self.lower_order_nums < max( + self.config.predictor_order, self.config.corrector_order - 1 + ): self.lower_order_nums += 1 # upon completion increase step index by one @@ -1025,7 +1175,9 @@ def step( return SchedulerOutput(prev_sample=prev_sample) - def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + def scale_model_input( + self, sample: torch.FloatTensor, *args, **kwargs + ) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. @@ -1048,7 +1200,9 @@ def add_noise( timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + alphas_cumprod = self.alphas_cumprod.to( + device=original_samples.device, dtype=original_samples.dtype + ) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 @@ -1061,7 +1215,9 @@ def add_noise( while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + noisy_samples = ( + sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + ) return noisy_samples def __len__(self): diff --git a/tests/schedulers/test_scheduler_sasolver.py b/tests/schedulers/test_scheduler_sasolver.py index 1b8a3dc69ac2..32ab889d203f 100644 --- a/tests/schedulers/test_scheduler_sasolver.py +++ b/tests/schedulers/test_scheduler_sasolver.py @@ -37,19 +37,29 @@ def test_step_shape(self): if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): scheduler.set_timesteps(num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + elif num_inference_steps is not None and not hasattr( + scheduler, "set_timesteps" + ): kwargs["num_inference_steps"] = num_inference_steps # copy over dummy past residuals (must be done after set_timesteps) dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10] - scheduler.model_outputs = \ - dummy_past_residuals[: max(scheduler.config.predictor_order, scheduler.config.corrector_order - 1)] + scheduler.model_outputs = dummy_past_residuals[ + : max( + scheduler.config.predictor_order, + scheduler.config.corrector_order - 1, + ) + ] time_step_0 = scheduler.timesteps[5] time_step_1 = scheduler.timesteps[6] - output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample - output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample + output_0 = scheduler.step( + residual, time_step_0, sample, **kwargs + ).prev_sample + output_1 = scheduler.step( + residual, time_step_1, sample, **kwargs + ).prev_sample self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, output_1.shape) @@ -59,7 +69,9 @@ def test_timesteps(self): self.check_over_configs(num_train_timesteps=timesteps) def test_betas(self): - for beta_start, beta_end in zip([0.00001, 0.0001, 0.001], [0.0002, 0.002, 0.02]): + for beta_start, beta_end in zip( + [0.00001, 0.0001, 0.001], [0.0002, 0.002, 0.02] + ): self.check_over_configs(beta_start=beta_start, beta_end=beta_end) def test_schedules(self):