From 0623b1868dae98290c130388288a232f75bcfb7a Mon Sep 17 00:00:00 2001 From: Chenguo Lin Date: Wed, 30 Oct 2024 20:45:36 +0800 Subject: [PATCH] fix `EMAModel.use_ema_warmup` I do not understand for a long time why `use_ema_warmup` is forced to be `True` in `EMAModel.__init__()`.And it is also unreasonable to still use warmup by `cur_decay_value = (1 + step) / (10 + step)` when `use_ema_warmup=False`. --- src/diffusers/training_utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 0e0d0ce5b568..3ad333e7777e 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -341,9 +341,6 @@ def __init__( ) parameters = parameters.parameters() - # set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility - use_ema_warmup = True - if kwargs.get("max_value", None) is not None: deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead." deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False) @@ -414,7 +411,7 @@ def get_decay(self, optimization_step: int) -> float: if self.use_ema_warmup: cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power else: - cur_decay_value = (1 + step) / (10 + step) + cur_decay_value = self.decay cur_decay_value = min(cur_decay_value, self.decay) # make sure decay is not smaller than min_decay