From 62668b509a3a2f1d95dfc4af0a94d44cfff2da7e Mon Sep 17 00:00:00 2001 From: hlky Date: Tue, 17 Dec 2024 17:29:59 +0000 Subject: [PATCH] Add `set_shift` to FlowMatchEulerDiscreteScheduler --- .../scheduling_flow_match_euler_discrete.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index 6ddd9ac23009..c7474d56c708 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -99,10 +99,19 @@ def __init__( self._step_index = None self._begin_index = None + self._shift = shift + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigma_min = self.sigmas[-1].item() self.sigma_max = self.sigmas[0].item() + @property + def shift(self): + """ + The value used for shifting. + """ + return self._shift + @property def step_index(self): """ @@ -128,6 +137,9 @@ def set_begin_index(self, begin_index: int = 0): """ self._begin_index = begin_index + def set_shift(self, shift: float): + self._shift = shift + def scale_noise( self, sample: torch.FloatTensor, @@ -236,7 +248,7 @@ def set_timesteps( if self.config.use_dynamic_shifting: sigmas = self.time_shift(mu, 1.0, sigmas) else: - sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) + sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) if self.config.shift_terminal: sigmas = self.stretch_shift_to_terminal(sigmas)