Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions src/diffusers/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(
deprecate("device", "1.0.0", deprecation_message, standard_warn=False)
self.to(device=kwargs["device"])

self.collected_params = None
self.temp_stored_params = None

self.decay = decay
self.min_decay = min_decay
Expand Down Expand Up @@ -149,7 +149,6 @@ def save_pretrained(self, path):
model = self.model_cls.from_config(self.model_config)
state_dict = self.state_dict()
state_dict.pop("shadow_params", None)
state_dict.pop("collected_params", None)

model.register_to_config(**state_dict)
self.copy_to(model.parameters())
Expand Down Expand Up @@ -248,9 +247,35 @@ def state_dict(self) -> dict:
"inv_gamma": self.inv_gamma,
"power": self.power,
"shadow_params": self.shadow_params,
"collected_params": self.collected_params,
}

def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
r"""
Args:
Save the current parameters for restoring later.
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]

def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
r"""
Args:
Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without:
affecting the original optimization process. Store the parameters before the `copy_to()` method. After
validation (or model saving), use this to restore the former parameters.
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters. If `None`, the parameters with which this
`ExponentialMovingAverage` was initialized will be used.
"""
if self.temp_stored_params is None:
raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`")
for c_param, param in zip(self.temp_stored_params, parameters):
param.data.copy_(c_param.data)

# Better memory-wise.
self.temp_stored_params = None

def load_state_dict(self, state_dict: dict) -> None:
r"""
Args:
Expand Down Expand Up @@ -297,12 +322,3 @@ def load_state_dict(self, state_dict: dict) -> None:
raise ValueError("shadow_params must be a list")
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
raise ValueError("shadow_params must all be Tensors")

self.collected_params = state_dict.get("collected_params", None)
if self.collected_params is not None:
if not isinstance(self.collected_params, list):
raise ValueError("collected_params must be a list")
if not all(isinstance(p, torch.Tensor) for p in self.collected_params):
raise ValueError("collected_params must all be Tensors")
if len(self.collected_params) != len(self.shadow_params):
raise ValueError("collected_params and shadow_params must have the same length")
4 changes: 2 additions & 2 deletions src/diffusers/utils/dummy_torch_and_transformers_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])


class StableDiffusionSAGPipeline(metaclass=DummyObject):
class StableDiffusionPix2PixZeroPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

def __init__(self, *args, **kwargs):
Expand All @@ -227,7 +227,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])


class StableDiffusionPix2PixZeroPipeline(metaclass=DummyObject):
class StableDiffusionSAGPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

def __init__(self, *args, **kwargs):
Expand Down