From 57c19e3796c9d387015506153c09a4e300472fbf Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 30 Dec 2022 12:11:23 +0100 Subject: [PATCH 1/2] Fix ema decay and clarify nomenclature. --- examples/text_to_image/train_text_to_image.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 224fe471889e..3fccef4fbcd8 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -278,24 +278,19 @@ def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): self.decay = decay self.optimization_step = 0 - def get_decay(self, optimization_step): - """ - Compute the decay factor for the exponential moving average. - """ - value = (1 + optimization_step) / (10 + optimization_step) - return 1 - min(self.decay, value) - @torch.no_grad() def step(self, parameters): parameters = list(parameters) self.optimization_step += 1 - self.decay = self.get_decay(self.optimization_step) + + # Compute the decay factor for the exponential moving average. + current_decay = (1 + self.optimization_step) / (10 + self.optimization_step) + one_minus_decay = 1 - min(self.decay, current_decay) for s_param, param in zip(self.shadow_params, parameters): if param.requires_grad: - tmp = self.decay * (s_param - param) - s_param.sub_(tmp) + s_param.sub_(one_minus_decay * (s_param - param)) else: s_param.copy_(param) From 65a0fd557fc7c90d2c0c5e71dd745825b95ca2fb Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 30 Dec 2022 12:32:32 +0100 Subject: [PATCH 2/2] Rename var. --- examples/text_to_image/train_text_to_image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 3fccef4fbcd8..6850d9cafcd1 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -285,8 +285,8 @@ def step(self, parameters): self.optimization_step += 1 # Compute the decay factor for the exponential moving average. - current_decay = (1 + self.optimization_step) / (10 + self.optimization_step) - one_minus_decay = 1 - min(self.decay, current_decay) + value = (1 + self.optimization_step) / (10 + self.optimization_step) + one_minus_decay = 1 - min(self.decay, value) for s_param, param in zip(self.shadow_params, parameters): if param.requires_grad: