From 2c887e89ddb67d815f64add8a3fd8a000268756a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 9 Feb 2023 16:39:43 +0530 Subject: [PATCH 01/11] add store and restore() methods to EMAModel. --- src/diffusers/training_utils.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index c77ea03adf3e..cc6785ee0fb7 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -251,6 +251,32 @@ def state_dict(self) -> dict: "collected_params": self.collected_params, } + def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: + r""" + Args: + Save the current parameters for restoring later. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + parameters = list(parameters) + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: + r""" + Args: + Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without + affecting the original optimization process. Store the parameters before the `copy_to()` method. After + validation (or model saving), use this to restore the former parameters. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the parameters with which this + `ExponentialMovingAverage` was initialized will be used. + """ + if self.collected_params is None: + raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") + parameters = list(parameters) + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) + def load_state_dict(self, state_dict: dict) -> None: r""" Args: From ee5766e28335d07256c836e24e26c929a1107c9c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 13 Feb 2023 20:09:31 +0530 Subject: [PATCH 02/11] Update src/diffusers/training_utils.py Co-authored-by: Patrick von Platen --- src/diffusers/training_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index cc6785ee0fb7..e2859c86cf35 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -259,7 +259,7 @@ def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: temporarily stored. """ parameters = list(parameters) - self.collected_params = [param.clone() for param in parameters] + self.collected_params = [param.detach().cpu().clone() for param in parameters] def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: r""" From 5986943292e608d97f67ac949334d807265010e2 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 14 Feb 2023 16:25:57 +0000 Subject: [PATCH 03/11] make style with doc builder --- src/diffusers/training_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index e2859c86cf35..8d619518967d 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -264,7 +264,7 @@ def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: r""" Args: - Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without + Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: affecting the original optimization process. Store the parameters before the `copy_to()` method. After validation (or model saving), use this to restore the former parameters. parameters: Iterable of `torch.nn.Parameter`; the parameters to be From db7f031166b226971a7d40389d1f044dac86aadf Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 16 Feb 2023 08:46:48 +0530 Subject: [PATCH 04/11] remove explicit listing. --- src/diffusers/training_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 8d619518967d..864aa586e565 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -258,7 +258,6 @@ def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored. """ - parameters = list(parameters) self.collected_params = [param.detach().cpu().clone() for param in parameters] def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: @@ -273,7 +272,6 @@ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: """ if self.collected_params is None: raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") - parameters = list(parameters) for c_param, param in zip(self.collected_params, parameters): param.data.copy_(c_param.data) From 4125190ef112263c4cd6dd467925e6fe9f945d5c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 16 Feb 2023 10:22:51 +0530 Subject: [PATCH 05/11] Apply suggestions from code review Co-authored-by: Will Berman --- src/diffusers/training_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 8d619518967d..864aa586e565 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -258,7 +258,6 @@ def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored. """ - parameters = list(parameters) self.collected_params = [param.detach().cpu().clone() for param in parameters] def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: @@ -273,7 +272,6 @@ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: """ if self.collected_params is None: raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") - parameters = list(parameters) for c_param, param in zip(self.collected_params, parameters): param.data.copy_(c_param.data) From e6f2e3792178c89ab77f1296e462443e076d75df Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 16 Feb 2023 14:48:47 +0530 Subject: [PATCH 06/11] Apply suggestions from code review Co-authored-by: Patrick von Platen --- src/diffusers/training_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 864aa586e565..976162a5a30a 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -251,7 +251,7 @@ def state_dict(self) -> dict: "collected_params": self.collected_params, } - def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: + def store_non_ema(self, parameters: Iterable[torch.nn.Parameter]) -> None: r""" Args: Save the current parameters for restoring later. @@ -260,7 +260,7 @@ def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: """ self.collected_params = [param.detach().cpu().clone() for param in parameters] - def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: + def restore_non_ema(self, parameters: Iterable[torch.nn.Parameter]) -> None: r""" Args: Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: From b3a28c067ac95b17361e2a8e4bef384c5160f316 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 16 Feb 2023 14:53:53 +0530 Subject: [PATCH 07/11] chore: better variable naming. --- src/diffusers/training_utils.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 976162a5a30a..61afc812ff87 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -115,7 +115,7 @@ def __init__( deprecate("device", "1.0.0", deprecation_message, standard_warn=False) self.to(device=kwargs["device"]) - self.collected_params = None + self.temp_stored_params = None self.decay = decay self.min_decay = min_decay @@ -149,7 +149,7 @@ def save_pretrained(self, path): model = self.model_cls.from_config(self.model_config) state_dict = self.state_dict() state_dict.pop("shadow_params", None) - state_dict.pop("collected_params", None) + state_dict.pop("temp_stored_params", None) model.register_to_config(**state_dict) self.copy_to(model.parameters()) @@ -248,19 +248,19 @@ def state_dict(self) -> dict: "inv_gamma": self.inv_gamma, "power": self.power, "shadow_params": self.shadow_params, - "collected_params": self.collected_params, + "temp_stored_params": self.temp_stored_params, } - def store_non_ema(self, parameters: Iterable[torch.nn.Parameter]) -> None: + def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: r""" Args: Save the current parameters for restoring later. parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored. """ - self.collected_params = [param.detach().cpu().clone() for param in parameters] + self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] - def restore_non_ema(self, parameters: Iterable[torch.nn.Parameter]) -> None: + def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: r""" Args: Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: @@ -270,9 +270,9 @@ def restore_non_ema(self, parameters: Iterable[torch.nn.Parameter]) -> None: updated with the stored parameters. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ - if self.collected_params is None: + if self.temp_stored_params is None: raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") - for c_param, param in zip(self.collected_params, parameters): + for c_param, param in zip(self.temp_stored_params, parameters): param.data.copy_(c_param.data) def load_state_dict(self, state_dict: dict) -> None: @@ -322,11 +322,11 @@ def load_state_dict(self, state_dict: dict) -> None: if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): raise ValueError("shadow_params must all be Tensors") - self.collected_params = state_dict.get("collected_params", None) - if self.collected_params is not None: - if not isinstance(self.collected_params, list): - raise ValueError("collected_params must be a list") - if not all(isinstance(p, torch.Tensor) for p in self.collected_params): - raise ValueError("collected_params must all be Tensors") - if len(self.collected_params) != len(self.shadow_params): - raise ValueError("collected_params and shadow_params must have the same length") + self.temp_stored_params = state_dict.get("temp_stored_params", None) + if self.temp_stored_params is not None: + if not isinstance(self.temp_stored_params, list): + raise ValueError("temp_stored_params must be a list") + if not all(isinstance(p, torch.Tensor) for p in self.temp_stored_params): + raise ValueError("temp_stored_params must all be Tensors") + if len(self.temp_stored_params) != len(self.shadow_params): + raise ValueError("temp_stored_params and shadow_params must have the same length") From 635313ca278752b5b779d2c9daf88f67ece1f936 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 16 Feb 2023 17:24:19 +0530 Subject: [PATCH 08/11] better treatment of temp_stored_params Co-authored-by: patil-suraj --- src/diffusers/training_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 61afc812ff87..4be93f4d6e06 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -248,7 +248,6 @@ def state_dict(self) -> dict: "inv_gamma": self.inv_gamma, "power": self.power, "shadow_params": self.shadow_params, - "temp_stored_params": self.temp_stored_params, } def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: @@ -275,6 +274,9 @@ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: for c_param, param in zip(self.temp_stored_params, parameters): param.data.copy_(c_param.data) + # Better memory-wise. + self.temp_stored_params = None + def load_state_dict(self, state_dict: dict) -> None: r""" Args: From 9adac8db1c334c9e06be16b2b176f8975dc06f85 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 16 Feb 2023 17:26:43 +0530 Subject: [PATCH 09/11] make style --- src/diffusers/training_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 4be93f4d6e06..3dec83187d9f 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -274,7 +274,7 @@ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: for c_param, param in zip(self.temp_stored_params, parameters): param.data.copy_(c_param.data) - # Better memory-wise. + # Better memory-wise. self.temp_stored_params = None def load_state_dict(self, state_dict: dict) -> None: From a2e44779bbd94fb1cb6d131cb6e35b625c9e4424 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 16 Feb 2023 17:35:09 +0530 Subject: [PATCH 10/11] =?UTF-8?q?remove=20temporary=20params=20from=20eart?= =?UTF-8?q?h=20=F0=9F=8C=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/diffusers/training_utils.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 3dec83187d9f..67a8e48d381f 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -149,7 +149,6 @@ def save_pretrained(self, path): model = self.model_cls.from_config(self.model_config) state_dict = self.state_dict() state_dict.pop("shadow_params", None) - state_dict.pop("temp_stored_params", None) model.register_to_config(**state_dict) self.copy_to(model.parameters()) @@ -323,12 +322,3 @@ def load_state_dict(self, state_dict: dict) -> None: raise ValueError("shadow_params must be a list") if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): raise ValueError("shadow_params must all be Tensors") - - self.temp_stored_params = state_dict.get("temp_stored_params", None) - if self.temp_stored_params is not None: - if not isinstance(self.temp_stored_params, list): - raise ValueError("temp_stored_params must be a list") - if not all(isinstance(p, torch.Tensor) for p in self.temp_stored_params): - raise ValueError("temp_stored_params must all be Tensors") - if len(self.temp_stored_params) != len(self.shadow_params): - raise ValueError("temp_stored_params and shadow_params must have the same length") From 58738646e4fafb1495f36cdccf88b92f2a9193e7 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 16 Feb 2023 19:06:51 +0530 Subject: [PATCH 11/11] make fix-copies. --- src/diffusers/utils/dummy_torch_and_transformers_objects.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index a1394292d7fd..6b8ddd2a0ef8 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -212,7 +212,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class StableDiffusionSAGPipeline(metaclass=DummyObject): +class StableDiffusionPix2PixZeroPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -227,7 +227,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class StableDiffusionPix2PixZeroPipeline(metaclass=DummyObject): +class StableDiffusionSAGPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs):