diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index c77ea03adf3e..67a8e48d381f 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -115,7 +115,7 @@ def __init__( deprecate("device", "1.0.0", deprecation_message, standard_warn=False) self.to(device=kwargs["device"]) - self.collected_params = None + self.temp_stored_params = None self.decay = decay self.min_decay = min_decay @@ -149,7 +149,6 @@ def save_pretrained(self, path): model = self.model_cls.from_config(self.model_config) state_dict = self.state_dict() state_dict.pop("shadow_params", None) - state_dict.pop("collected_params", None) model.register_to_config(**state_dict) self.copy_to(model.parameters()) @@ -248,9 +247,35 @@ def state_dict(self) -> dict: "inv_gamma": self.inv_gamma, "power": self.power, "shadow_params": self.shadow_params, - "collected_params": self.collected_params, } + def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: + r""" + Args: + Save the current parameters for restoring later. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] + + def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: + r""" + Args: + Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: + affecting the original optimization process. Store the parameters before the `copy_to()` method. After + validation (or model saving), use this to restore the former parameters. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the parameters with which this + `ExponentialMovingAverage` was initialized will be used. + """ + if self.temp_stored_params is None: + raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") + for c_param, param in zip(self.temp_stored_params, parameters): + param.data.copy_(c_param.data) + + # Better memory-wise. + self.temp_stored_params = None + def load_state_dict(self, state_dict: dict) -> None: r""" Args: @@ -297,12 +322,3 @@ def load_state_dict(self, state_dict: dict) -> None: raise ValueError("shadow_params must be a list") if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): raise ValueError("shadow_params must all be Tensors") - - self.collected_params = state_dict.get("collected_params", None) - if self.collected_params is not None: - if not isinstance(self.collected_params, list): - raise ValueError("collected_params must be a list") - if not all(isinstance(p, torch.Tensor) for p in self.collected_params): - raise ValueError("collected_params must all be Tensors") - if len(self.collected_params) != len(self.shadow_params): - raise ValueError("collected_params and shadow_params must have the same length") diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index a1394292d7fd..6b8ddd2a0ef8 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -212,7 +212,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class StableDiffusionSAGPipeline(metaclass=DummyObject): +class StableDiffusionPix2PixZeroPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -227,7 +227,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class StableDiffusionPix2PixZeroPipeline(metaclass=DummyObject): +class StableDiffusionSAGPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs):