diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 62a13c33cd89..dcf899935deb 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -280,10 +280,12 @@ def step( pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif self.config.prediction_type == "sample": pred_original_sample = model_output + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output else: raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` " - " for the DDPMScheduler." + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" + " `v_prediction` for the DDPMScheduler." ) # 3. Clip "predicted x_0" diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index fe8a36c43f51..7011e1f30a4e 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -78,6 +78,7 @@ def __init__( beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + prediction_type: str = "epsilon", ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -202,7 +203,16 @@ def step( sigma = self.sigmas[step_index] # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise - pred_original_sample = sample - sigma * model_output + if self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + sigma_from = self.sigmas[step_index] sigma_to = self.sigmas[step_index + 1] sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_heun.py b/src/diffusers/schedulers/scheduling_heun.py index d21591b3df21..27a54f645e4e 100644 --- a/src/diffusers/schedulers/scheduling_heun.py +++ b/src/diffusers/schedulers/scheduling_heun.py @@ -54,6 +54,7 @@ def __init__( beta_end: float = 0.012, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + prediction_type: str = "epsilon", ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -184,7 +185,15 @@ def step( sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise - pred_original_sample = sample - sigma_hat * model_output + if self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma_hat * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) if self.state_in_first_order: # 2. Convert to an ODE derivative diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 4c28db591a62..044eb70c213c 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -78,6 +78,7 @@ def __init__( beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + prediction_type: str = "epsilon", ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -215,7 +216,15 @@ def step( sigma = self.sigmas[step_index] # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise - pred_original_sample = sample - sigma * model_output + if self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) # 2. Convert to an ODE derivative derivative = (sample - pred_original_sample) / sigma diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index bb3e098f7e42..ddfd1338a4e5 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -102,6 +102,7 @@ def __init__( trained_betas: Optional[Union[np.ndarray, List[float]]] = None, skip_prk_steps: bool = False, set_alpha_to_one: bool = False, + prediction_type: str = "epsilon", steps_offset: int = 0, ): if trained_betas is not None: @@ -368,6 +369,13 @@ def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev + if self.config.prediction_type == "v_prediction": + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + elif self.config.prediction_type != "epsilon": + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `v_prediction`" + ) + # corresponds to (α_(t−δ) - α_t) divided by # denominator of x_t in formula (9) and plus 1 # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 621fbfe1253b..5d885fef6b12 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -635,7 +635,7 @@ def test_clip_sample(self): self.check_over_configs(clip_sample=clip_sample) def test_prediction_type(self): - for prediction_type in ["epsilon", "sample"]: + for prediction_type in ["epsilon", "sample", "v_prediction"]: self.check_over_configs(prediction_type=prediction_type) def test_deprecated_predict_epsilon(self): @@ -711,6 +711,37 @@ def test_full_loop_no_noise(self): assert abs(result_sum.item() - 258.9070) < 1e-2 assert abs(result_mean.item() - 0.3374) < 1e-3 + def test_full_loop_with_v_prediction(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(prediction_type="v_prediction") + scheduler = scheduler_class(**scheduler_config) + + num_trained_timesteps = len(scheduler) + + model = self.dummy_model() + sample = self.dummy_sample_deter + generator = torch.manual_seed(0) + + for t in reversed(range(num_trained_timesteps)): + # 1. predict noise residual + residual = model(sample, t) + + # 2. predict previous mean of sample x_t-1 + pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample + + # if t > 0: + # noise = self.dummy_sample_deter + # variance = scheduler.get_variance(t) ** (0.5) * noise + # + # sample = pred_prev_sample + variance + sample = pred_prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 201.9864) < 1e-2 + assert abs(result_mean.item() - 0.2630) < 1e-3 + class DDIMSchedulerTest(SchedulerCommonTest): scheduler_classes = (DDIMScheduler,) @@ -768,6 +799,10 @@ def test_schedules(self): for schedule in ["linear", "squaredcos_cap_v2"]: self.check_over_configs(beta_schedule=schedule) + def test_prediction_type(self): + for prediction_type in ["epsilon", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + def test_clip_sample(self): for clip_sample in [True, False]: self.check_over_configs(clip_sample=clip_sample) @@ -805,6 +840,15 @@ def test_full_loop_no_noise(self): assert abs(result_sum.item() - 172.0067) < 1e-2 assert abs(result_mean.item() - 0.223967) < 1e-3 + def test_full_loop_with_v_prediction(self): + sample = self.full_loop(prediction_type="v_prediction") + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 52.5302) < 1e-2 + assert abs(result_mean.item() - 0.0684) < 1e-3 + def test_full_loop_with_set_alpha_to_one(self): # We specify different beta, so that the first alpha is 0.99 sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) @@ -971,6 +1015,10 @@ def test_thresholding(self): solver_type=solver_type, ) + def test_prediction_type(self): + for prediction_type in ["epsilon", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + def test_solver_order_and_type(self): for algorithm_type in ["dpmsolver", "dpmsolver++"]: for solver_type in ["midpoint", "heun"]: @@ -1004,6 +1052,12 @@ def test_full_loop_no_noise(self): assert abs(result_mean.item() - 0.3301) < 1e-3 + def test_full_loop_with_v_prediction(self): + sample = self.full_loop(prediction_type="v_prediction") + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 0.2251) < 1e-3 + def test_fp16_support(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0) @@ -1184,6 +1238,10 @@ def test_schedules(self): for schedule in ["linear", "squaredcos_cap_v2"]: self.check_over_configs(beta_schedule=schedule) + def test_prediction_type(self): + for prediction_type in ["epsilon", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + def test_time_indices(self): for t in [1, 5, 10]: self.check_over_forward(time_step=t) @@ -1225,6 +1283,14 @@ def test_full_loop_no_noise(self): assert abs(result_sum.item() - 198.1318) < 1e-2 assert abs(result_mean.item() - 0.2580) < 1e-3 + def test_full_loop_with_v_prediction(self): + sample = self.full_loop(prediction_type="v_prediction") + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 67.3986) < 1e-2 + assert abs(result_mean.item() - 0.0878) < 1e-3 + def test_full_loop_with_set_alpha_to_one(self): # We specify different beta, so that the first alpha is 0.99 sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) @@ -1453,6 +1519,10 @@ def test_schedules(self): for schedule in ["linear", "scaled_linear"]: self.check_over_configs(beta_schedule=schedule) + def test_prediction_type(self): + for prediction_type in ["epsilon", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + def test_time_indices(self): for t in [0, 500, 800]: self.check_over_forward(time_step=t) @@ -1481,6 +1551,30 @@ def test_full_loop_no_noise(self): assert abs(result_sum.item() - 1006.388) < 1e-2 assert abs(result_mean.item() - 1.31) < 1e-3 + def test_full_loop_with_v_prediction(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(prediction_type="v_prediction") + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + + for i, t in enumerate(scheduler.timesteps): + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 0.0017) < 1e-2 + assert abs(result_mean.item() - 2.2676e-06) < 1e-3 + def test_full_loop_device(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() @@ -1534,6 +1628,10 @@ def test_schedules(self): for schedule in ["linear", "scaled_linear"]: self.check_over_configs(beta_schedule=schedule) + def test_prediction_type(self): + for prediction_type in ["epsilon", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() @@ -1565,6 +1663,37 @@ def test_full_loop_no_noise(self): assert abs(result_sum.item() - 10.0807) < 1e-2 assert abs(result_mean.item() - 0.0131) < 1e-3 + def test_full_loop_with_v_prediction(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(prediction_type="v_prediction") + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps) + + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = sample.to(torch_device) + + for i, t in enumerate(scheduler.timesteps): + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample, generator=generator) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 0.0002) < 1e-2 + assert abs(result_mean.item() - 2.2676e-06) < 1e-3 + def test_full_loop_device(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() @@ -1624,6 +1753,10 @@ def test_schedules(self): for schedule in ["linear", "scaled_linear"]: self.check_over_configs(beta_schedule=schedule) + def test_prediction_type(self): + for prediction_type in ["epsilon", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() @@ -1660,6 +1793,42 @@ def test_full_loop_no_noise(self): assert abs(result_sum.item() - 144.8084) < 1e-2 assert abs(result_mean.item() - 0.18855) < 1e-3 + def test_full_loop_with_v_prediction(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(prediction_type="v_prediction") + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps) + + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = sample.to(torch_device) + + for i, t in enumerate(scheduler.timesteps): + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample, generator=generator) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + if torch_device in ["cpu", "mps"]: + assert abs(result_sum.item() - 108.4439) < 1e-2 + assert abs(result_mean.item() - 0.1412) < 1e-3 + else: + # CUDA + assert abs(result_sum.item() - 102.5807) < 1e-2 + assert abs(result_mean.item() - 0.1335) < 1e-3 + def test_full_loop_device(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() @@ -1932,6 +2101,10 @@ def test_schedules(self): for schedule in ["linear", "scaled_linear"]: self.check_over_configs(beta_schedule=schedule) + def test_prediction_type(self): + for prediction_type in ["epsilon", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() @@ -1962,6 +2135,36 @@ def test_full_loop_no_noise(self): assert abs(result_sum.item() - 0.1233) < 1e-2 assert abs(result_mean.item() - 0.0002) < 1e-3 + def test_full_loop_with_v_prediction(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(prediction_type="v_prediction") + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = sample.to(torch_device) + + for i, t in enumerate(scheduler.timesteps): + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + if torch_device in ["cpu", "mps"]: + assert abs(result_sum.item() - 4.6934e-07) < 1e-2 + assert abs(result_mean.item() - 6.1112e-10) < 1e-3 + else: + # CUDA + assert abs(result_sum.item() - 4.693428650170972e-07) < 1e-2 + assert abs(result_mean.item() - 0.0002) < 1e-3 + def test_full_loop_device(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config()