From 92606665e192e548f71b282a4fc103e95c6bef70 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 20 Jan 2023 06:37:56 +0100 Subject: [PATCH 01/14] better accelerated saving --- examples/dreambooth/train_dreambooth.py | 29 +++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 654616eb9849..9f99f2217953 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -585,6 +585,35 @@ def main(args): args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + for model in models: + sub_dir = "unet" if type(model) == type(unet) else "text_encoder" + model.save_pretrained(os.path.join(output_dir, sub_dir)) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + while len(models) > 0: + # pop models so that they are not loaded again + model = models.pop() + + if type(model) == type(text_encoder): + # load transformer style into model + load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder") + model.config = load_model.config + else: + # load diffusers style into model + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + vae.requires_grad_(False) if not args.train_text_encoder: text_encoder.requires_grad_(False) From 423f357f140aa967a00fece2f1b538b62ae28b82 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 27 Jan 2023 19:34:21 +0200 Subject: [PATCH 02/14] up --- examples/dreambooth/train_dreambooth.py | 50 ++++++++++--------- examples/text_to_image/train_text_to_image.py | 39 +++++++++++++++ .../train_unconditional.py | 39 +++++++++++++++ 3 files changed, 105 insertions(+), 23 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 9f99f2217953..b4138123d86b 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -28,6 +28,7 @@ import torch.utils.checkpoint from torch.utils.data import Dataset +import accelerate import datasets import diffusers import transformers @@ -39,6 +40,7 @@ from diffusers.utils import check_min_version from diffusers.utils.import_utils import is_xformers_available from huggingface_hub import HfFolder, Repository, whoami +from packaging import version from PIL import Image from torchvision import transforms from tqdm.auto import tqdm @@ -585,34 +587,36 @@ def main(args): args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) - # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format - def save_model_hook(models, weights, output_dir): - for model in models: - sub_dir = "unet" if type(model) == type(unet) else "text_encoder" - model.save_pretrained(os.path.join(output_dir, sub_dir)) + # `accelerate` 0.15.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.15.0.dev0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + for model in models: + sub_dir = "unet" if type(model) == type(unet) else "text_encoder" + model.save_pretrained(os.path.join(output_dir, sub_dir)) - # make sure to pop weight so that corresponding model is not saved again - weights.pop() + # make sure to pop weight so that corresponding model is not saved again + weights.pop() - def load_model_hook(models, input_dir): - while len(models) > 0: - # pop models so that they are not loaded again - model = models.pop() + def load_model_hook(models, input_dir): + while len(models) > 0: + # pop models so that they are not loaded again + model = models.pop() - if type(model) == type(text_encoder): - # load transformer style into model - load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder") - model.config = load_model.config - else: - # load diffusers style into model - load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") - model.register_to_config(**load_model.config) + if type(model) == type(text_encoder): + # load transformer style into model + load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder") + model.config = load_model.config + else: + # load diffusers style into model + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) - model.load_state_dict(load_model.state_dict()) - del load_model + model.load_state_dict(load_model.state_dict()) + del load_model - accelerator.register_save_state_pre_hook(save_model_hook) - accelerator.register_load_state_pre_hook(load_model_hook) + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) vae.requires_grad_(False) if not args.train_text_encoder: diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index b77f90566799..dd2047b7733d 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -27,6 +27,7 @@ import torch.nn.functional as F import torch.utils.checkpoint +import accelerate import datasets import diffusers import transformers @@ -39,6 +40,7 @@ from diffusers.utils import check_min_version from diffusers.utils.import_utils import is_xformers_available from huggingface_hub import HfFolder, Repository, whoami +from packaging import version from torchvision import transforms from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer @@ -492,6 +494,43 @@ def main(): else: raise ValueError("xformers is not available. Make sure it is installed correctly") + # `accelerate` 0.15.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.15.0.dev0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + for i, model in enumerate(models): + # if we use_ema, we save it under "unet_ema" + if args.use_ema and i == 1: + sub_folder = "unet_ema" + else: + sub_folder = "unet" + + model.save_pretrained(os.path.join(output_dir, sub_folder)) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # if we use_ema, we save it under "unet_ema" + if args.use_ema and i == 1: + sub_folder = "unet_ema" + else: + sub_folder = "unet" + + # load diffusers style into model + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder=sub_folder) + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + if args.gradient_checkpointing: unet.enable_gradient_checkpointing() diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 25d4b8798363..754da31fd565 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -8,6 +8,7 @@ import torch import torch.nn.functional as F +import accelerate from accelerate import Accelerator from accelerate.logging import get_logger from datasets import load_dataset @@ -16,6 +17,7 @@ from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version from huggingface_hub import HfFolder, Repository, whoami +from packaging import version from torchvision.transforms import ( CenterCrop, Compose, @@ -262,6 +264,43 @@ def main(args): logging_dir=logging_dir, ) + # `accelerate` 0.15.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.15.0.dev0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + for i, model in enumerate(models): + # if we use_ema, we save it under "unet_ema" + if args.use_ema and i == 1: + sub_folder = "unet_ema" + else: + sub_folder = "unet" + + model.save_pretrained(os.path.join(output_dir, sub_folder)) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # if we use_ema, we save it under "unet_ema" + if args.use_ema and i == 1: + sub_folder = "unet_ema" + else: + sub_folder = "unet" + + # load diffusers style into model + load_model = UNet2DModel.from_pretrained(input_dir, subfolder=sub_folder) + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + model = UNet2DModel( sample_size=args.resolution, in_channels=3, From 0b41614c1e54212a00338a5587e97be0ccf8ef70 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 30 Jan 2023 12:12:03 +0000 Subject: [PATCH 03/14] finish --- docs/source/en/training/dreambooth.mdx | 31 +++- examples/text_to_image/train_text_to_image.py | 169 ++++-------------- .../train_unconditional.py | 34 ++-- src/diffusers/training_utils.py | 9 +- 4 files changed, 92 insertions(+), 151 deletions(-) diff --git a/docs/source/en/training/dreambooth.mdx b/docs/source/en/training/dreambooth.mdx index 7f9de798c470..6e1ecb2b1059 100644 --- a/docs/source/en/training/dreambooth.mdx +++ b/docs/source/en/training/dreambooth.mdx @@ -127,7 +127,30 @@ This would be a good opportunity to tweak some of your hyperparameters if you wi Saved checkpoints are stored in a format suitable for resuming training. They not only include the model weights, but also the state of the optimizer, data loaders and learning rate. -You can use a checkpoint for inference, but first you need to convert it to an inference pipeline. This is how you could do it: +**Note**: If you have installed `"accelerate>=0.15.0dev0"` you can use the following code to run +inference from an intermediate checkpoint. + +```python +from diffusers import DiffusionPipeline, UNet2DConditionModel +from transformers import CLIPTextModel +import torch + +# Load the pipeline with the same arguments (model, revision) that were used for training +model_id = "CompVis/stable-diffusion-v1-4" + +unet = UNet2DConditionModel.from_pretrained("/sddata/dreambooth/daruma-v2-1/checkpoint-100/unet") + +# if you have trained with `--args.train_text_encoder` make sure to also load the text encoder +text_encoder = CLIPTextModel.from_pretrained("/sddata/dreambooth/daruma-v2-1/checkpoint-100/text_encoder") + +pipeline = DiffusionPipeline.from_pretrained(model_id, unet=unet, text_encoder=text_encoder, dtype=torch.float16) +pipeline.to("cuda") + +# Perform inference, or save, or push to the hub +pipeline.save_pretrained("dreambooth-pipeline") +``` + +If you have installed `"accelerate<0.15.0dev0"` you need to first convert it to an inference pipeline. This is how you could do it: ```python from accelerate import Accelerator @@ -271,6 +294,10 @@ accelerate launch train_dreambooth.py \ Once you have trained a model, inference can be done using the `StableDiffusionPipeline`, by simply indicating the path where the model was saved. Make sure that your prompts include the special `identifier` used during training (`sks` in the previous examples). +**Note**: If you have installed `"accelerate>=0.15.0dev0"` you can use the following code to run +inference from an intermediat checkpoint. + + ```python from diffusers import StableDiffusionPipeline import torch @@ -284,4 +311,4 @@ image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0] image.save("dog-bucket.png") ``` -You may also run inference from [any of the saved training checkpoints](#performing-inference-using-a-saved-checkpoint). \ No newline at end of file +You may also run inference from [any of the saved training checkpoints](#performing-inference-using-a-saved-checkpoint). diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index dd2047b7733d..5c530063cbbe 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -14,13 +14,12 @@ # See the License for the specific language governing permissions and import argparse -import copy import logging import math import os import random from pathlib import Path -from typing import Iterable, Optional +from typing import Optional import numpy as np import torch @@ -37,7 +36,8 @@ from datasets import load_dataset from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler -from diffusers.utils import check_min_version +from diffusers.training_utils import EMAModel +from diffusers.utils import check_min_version, deprecate from diffusers.utils.import_utils import is_xformers_available from huggingface_hub import HfFolder, Repository, whoami from packaging import version @@ -201,6 +201,9 @@ def parse_args(): ), ) parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.") + parser.add_argument("--ema_power", type=float, default=3 / 4, help="The power value for the EMA decay.") + parser.add_argument("--ema_max_decay", type=float, default=0.9999, help="The maximum decay magnitude for EMA.") parser.add_argument( "--non_ema_revision", type=str, @@ -307,117 +310,18 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: } -# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 -class EMAModel: - """ - Exponential Moving Average of models weights - """ - - def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): - parameters = list(parameters) - self.shadow_params = [p.clone().detach() for p in parameters] - - self.collected_params = None - - self.decay = decay - self.optimization_step = 0 - - @torch.no_grad() - def step(self, parameters): - parameters = list(parameters) - - self.optimization_step += 1 - - # Compute the decay factor for the exponential moving average. - value = (1 + self.optimization_step) / (10 + self.optimization_step) - one_minus_decay = 1 - min(self.decay, value) - - for s_param, param in zip(self.shadow_params, parameters): - if param.requires_grad: - s_param.sub_(one_minus_decay * (s_param - param)) - else: - s_param.copy_(param) - - torch.cuda.empty_cache() - - def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: - """ - Copy current averaged parameters into given collection of parameters. - - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored moving averages. If `None`, the - parameters with which this `ExponentialMovingAverage` was - initialized will be used. - """ - parameters = list(parameters) - for s_param, param in zip(self.shadow_params, parameters): - param.data.copy_(s_param.data) - - def to(self, device=None, dtype=None) -> None: - r"""Move internal buffers of the ExponentialMovingAverage to `device`. - - Args: - device: like `device` argument to `torch.Tensor.to` - """ - # .to() on the tensors handles None correctly - self.shadow_params = [ - p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) - for p in self.shadow_params - ] - - def state_dict(self) -> dict: - r""" - Returns the state of the ExponentialMovingAverage as a dict. - This method is used by accelerate during checkpointing to save the ema state dict. - """ - # Following PyTorch conventions, references to tensors are returned: - # "returns a reference to the state and not its copy!" - - # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict - return { - "decay": self.decay, - "optimization_step": self.optimization_step, - "shadow_params": self.shadow_params, - "collected_params": self.collected_params, - } - - def load_state_dict(self, state_dict: dict) -> None: - r""" - Loads the ExponentialMovingAverage state. - This method is used by accelerate during checkpointing to save the ema state dict. - Args: - state_dict (dict): EMA state. Should be an object returned - from a call to :meth:`state_dict`. - """ - # deepcopy, to be consistent with module API - state_dict = copy.deepcopy(state_dict) - - self.decay = state_dict["decay"] - if self.decay < 0.0 or self.decay > 1.0: - raise ValueError("Decay must be between 0 and 1") - - self.optimization_step = state_dict["optimization_step"] - if not isinstance(self.optimization_step, int): - raise ValueError("Invalid optimization_step") - - self.shadow_params = state_dict["shadow_params"] - if not isinstance(self.shadow_params, list): - raise ValueError("shadow_params must be a list") - if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): - raise ValueError("shadow_params must all be Tensors") - - self.collected_params = state_dict["collected_params"] - if self.collected_params is not None: - if not isinstance(self.collected_params, list): - raise ValueError("collected_params must be a list") - if not all(isinstance(p, torch.Tensor) for p in self.collected_params): - raise ValueError("collected_params must all be Tensors") - if len(self.collected_params) != len(self.shadow_params): - raise ValueError("collected_params and shadow_params must have the same length") - - def main(): args = parse_args() + + if args.non_ema_revision is not None: + deprecate( + "non_ema_revision!=None", + "0.15.0", + message=( + "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" + " use `--variant=non_ema` instead." + ), + ) logging_dir = os.path.join(args.output_dir, args.logging_dir) accelerator = Accelerator( @@ -483,10 +387,13 @@ def main(): # Create EMA for the unet. if args.use_ema: - ema_unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + ema_unet = EMAModel( + unet.parameters(), + decay=args.ema_max_decay, + use_ema_warmup=True, + inv_gamma=args.ema_inv_gamma, + power=args.ema_power, ) - ema_unet = EMAModel(ema_unet.parameters()) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): @@ -499,13 +406,7 @@ def main(): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): for i, model in enumerate(models): - # if we use_ema, we save it under "unet_ema" - if args.use_ema and i == 1: - sub_folder = "unet_ema" - else: - sub_folder = "unet" - - model.save_pretrained(os.path.join(output_dir, sub_folder)) + model.save_pretrained(os.path.join(output_dir, "unet")) # make sure to pop weight so that corresponding model is not saved again weights.pop() @@ -515,14 +416,8 @@ def load_model_hook(models, input_dir): # pop models so that they are not loaded again model = models.pop() - # if we use_ema, we save it under "unet_ema" - if args.use_ema and i == 1: - sub_folder = "unet_ema" - else: - sub_folder = "unet" - # load diffusers style into model - load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder=sub_folder) + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") model.register_to_config(**load_model.config) model.load_state_dict(load_model.state_dict()) @@ -683,8 +578,6 @@ def collate_fn(examples): unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, optimizer, train_dataloader, lr_scheduler ) - if args.use_ema: - accelerator.register_for_checkpointing(ema_unet) # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. @@ -739,6 +632,17 @@ def collate_fn(examples): accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) + if args.use_ema: + ema_model = UNet2DConditionModel.from_pretrained(os.path.join(args.output_dir, path, "unet_ema")) + ema_model = EMAModel( + ema_model, + inv_gamma=args.ema_inv_gamma, + power=args.ema_power, + max_value=args.ema_max_decay, + device=accelerator.unwrap_model(unet).device, + optimization_step=global_step, + ) + first_epoch = global_step // num_update_steps_per_epoch resume_step = global_step % num_update_steps_per_epoch @@ -814,6 +718,9 @@ def collate_fn(examples): accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") + if args.use_ema: + ema_unet.averaged_model.save_pretrained(os.path.join(save_path, "unet_ema")) + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 754da31fd565..c9779ceeabbf 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -269,13 +269,7 @@ def main(args): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): for i, model in enumerate(models): - # if we use_ema, we save it under "unet_ema" - if args.use_ema and i == 1: - sub_folder = "unet_ema" - else: - sub_folder = "unet" - - model.save_pretrained(os.path.join(output_dir, sub_folder)) + model.save_pretrained(os.path.join(output_dir, "unet")) # make sure to pop weight so that corresponding model is not saved again weights.pop() @@ -285,14 +279,8 @@ def load_model_hook(models, input_dir): # pop models so that they are not loaded again model = models.pop() - # if we use_ema, we save it under "unet_ema" - if args.use_ema and i == 1: - sub_folder = "unet_ema" - else: - sub_folder = "unet" - # load diffusers style into model - load_model = UNet2DModel.from_pretrained(input_dir, subfolder=sub_folder) + load_model = UNet2DModel.from_pretrained(input_dir, subfolder="unet") model.register_to_config(**load_model.config) model.load_state_dict(load_model.state_dict()) @@ -384,7 +372,6 @@ def transforms(examples): model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( model, optimizer, train_dataloader, lr_scheduler ) - accelerator.register_for_checkpointing(lr_scheduler) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -432,6 +419,17 @@ def transforms(examples): accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) + if args.use_ema: + ema_model = UNet2DModel.from_pretrained(os.path.join(args.output_dir, path, "unet_ema")) + ema_model = EMAModel( + ema_model, + inv_gamma=args.ema_inv_gamma, + power=args.ema_power, + max_value=args.ema_max_decay, + device=accelerator.unwrap_model(model).device, + optimization_step=global_step, + ) + resume_global_step = global_step * args.gradient_accumulation_steps first_epoch = resume_global_step // num_update_steps_per_epoch resume_step = resume_global_step % num_update_steps_per_epoch @@ -497,6 +495,10 @@ def transforms(examples): if accelerator.is_main_process: save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) + + if args.use_ema: + ema_model.averaged_model.save_pretrained(os.path.join(save_path, "unet_ema")) + logger.info(f"Saved state to {save_path}") logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} @@ -512,7 +514,7 @@ def transforms(examples): if accelerator.is_main_process: if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1: pipeline = DDPMPipeline( - unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model), + unet=accelerator.unwrap_model(model.averaged_model if args.use_ema else model), scheduler=noise_scheduler, ) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index fefc490c1f01..1521d10281cf 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -53,6 +53,7 @@ def __init__( min_value=0.0, max_value=0.9999, device=None, + optimization_step=0, ): """ @crowsonkb's notes on EMA Warmup: @@ -79,8 +80,12 @@ def __init__( if device is not None: self.averaged_model = self.averaged_model.to(device=device) - self.decay = 0.0 - self.optimization_step = 0 + if optimization_step == 0: + self.decay = 0.0 + self.optimization_step = 0 + else: + self.optimization_step = optimization_step + self.decay = self.get_decay(optimization_step) def get_decay(self, optimization_step): """ From b551de530693c62abe163314359b6523411b52a3 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 30 Jan 2023 12:57:34 +0000 Subject: [PATCH 04/14] finish --- _ | 278 ++++++++++++++++++ examples/text_to_image/train_text_to_image.py | 30 +- .../train_unconditional.py | 35 +-- 3 files changed, 290 insertions(+), 53 deletions(-) create mode 100644 _ diff --git a/_ b/_ new file mode 100644 index 000000000000..2f8856b616ac --- /dev/null +++ b/_ @@ -0,0 +1,278 @@ +import copy +import os +import random +from typing import Iterable, Union + +import numpy as np +import torch + +from .utils import deprecate + + +def enable_full_determinism(seed: int): + """ + Helper function for reproducible behavior during distributed training. See + - https://pytorch.org/docs/stable/notes/randomness.html for pytorch + """ + # set seed first + set_seed(seed) + + # Enable PyTorch deterministic mode. This potentially requires either the environment + # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set, + # depending on the CUDA version, so we set them both here + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" + torch.use_deterministic_algorithms(True) + + # Enable CUDNN deterministic mode + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def set_seed(seed: int): + """ + Args: + Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. + seed (`int`): The seed to set. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # ^^ safe to call this function even if cuda is not available + + +# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 +class EMAModel: + """ + Exponential Moving Average of models weights + """ + + def __init__( + self, + parameters: Iterable[torch.nn.Parameter], + decay: float = 0.9999, + min_decay: float = 0.0, + update_after_step: int = 0, + use_ema_warmup: bool = False, + inv_gamma: Union[float, int] = 1.0, + power: Union[float, int] = 2 / 3, + optimization_step=0, + **kwargs, + ): + """ + Args: + parameters (Iterable[torch.nn.Parameter]): The parameters to track. + decay (float): The decay factor for the exponential moving average. + min_decay (float): The minimum decay factor for the exponential moving average. + update_after_step (int): The number of steps to wait before starting to update the EMA weights. + use_ema_warmup (bool): Whether to use EMA warmup. + inv_gamma (float): + Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. + power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. + device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA + weights will be stored on CPU. + + @crowsonkb's notes on EMA Warmup: + If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan + to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), + gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 + at 215.4k steps). + """ + + if isinstance(parameters, torch.nn.Module): + deprecation_message = ( + "Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. " + "Please pass the parameters of the module instead." + ) + deprecate( + "passing a `torch.nn.Module` to `ExponentialMovingAverage`", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + parameters = parameters.parameters() + + # set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility + use_ema_warmup = True + + if kwargs.get("max_value", None) is not None: + deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead." + deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False) + decay = kwargs["max_value"] + + if kwargs.get("min_value", None) is not None: + deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead." + deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False) + min_decay = kwargs["min_value"] + + parameters = list(parameters) + self.shadow_params = [p.clone().detach() for p in parameters] + + if kwargs.get("device", None) is not None: + deprecation_message = "The `device` argument is deprecated. Please use `to` instead." + deprecate("device", "1.0.0", deprecation_message, standard_warn=False) + self.to(device=kwargs["device"]) + + self.collected_params = None + + self.decay = decay + self.min_decay = min_decay + self.update_after_step = update_after_step + self.use_ema_warmup = use_ema_warmup + self.inv_gamma = inv_gamma + self.power = power + + if optimization_step == 0: + self.optimization_step = 0 + else: + self.optimization_step = optimization_step + + def get_decay(self, optimization_step: int) -> float: + """ + Compute the decay factor for the exponential moving average. + """ + step = max(0, optimization_step - self.update_after_step - 1) + + if step <= 0: + return 0.0 + + if self.use_ema_warmup: + cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power + else: + cur_decay_value = (1 + step) / (10 + step) + + cur_decay_value = min(cur_decay_value, self.decay) + # make sure decay is not smaller than min_decay + cur_decay_value = max(cur_decay_value, self.min_decay) + return cur_decay_value + + @torch.no_grad() + def step(self, parameters: Iterable[torch.nn.Parameter]): + if isinstance(parameters, torch.nn.Module): + deprecation_message = ( + "Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. " + "Please pass the parameters of the module instead." + ) + deprecate( + "passing a `torch.nn.Module` to `ExponentialMovingAverage.step`", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + parameters = parameters.parameters() + + parameters = list(parameters) + + self.optimization_step += 1 + + # Compute the decay factor for the exponential moving average. + decay = self.get_decay(self.optimization_step) + one_minus_decay = 1 - decay + + for s_param, param in zip(self.shadow_params, parameters): + if param.requires_grad: + s_param.sub_(one_minus_decay * (s_param - param)) + else: + s_param.copy_(param) + + torch.cuda.empty_cache() + + def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: + """ + Copy current averaged parameters into given collection of parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. If `None`, the parameters with which this + `ExponentialMovingAverage` was initialized will be used. + """ + parameters = list(parameters) + for s_param, param in zip(self.shadow_params, parameters): + param.data.copy_(s_param.data) + + def to(self, device=None, dtype=None) -> None: + r"""Move internal buffers of the ExponentialMovingAverage to `device`. + + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) + for p in self.shadow_params + ] + + def state_dict(self) -> dict: + r""" + Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during + checkpointing to save the ema state dict. + """ + # Following PyTorch conventions, references to tensors are returned: + # "returns a reference to the state and not its copy!" - + # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict + return { + "decay": self.decay, + "min_decay": self.decay, + "optimization_step": self.optimization_step, + "update_after_step": self.update_after_step, + "use_ema_warmup": self.use_ema_warmup, + "inv_gamma": self.inv_gamma, + "power": self.power, + "shadow_params": self.shadow_params, + "collected_params": self.collected_params, + } + + def load_state_dict(self, state_dict: dict) -> None: + r""" + Args: + Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the + ema state dict. + state_dict (dict): EMA state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = copy.deepcopy(state_dict) + + self.decay = state_dict.get("decay", self.decay) + if self.decay < 0.0 or self.decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.min_decay = state_dict.get("min_decay", self.min_decay) + if not isinstance(self.min_decay, float): + raise ValueError("Invalid min_decay") + + self.optimization_step = state_dict.get("optimization_step", self.optimization_step) + if not isinstance(self.optimization_step, int): + raise ValueError("Invalid optimization_step") + + self.update_after_step = state_dict.get("update_after_step", self.update_after_step) + if not isinstance(self.update_after_step, int): + raise ValueError("Invalid update_after_step") + + self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) + if not isinstance(self.use_ema_warmup, bool): + raise ValueError("Invalid use_ema_warmup") + + self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) + if not isinstance(self.inv_gamma, (float, int)): + raise ValueError("Invalid inv_gamma") + + self.power = state_dict.get("power", self.power) + if not isinstance(self.power, (float, int)): + raise ValueError("Invalid power") + + self.shadow_params = state_dict["shadow_params"] + if not isinstance(self.shadow_params, list): + raise ValueError("shadow_params must be a list") + if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): + raise ValueError("shadow_params must all be Tensors") + + self.collected_params = state_dict["collected_params"] + if self.collected_params is not None: + if not isinstance(self.collected_params, list): + raise ValueError("collected_params must be a list") + if not all(isinstance(p, torch.Tensor) for p in self.collected_params): + raise ValueError("collected_params must all be Tensors") + if len(self.collected_params) != len(self.shadow_params): + raise ValueError("collected_params and shadow_params must have the same length") diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index a8956347b196..75033f444eb6 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -41,7 +41,6 @@ from diffusers.utils.import_utils import is_xformers_available from huggingface_hub import HfFolder, Repository, create_repo, whoami from packaging import version - from torchvision import transforms from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer @@ -401,13 +400,10 @@ def main(): # Create EMA for the unet. if args.use_ema: - ema_unet = EMAModel( - unet.parameters(), - decay=args.ema_max_decay, - use_ema_warmup=True, - inv_gamma=args.ema_inv_gamma, - power=args.ema_power, + ema_unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) + ema_unet = EMAModel(ema_unet.parameters()) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): @@ -597,6 +593,10 @@ def collate_fn(examples): unet, optimizer, train_dataloader, lr_scheduler ) + if args.use_ema: + accelerator.register_for_checkpointing(ema_unet) + ema_unet.to(accelerator.device) + # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 @@ -608,8 +608,6 @@ def collate_fn(examples): # Move text_encode and vae to gpu and cast to weight_dtype text_encoder.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) - if args.use_ema: - ema_unet.to(accelerator.device) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -657,17 +655,6 @@ def collate_fn(examples): accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - if args.use_ema: - ema_model = UNet2DConditionModel.from_pretrained(os.path.join(args.output_dir, path, "unet_ema")) - ema_model = EMAModel( - ema_model, - inv_gamma=args.ema_inv_gamma, - power=args.ema_power, - max_value=args.ema_max_decay, - device=accelerator.unwrap_model(unet).device, - optimization_step=global_step, - ) - resume_global_step = global_step * args.gradient_accumulation_steps first_epoch = global_step // num_update_steps_per_epoch resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) @@ -744,9 +731,6 @@ def collate_fn(examples): accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") - if args.use_ema: - ema_unet.averaged_model.save_pretrained(os.path.join(save_path, "unet_ema")) - logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 86aaa162b656..38352fe86d3e 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -10,12 +10,9 @@ import torch import torch.nn.functional as F - import accelerate - import datasets import diffusers - from accelerate import Accelerator from accelerate.logging import get_logger from datasets import load_dataset @@ -23,7 +20,6 @@ from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version - from huggingface_hub import HfFolder, Repository, create_repo, whoami from packaging import version from torchvision.transforms import ( @@ -435,14 +431,11 @@ def transforms(examples): model, optimizer, train_dataloader, lr_scheduler ) - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.use_ema: + accelerator.register_for_checkpointing(ema_model) + ema_model.to(accelerator.device) - ema_model = EMAModel( - accelerator.unwrap_model(model), - inv_gamma=args.ema_inv_gamma, - power=args.ema_power, - max_value=args.ema_max_decay, - ) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) # Handle the repository creation if accelerator.is_main_process: @@ -453,10 +446,6 @@ def transforms(examples): repo_name = args.hub_model_id repo = Repository(args.output_dir, clone_from=repo_name) - if args.use_ema: - accelerator.register_for_checkpointing(ema_model) - ema_model.to(accelerator.device) - # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: @@ -499,17 +488,6 @@ def transforms(examples): accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - if args.use_ema: - ema_model = UNet2DModel.from_pretrained(os.path.join(args.output_dir, path, "unet_ema")) - ema_model = EMAModel( - ema_model, - inv_gamma=args.ema_inv_gamma, - power=args.ema_power, - max_value=args.ema_max_decay, - device=accelerator.unwrap_model(model).device, - optimization_step=global_step, - ) - resume_global_step = global_step * args.gradient_accumulation_steps first_epoch = global_step // num_update_steps_per_epoch resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) @@ -577,9 +555,6 @@ def transforms(examples): save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) - if args.use_ema: - ema_model.averaged_model.save_pretrained(os.path.join(save_path, "unet_ema")) - logger.info(f"Saved state to {save_path}") logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} @@ -598,7 +573,7 @@ def transforms(examples): if args.use_ema: ema_model.copy_to(unet.parameters()) pipeline = DDPMPipeline( - unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else unet), + unet=unet, scheduler=noise_scheduler, ) From 39e57c458b6a269a5cf52816130a0f138661bb92 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 30 Jan 2023 12:57:46 +0000 Subject: [PATCH 05/14] uP --- _ | 278 -------------------------------------------------------------- 1 file changed, 278 deletions(-) delete mode 100644 _ diff --git a/_ b/_ deleted file mode 100644 index 2f8856b616ac..000000000000 --- a/_ +++ /dev/null @@ -1,278 +0,0 @@ -import copy -import os -import random -from typing import Iterable, Union - -import numpy as np -import torch - -from .utils import deprecate - - -def enable_full_determinism(seed: int): - """ - Helper function for reproducible behavior during distributed training. See - - https://pytorch.org/docs/stable/notes/randomness.html for pytorch - """ - # set seed first - set_seed(seed) - - # Enable PyTorch deterministic mode. This potentially requires either the environment - # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set, - # depending on the CUDA version, so we set them both here - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" - os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" - torch.use_deterministic_algorithms(True) - - # Enable CUDNN deterministic mode - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - - -def set_seed(seed: int): - """ - Args: - Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. - seed (`int`): The seed to set. - """ - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - # ^^ safe to call this function even if cuda is not available - - -# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 -class EMAModel: - """ - Exponential Moving Average of models weights - """ - - def __init__( - self, - parameters: Iterable[torch.nn.Parameter], - decay: float = 0.9999, - min_decay: float = 0.0, - update_after_step: int = 0, - use_ema_warmup: bool = False, - inv_gamma: Union[float, int] = 1.0, - power: Union[float, int] = 2 / 3, - optimization_step=0, - **kwargs, - ): - """ - Args: - parameters (Iterable[torch.nn.Parameter]): The parameters to track. - decay (float): The decay factor for the exponential moving average. - min_decay (float): The minimum decay factor for the exponential moving average. - update_after_step (int): The number of steps to wait before starting to update the EMA weights. - use_ema_warmup (bool): Whether to use EMA warmup. - inv_gamma (float): - Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. - power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. - device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA - weights will be stored on CPU. - - @crowsonkb's notes on EMA Warmup: - If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan - to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), - gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 - at 215.4k steps). - """ - - if isinstance(parameters, torch.nn.Module): - deprecation_message = ( - "Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. " - "Please pass the parameters of the module instead." - ) - deprecate( - "passing a `torch.nn.Module` to `ExponentialMovingAverage`", - "1.0.0", - deprecation_message, - standard_warn=False, - ) - parameters = parameters.parameters() - - # set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility - use_ema_warmup = True - - if kwargs.get("max_value", None) is not None: - deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead." - deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False) - decay = kwargs["max_value"] - - if kwargs.get("min_value", None) is not None: - deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead." - deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False) - min_decay = kwargs["min_value"] - - parameters = list(parameters) - self.shadow_params = [p.clone().detach() for p in parameters] - - if kwargs.get("device", None) is not None: - deprecation_message = "The `device` argument is deprecated. Please use `to` instead." - deprecate("device", "1.0.0", deprecation_message, standard_warn=False) - self.to(device=kwargs["device"]) - - self.collected_params = None - - self.decay = decay - self.min_decay = min_decay - self.update_after_step = update_after_step - self.use_ema_warmup = use_ema_warmup - self.inv_gamma = inv_gamma - self.power = power - - if optimization_step == 0: - self.optimization_step = 0 - else: - self.optimization_step = optimization_step - - def get_decay(self, optimization_step: int) -> float: - """ - Compute the decay factor for the exponential moving average. - """ - step = max(0, optimization_step - self.update_after_step - 1) - - if step <= 0: - return 0.0 - - if self.use_ema_warmup: - cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power - else: - cur_decay_value = (1 + step) / (10 + step) - - cur_decay_value = min(cur_decay_value, self.decay) - # make sure decay is not smaller than min_decay - cur_decay_value = max(cur_decay_value, self.min_decay) - return cur_decay_value - - @torch.no_grad() - def step(self, parameters: Iterable[torch.nn.Parameter]): - if isinstance(parameters, torch.nn.Module): - deprecation_message = ( - "Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. " - "Please pass the parameters of the module instead." - ) - deprecate( - "passing a `torch.nn.Module` to `ExponentialMovingAverage.step`", - "1.0.0", - deprecation_message, - standard_warn=False, - ) - parameters = parameters.parameters() - - parameters = list(parameters) - - self.optimization_step += 1 - - # Compute the decay factor for the exponential moving average. - decay = self.get_decay(self.optimization_step) - one_minus_decay = 1 - decay - - for s_param, param in zip(self.shadow_params, parameters): - if param.requires_grad: - s_param.sub_(one_minus_decay * (s_param - param)) - else: - s_param.copy_(param) - - torch.cuda.empty_cache() - - def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: - """ - Copy current averaged parameters into given collection of parameters. - - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored moving averages. If `None`, the parameters with which this - `ExponentialMovingAverage` was initialized will be used. - """ - parameters = list(parameters) - for s_param, param in zip(self.shadow_params, parameters): - param.data.copy_(s_param.data) - - def to(self, device=None, dtype=None) -> None: - r"""Move internal buffers of the ExponentialMovingAverage to `device`. - - Args: - device: like `device` argument to `torch.Tensor.to` - """ - # .to() on the tensors handles None correctly - self.shadow_params = [ - p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) - for p in self.shadow_params - ] - - def state_dict(self) -> dict: - r""" - Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during - checkpointing to save the ema state dict. - """ - # Following PyTorch conventions, references to tensors are returned: - # "returns a reference to the state and not its copy!" - - # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict - return { - "decay": self.decay, - "min_decay": self.decay, - "optimization_step": self.optimization_step, - "update_after_step": self.update_after_step, - "use_ema_warmup": self.use_ema_warmup, - "inv_gamma": self.inv_gamma, - "power": self.power, - "shadow_params": self.shadow_params, - "collected_params": self.collected_params, - } - - def load_state_dict(self, state_dict: dict) -> None: - r""" - Args: - Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the - ema state dict. - state_dict (dict): EMA state. Should be an object returned - from a call to :meth:`state_dict`. - """ - # deepcopy, to be consistent with module API - state_dict = copy.deepcopy(state_dict) - - self.decay = state_dict.get("decay", self.decay) - if self.decay < 0.0 or self.decay > 1.0: - raise ValueError("Decay must be between 0 and 1") - - self.min_decay = state_dict.get("min_decay", self.min_decay) - if not isinstance(self.min_decay, float): - raise ValueError("Invalid min_decay") - - self.optimization_step = state_dict.get("optimization_step", self.optimization_step) - if not isinstance(self.optimization_step, int): - raise ValueError("Invalid optimization_step") - - self.update_after_step = state_dict.get("update_after_step", self.update_after_step) - if not isinstance(self.update_after_step, int): - raise ValueError("Invalid update_after_step") - - self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) - if not isinstance(self.use_ema_warmup, bool): - raise ValueError("Invalid use_ema_warmup") - - self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) - if not isinstance(self.inv_gamma, (float, int)): - raise ValueError("Invalid inv_gamma") - - self.power = state_dict.get("power", self.power) - if not isinstance(self.power, (float, int)): - raise ValueError("Invalid power") - - self.shadow_params = state_dict["shadow_params"] - if not isinstance(self.shadow_params, list): - raise ValueError("shadow_params must be a list") - if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): - raise ValueError("shadow_params must all be Tensors") - - self.collected_params = state_dict["collected_params"] - if self.collected_params is not None: - if not isinstance(self.collected_params, list): - raise ValueError("collected_params must be a list") - if not all(isinstance(p, torch.Tensor) for p in self.collected_params): - raise ValueError("collected_params must all be Tensors") - if len(self.collected_params) != len(self.shadow_params): - raise ValueError("collected_params and shadow_params must have the same length") From 6a20148309f46b6b96ffc14d472331bcb94d8e5b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 30 Jan 2023 13:01:08 +0000 Subject: [PATCH 06/14] up --- examples/text_to_image/train_text_to_image.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 75033f444eb6..f8c821674691 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -205,9 +205,6 @@ def parse_args(): ), ) parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") - parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.") - parser.add_argument("--ema_power", type=float, default=3 / 4, help="The power value for the EMA decay.") - parser.add_argument("--ema_max_decay", type=float, default=0.9999, help="The maximum decay magnitude for EMA.") parser.add_argument( "--non_ema_revision", type=str, From 190985340d44a17004c2b38ca1a2cd632710622a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 30 Jan 2023 13:01:45 +0000 Subject: [PATCH 07/14] up --- src/diffusers/training_utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 2f8856b616ac..605267f35365 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -57,7 +57,6 @@ def __init__( use_ema_warmup: bool = False, inv_gamma: Union[float, int] = 1.0, power: Union[float, int] = 2 / 3, - optimization_step=0, **kwargs, ): """ @@ -123,10 +122,7 @@ def __init__( self.inv_gamma = inv_gamma self.power = power - if optimization_step == 0: - self.optimization_step = 0 - else: - self.optimization_step = optimization_step + self.optimization_step = 0 def get_decay(self, optimization_step: int) -> float: """ From 4653b4369a3f6ed97e3122e8ff5a6137432a2069 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 30 Jan 2023 13:03:30 +0000 Subject: [PATCH 08/14] fix --- .../train_unconditional.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 38352fe86d3e..276f055b6848 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -435,17 +435,6 @@ def transforms(examples): accelerator.register_for_checkpointing(ema_model) ema_model.to(accelerator.device) - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - - # Handle the repository creation - if accelerator.is_main_process: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - repo = Repository(args.output_dir, clone_from=repo_name) - # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: From 415fb91e22eba211447ae889df065893c538699b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 30 Jan 2023 14:04:06 +0100 Subject: [PATCH 09/14] Apply suggestions from code review --- examples/unconditional_image_generation/train_unconditional.py | 1 - src/diffusers/training_utils.py | 1 - 2 files changed, 2 deletions(-) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 276f055b6848..39944d63ff79 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -543,7 +543,6 @@ def transforms(examples): if accelerator.is_main_process: save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) - logger.info(f"Saved state to {save_path}") logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 605267f35365..c5449556a12f 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -121,7 +121,6 @@ def __init__( self.use_ema_warmup = use_ema_warmup self.inv_gamma = inv_gamma self.power = power - self.optimization_step = 0 def get_decay(self, optimization_step: int) -> float: From 07b78783f05653fcf5dec94e2b66ac547b74c10e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 7 Feb 2023 16:22:14 +0000 Subject: [PATCH 10/14] correct ema --- @ | 297 ++++++++++++++++++ examples/dreambooth/train_dreambooth.py | 4 +- examples/text_to_image/train_text_to_image.py | 4 +- .../train_unconditional.py | 18 +- src/diffusers/training_utils.py | 54 +++- 5 files changed, 357 insertions(+), 20 deletions(-) create mode 100644 @ diff --git a/@ b/@ new file mode 100644 index 000000000000..6cf36e1ba2ca --- /dev/null +++ b/@ @@ -0,0 +1,297 @@ +import copy +import os +import random +from typing import Iterable, Union + +import numpy as np +import torch + +from .utils import deprecate + + +def enable_full_determinism(seed: int): + """ + Helper function for reproducible behavior during distributed training. See + - https://pytorch.org/docs/stable/notes/randomness.html for pytorch + """ + # set seed first + set_seed(seed) + + # Enable PyTorch deterministic mode. This potentially requires either the environment + # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set, + # depending on the CUDA version, so we set them both here + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" + torch.use_deterministic_algorithms(True) + + # Enable CUDNN deterministic mode + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def set_seed(seed: int): + """ + Args: + Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. + seed (`int`): The seed to set. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # ^^ safe to call this function even if cuda is not available + + +# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 +class EMAModel(torch.nn.Module): + """ + Exponential Moving Average of models weights + """ + + def __init__( + self, + parameters: Iterable[torch.nn.Parameter], + decay: float = 0.9999, + min_decay: float = 0.0, + update_after_step: int = 0, + use_ema_warmup: bool = False, + inv_gamma: Union[float, int] = 1.0, + power: Union[float, int] = 2 / 3, + model_cls: Optional[Any] = None, + model_config: Dict[str, Any] = None, + **kwargs, + ): + super().__init__() + """ + Args: + parameters (Iterable[torch.nn.Parameter]): The parameters to track. + decay (float): The decay factor for the exponential moving average. + min_decay (float): The minimum decay factor for the exponential moving average. + update_after_step (int): The number of steps to wait before starting to update the EMA weights. + use_ema_warmup (bool): Whether to use EMA warmup. + inv_gamma (float): + Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. + power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. + device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA + weights will be stored on CPU. + + @crowsonkb's notes on EMA Warmup: + If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan + to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), + gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 + at 215.4k steps). + """ + + if isinstance(parameters, torch.nn.Module): + deprecation_message = ( + "Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. " + "Please pass the parameters of the module instead." + ) + deprecate( + "passing a `torch.nn.Module` to `ExponentialMovingAverage`", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + parameters = parameters.parameters() + + # set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility + use_ema_warmup = True + + if kwargs.get("max_value", None) is not None: + deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead." + deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False) + decay = kwargs["max_value"] + + if kwargs.get("min_value", None) is not None: + deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead." + deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False) + min_decay = kwargs["min_value"] + + parameters = list(parameters) + self.shadow_params = [p.clone().detach() for p in parameters] + + if kwargs.get("device", None) is not None: + deprecation_message = "The `device` argument is deprecated. Please use `to` instead." + deprecate("device", "1.0.0", deprecation_message, standard_warn=False) + self.to(device=kwargs["device"]) + + self.collected_params = None + + self.decay = decay + self.min_decay = min_decay + self.update_after_step = update_after_step + self.use_ema_warmup = use_ema_warmup + self.inv_gamma = inv_gamma + self.power = power + self.optimization_step = 0 + + self.model_cls = model_cls + self.model_config = model_config + + def from_pretrained(cls, path) -> "EMAModel": + _, ema_kwargs = cls.model_cls.load_config(path, return_unused_kwargs=True) + model = cls.model_cls.from_pretrained(path) + + parameters = list(model.parameters()) + ema_model = cls(parameters, model_config=model.config) + + ema_model.load_state_dict(**ema_kwargs) + return ema_model + + def save_pretrained(self, path): + state_dict = self.shadow_params + model = self.model_cls(self.config) + self.copy_to(model.parameters()) + + model.save_pretrained(path) + + def get_decay(self, optimization_step: int) -> float: + """ + Compute the decay factor for the exponential moving average. + """ + step = max(0, optimization_step - self.update_after_step - 1) + + if step <= 0: + return 0.0 + + if self.use_ema_warmup: + cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power + else: + cur_decay_value = (1 + step) / (10 + step) + + cur_decay_value = min(cur_decay_value, self.decay) + # make sure decay is not smaller than min_decay + cur_decay_value = max(cur_decay_value, self.min_decay) + return cur_decay_value + + @torch.no_grad() + def step(self, parameters: Iterable[torch.nn.Parameter]): + if isinstance(parameters, torch.nn.Module): + deprecation_message = ( + "Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. " + "Please pass the parameters of the module instead." + ) + deprecate( + "passing a `torch.nn.Module` to `ExponentialMovingAverage.step`", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + parameters = parameters.parameters() + + parameters = list(parameters) + + self.optimization_step += 1 + + # Compute the decay factor for the exponential moving average. + decay = self.get_decay(self.optimization_step) + one_minus_decay = 1 - decay + + for s_param, param in zip(self.shadow_params, parameters): + if param.requires_grad: + s_param.sub_(one_minus_decay * (s_param - param)) + else: + s_param.copy_(param) + + torch.cuda.empty_cache() + + def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: + """ + Copy current averaged parameters into given collection of parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. If `None`, the parameters with which this + `ExponentialMovingAverage` was initialized will be used. + """ + parameters = list(parameters) + for s_param, param in zip(self.shadow_params, parameters): + param.data.copy_(s_param.data) + + def to(self, device=None, dtype=None) -> None: + r"""Move internal buffers of the ExponentialMovingAverage to `device`. + + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) + for p in self.shadow_params + ] + + def state_dict(self) -> dict: + r""" + Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during + checkpointing to save the ema state dict. + """ + # Following PyTorch conventions, references to tensors are returned: + # "returns a reference to the state and not its copy!" - + # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict + return { + "decay": self.decay, + "min_decay": self.decay, + "optimization_step": self.optimization_step, + "update_after_step": self.update_after_step, + "use_ema_warmup": self.use_ema_warmup, + "inv_gamma": self.inv_gamma, + "power": self.power, + "shadow_params": self.shadow_params, + "collected_params": self.collected_params, + } + + def load_state_dict(self, state_dict: dict) -> None: + r""" + Args: + Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the + ema state dict. + state_dict (dict): EMA state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = copy.deepcopy(state_dict) + + self.decay = state_dict.get("decay", self.decay) + if self.decay < 0.0 or self.decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.min_decay = state_dict.get("min_decay", self.min_decay) + if not isinstance(self.min_decay, float): + raise ValueError("Invalid min_decay") + + self.optimization_step = state_dict.get("optimization_step", self.optimization_step) + if not isinstance(self.optimization_step, int): + raise ValueError("Invalid optimization_step") + + self.update_after_step = state_dict.get("update_after_step", self.update_after_step) + if not isinstance(self.update_after_step, int): + raise ValueError("Invalid update_after_step") + + self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) + if not isinstance(self.use_ema_warmup, bool): + raise ValueError("Invalid use_ema_warmup") + + self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) + if not isinstance(self.inv_gamma, (float, int)): + raise ValueError("Invalid inv_gamma") + + self.power = state_dict.get("power", self.power) + if not isinstance(self.power, (float, int)): + raise ValueError("Invalid power") + + self.shadow_params = state_dict.get("shadow_params", None) + if self.shadow_params is not None: + if not isinstance(self.shadow_params, list): + raise ValueError("shadow_params must be a list") + if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): + raise ValueError("shadow_params must all be Tensors") + + self.collected_params = state_dict.get("collected_params", None) + if self.collected_params is not None: + if not isinstance(self.collected_params, list): + raise ValueError("collected_params must be a list") + if not all(isinstance(p, torch.Tensor) for p in self.collected_params): + raise ValueError("collected_params must all be Tensors") + if len(self.collected_params) != len(self.shadow_params): + raise ValueError("collected_params and shadow_params must have the same length") diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 6e1828523a7c..989c2604a96d 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -611,8 +611,8 @@ def main(args): args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) - # `accelerate` 0.15.0 will have better support for customized saving - if version.parse(accelerate.__version__) >= version.parse("0.15.0.dev0"): + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): for model in models: diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index f8c821674691..1afceabad5ee 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -408,8 +408,8 @@ def main(): else: raise ValueError("xformers is not available. Make sure it is installed correctly") - # `accelerate` 0.15.0 will have better support for customized saving - if version.parse(accelerate.__version__) >= version.parse("0.15.0.dev0"): + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): for i, model in enumerate(models): diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 39944d63ff79..0bcc80081a2d 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -267,10 +267,13 @@ def main(args): logging_dir=logging_dir, ) - # `accelerate` 0.15.0 will have better support for customized saving - if version.parse(accelerate.__version__) >= version.parse("0.15.0.dev0"): + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): + if args.use_ema: + ema_model.save_pretrained(os.path.join(output_dir, "unet_ema")) + for i, model in enumerate(models): model.save_pretrained(os.path.join(output_dir, "unet")) @@ -278,6 +281,11 @@ def save_model_hook(models, weights, output_dir): weights.pop() def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DModel) + ema_model.load_state_dict(load_model.state_dict()) + del load_model + for i in range(len(models)): # pop models so that they are not loaded again model = models.pop() @@ -357,6 +365,8 @@ def load_model_hook(models, input_dir): use_ema_warmup=True, inv_gamma=args.ema_inv_gamma, power=args.ema_power, + model_cls=UNet2DModel, + model_config=model.config, ) # Initialize the scheduler @@ -431,10 +441,6 @@ def transforms(examples): model, optimizer, train_dataloader, lr_scheduler ) - if args.use_ema: - accelerator.register_for_checkpointing(ema_model) - ema_model.to(accelerator.device) - # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index c5449556a12f..960109edc0d6 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -1,7 +1,7 @@ import copy import os import random -from typing import Iterable, Union +from typing import Any, Dict, Iterable, Optional, Union import numpy as np import torch @@ -57,8 +57,11 @@ def __init__( use_ema_warmup: bool = False, inv_gamma: Union[float, int] = 1.0, power: Union[float, int] = 2 / 3, + model_cls: Optional[Any] = None, + model_config: Dict[str, Any] = None, **kwargs, ): + super().__init__() """ Args: parameters (Iterable[torch.nn.Parameter]): The parameters to track. @@ -123,6 +126,35 @@ def __init__( self.power = power self.optimization_step = 0 + self.model_cls = model_cls + self.model_config = model_config + + @classmethod + def from_pretrained(cls, path, model_cls) -> "EMAModel": + _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True) + model = model_cls.from_pretrained(path) + + ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config) + + ema_model.load_state_dict(ema_kwargs) + return ema_model + + def save_pretrained(self, path): + if self.model_cls is None: + raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.") + + if self.model_config is None: + raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.") + + model = self.model_cls.from_config(self.model_config) + state_dict = self.state_dict() + state_dict.pop("shadow_params", None) + state_dict.pop("collected_params", None) + + model.register_to_config(**state_dict) + self.copy_to(model.parameters()) + model.save_pretrained(path) + def get_decay(self, optimization_step: int) -> float: """ Compute the decay factor for the exponential moving average. @@ -167,9 +199,9 @@ def step(self, parameters: Iterable[torch.nn.Parameter]): for s_param, param in zip(self.shadow_params, parameters): if param.requires_grad: - s_param.sub_(one_minus_decay * (s_param - param)) + s_param.sub_(one_minus_decay * (s_param - param.cpu())) else: - s_param.copy_(param) + s_param.copy_(param.cpu()) torch.cuda.empty_cache() @@ -184,7 +216,7 @@ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: """ parameters = list(parameters) for s_param, param in zip(self.shadow_params, parameters): - param.data.copy_(s_param.data) + param.data.copy_(s_param.to(param.device).data) def to(self, device=None, dtype=None) -> None: r"""Move internal buffers of the ExponentialMovingAverage to `device`. @@ -257,13 +289,15 @@ def load_state_dict(self, state_dict: dict) -> None: if not isinstance(self.power, (float, int)): raise ValueError("Invalid power") - self.shadow_params = state_dict["shadow_params"] - if not isinstance(self.shadow_params, list): - raise ValueError("shadow_params must be a list") - if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): - raise ValueError("shadow_params must all be Tensors") + shadow_params = state_dict.get("shadow_params", None) + if shadow_params is not None: + self.shadow_params = shadow_params + if not isinstance(self.shadow_params, list): + raise ValueError("shadow_params must be a list") + if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): + raise ValueError("shadow_params must all be Tensors") - self.collected_params = state_dict["collected_params"] + self.collected_params = state_dict.get("collected_params", None) if self.collected_params is not None: if not isinstance(self.collected_params, list): raise ValueError("collected_params must be a list") From 3c0fe6b3dab284f42c72b08ebf4423bccdccce81 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 7 Feb 2023 16:22:24 +0000 Subject: [PATCH 11/14] Remove @ --- @ | 297 -------------------------------------------------------------- 1 file changed, 297 deletions(-) delete mode 100644 @ diff --git a/@ b/@ deleted file mode 100644 index 6cf36e1ba2ca..000000000000 --- a/@ +++ /dev/null @@ -1,297 +0,0 @@ -import copy -import os -import random -from typing import Iterable, Union - -import numpy as np -import torch - -from .utils import deprecate - - -def enable_full_determinism(seed: int): - """ - Helper function for reproducible behavior during distributed training. See - - https://pytorch.org/docs/stable/notes/randomness.html for pytorch - """ - # set seed first - set_seed(seed) - - # Enable PyTorch deterministic mode. This potentially requires either the environment - # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set, - # depending on the CUDA version, so we set them both here - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" - os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" - torch.use_deterministic_algorithms(True) - - # Enable CUDNN deterministic mode - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - - -def set_seed(seed: int): - """ - Args: - Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. - seed (`int`): The seed to set. - """ - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - # ^^ safe to call this function even if cuda is not available - - -# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 -class EMAModel(torch.nn.Module): - """ - Exponential Moving Average of models weights - """ - - def __init__( - self, - parameters: Iterable[torch.nn.Parameter], - decay: float = 0.9999, - min_decay: float = 0.0, - update_after_step: int = 0, - use_ema_warmup: bool = False, - inv_gamma: Union[float, int] = 1.0, - power: Union[float, int] = 2 / 3, - model_cls: Optional[Any] = None, - model_config: Dict[str, Any] = None, - **kwargs, - ): - super().__init__() - """ - Args: - parameters (Iterable[torch.nn.Parameter]): The parameters to track. - decay (float): The decay factor for the exponential moving average. - min_decay (float): The minimum decay factor for the exponential moving average. - update_after_step (int): The number of steps to wait before starting to update the EMA weights. - use_ema_warmup (bool): Whether to use EMA warmup. - inv_gamma (float): - Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. - power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. - device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA - weights will be stored on CPU. - - @crowsonkb's notes on EMA Warmup: - If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan - to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), - gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 - at 215.4k steps). - """ - - if isinstance(parameters, torch.nn.Module): - deprecation_message = ( - "Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. " - "Please pass the parameters of the module instead." - ) - deprecate( - "passing a `torch.nn.Module` to `ExponentialMovingAverage`", - "1.0.0", - deprecation_message, - standard_warn=False, - ) - parameters = parameters.parameters() - - # set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility - use_ema_warmup = True - - if kwargs.get("max_value", None) is not None: - deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead." - deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False) - decay = kwargs["max_value"] - - if kwargs.get("min_value", None) is not None: - deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead." - deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False) - min_decay = kwargs["min_value"] - - parameters = list(parameters) - self.shadow_params = [p.clone().detach() for p in parameters] - - if kwargs.get("device", None) is not None: - deprecation_message = "The `device` argument is deprecated. Please use `to` instead." - deprecate("device", "1.0.0", deprecation_message, standard_warn=False) - self.to(device=kwargs["device"]) - - self.collected_params = None - - self.decay = decay - self.min_decay = min_decay - self.update_after_step = update_after_step - self.use_ema_warmup = use_ema_warmup - self.inv_gamma = inv_gamma - self.power = power - self.optimization_step = 0 - - self.model_cls = model_cls - self.model_config = model_config - - def from_pretrained(cls, path) -> "EMAModel": - _, ema_kwargs = cls.model_cls.load_config(path, return_unused_kwargs=True) - model = cls.model_cls.from_pretrained(path) - - parameters = list(model.parameters()) - ema_model = cls(parameters, model_config=model.config) - - ema_model.load_state_dict(**ema_kwargs) - return ema_model - - def save_pretrained(self, path): - state_dict = self.shadow_params - model = self.model_cls(self.config) - self.copy_to(model.parameters()) - - model.save_pretrained(path) - - def get_decay(self, optimization_step: int) -> float: - """ - Compute the decay factor for the exponential moving average. - """ - step = max(0, optimization_step - self.update_after_step - 1) - - if step <= 0: - return 0.0 - - if self.use_ema_warmup: - cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power - else: - cur_decay_value = (1 + step) / (10 + step) - - cur_decay_value = min(cur_decay_value, self.decay) - # make sure decay is not smaller than min_decay - cur_decay_value = max(cur_decay_value, self.min_decay) - return cur_decay_value - - @torch.no_grad() - def step(self, parameters: Iterable[torch.nn.Parameter]): - if isinstance(parameters, torch.nn.Module): - deprecation_message = ( - "Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. " - "Please pass the parameters of the module instead." - ) - deprecate( - "passing a `torch.nn.Module` to `ExponentialMovingAverage.step`", - "1.0.0", - deprecation_message, - standard_warn=False, - ) - parameters = parameters.parameters() - - parameters = list(parameters) - - self.optimization_step += 1 - - # Compute the decay factor for the exponential moving average. - decay = self.get_decay(self.optimization_step) - one_minus_decay = 1 - decay - - for s_param, param in zip(self.shadow_params, parameters): - if param.requires_grad: - s_param.sub_(one_minus_decay * (s_param - param)) - else: - s_param.copy_(param) - - torch.cuda.empty_cache() - - def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: - """ - Copy current averaged parameters into given collection of parameters. - - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored moving averages. If `None`, the parameters with which this - `ExponentialMovingAverage` was initialized will be used. - """ - parameters = list(parameters) - for s_param, param in zip(self.shadow_params, parameters): - param.data.copy_(s_param.data) - - def to(self, device=None, dtype=None) -> None: - r"""Move internal buffers of the ExponentialMovingAverage to `device`. - - Args: - device: like `device` argument to `torch.Tensor.to` - """ - # .to() on the tensors handles None correctly - self.shadow_params = [ - p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) - for p in self.shadow_params - ] - - def state_dict(self) -> dict: - r""" - Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during - checkpointing to save the ema state dict. - """ - # Following PyTorch conventions, references to tensors are returned: - # "returns a reference to the state and not its copy!" - - # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict - return { - "decay": self.decay, - "min_decay": self.decay, - "optimization_step": self.optimization_step, - "update_after_step": self.update_after_step, - "use_ema_warmup": self.use_ema_warmup, - "inv_gamma": self.inv_gamma, - "power": self.power, - "shadow_params": self.shadow_params, - "collected_params": self.collected_params, - } - - def load_state_dict(self, state_dict: dict) -> None: - r""" - Args: - Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the - ema state dict. - state_dict (dict): EMA state. Should be an object returned - from a call to :meth:`state_dict`. - """ - # deepcopy, to be consistent with module API - state_dict = copy.deepcopy(state_dict) - - self.decay = state_dict.get("decay", self.decay) - if self.decay < 0.0 or self.decay > 1.0: - raise ValueError("Decay must be between 0 and 1") - - self.min_decay = state_dict.get("min_decay", self.min_decay) - if not isinstance(self.min_decay, float): - raise ValueError("Invalid min_decay") - - self.optimization_step = state_dict.get("optimization_step", self.optimization_step) - if not isinstance(self.optimization_step, int): - raise ValueError("Invalid optimization_step") - - self.update_after_step = state_dict.get("update_after_step", self.update_after_step) - if not isinstance(self.update_after_step, int): - raise ValueError("Invalid update_after_step") - - self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) - if not isinstance(self.use_ema_warmup, bool): - raise ValueError("Invalid use_ema_warmup") - - self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) - if not isinstance(self.inv_gamma, (float, int)): - raise ValueError("Invalid inv_gamma") - - self.power = state_dict.get("power", self.power) - if not isinstance(self.power, (float, int)): - raise ValueError("Invalid power") - - self.shadow_params = state_dict.get("shadow_params", None) - if self.shadow_params is not None: - if not isinstance(self.shadow_params, list): - raise ValueError("shadow_params must be a list") - if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): - raise ValueError("shadow_params must all be Tensors") - - self.collected_params = state_dict.get("collected_params", None) - if self.collected_params is not None: - if not isinstance(self.collected_params, list): - raise ValueError("collected_params must be a list") - if not all(isinstance(p, torch.Tensor) for p in self.collected_params): - raise ValueError("collected_params must all be Tensors") - if len(self.collected_params) != len(self.shadow_params): - raise ValueError("collected_params and shadow_params must have the same length") From 7e2dc2f2239750c3003be4744d6b2d3274af54e7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 7 Feb 2023 16:56:05 +0000 Subject: [PATCH 12/14] up --- examples/text_to_image/train_text_to_image.py | 9 ++++++++- .../train_unconditional.py | 4 ++++ src/diffusers/training_utils.py | 4 ++-- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 1afceabad5ee..e15781aacec7 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -412,6 +412,9 @@ def main(): if version.parse(accelerate.__version__) >= version.parse("0.16.0"): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): + if args.use_ema: + ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) + for i, model in enumerate(models): model.save_pretrained(os.path.join(output_dir, "unet")) @@ -419,6 +422,11 @@ def save_model_hook(models, weights, output_dir): weights.pop() def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) + ema_unet.load_state_dict(load_model.state_dict()) + del load_model + for i in range(len(models)): # pop models so that they are not loaded again model = models.pop() @@ -591,7 +599,6 @@ def collate_fn(examples): ) if args.use_ema: - accelerator.register_for_checkpointing(ema_unet) ema_unet.to(accelerator.device) # For mixed precision training we cast the text_encoder and vae weights to half-precision diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index f860f77b9e31..bdba3e7805be 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -290,6 +290,7 @@ def load_model_hook(models, input_dir): if args.use_ema: load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DModel) ema_model.load_state_dict(load_model.state_dict()) + ema_model.to(accelerator.device) del load_model for i in range(len(models)): @@ -447,6 +448,9 @@ def transform_images(examples): model, optimizer, train_dataloader, lr_scheduler ) + if args.use_ema: + ema_model.to(accelerator.device) + # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 960109edc0d6..f645a9d4365b 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -199,9 +199,9 @@ def step(self, parameters: Iterable[torch.nn.Parameter]): for s_param, param in zip(self.shadow_params, parameters): if param.requires_grad: - s_param.sub_(one_minus_decay * (s_param - param.cpu())) + s_param.sub_(one_minus_decay * (s_param - param)) else: - s_param.copy_(param.cpu()) + s_param.copy_(param) torch.cuda.empty_cache() From 60aa8e1766263aa007977d4acc82a339afadd2d8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 7 Feb 2023 17:59:10 +0100 Subject: [PATCH 13/14] Apply suggestions from code review Co-authored-by: Pedro Cuenca --- docs/source/en/training/dreambooth.mdx | 6 +++--- examples/dreambooth/train_dreambooth.py | 2 +- src/diffusers/training_utils.py | 1 - 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/docs/source/en/training/dreambooth.mdx b/docs/source/en/training/dreambooth.mdx index 6e1ecb2b1059..abac7602addb 100644 --- a/docs/source/en/training/dreambooth.mdx +++ b/docs/source/en/training/dreambooth.mdx @@ -127,7 +127,7 @@ This would be a good opportunity to tweak some of your hyperparameters if you wi Saved checkpoints are stored in a format suitable for resuming training. They not only include the model weights, but also the state of the optimizer, data loaders and learning rate. -**Note**: If you have installed `"accelerate>=0.15.0dev0"` you can use the following code to run +**Note**: If you have installed `"accelerate>=0.16.0"` you can use the following code to run inference from an intermediate checkpoint. ```python @@ -294,8 +294,8 @@ accelerate launch train_dreambooth.py \ Once you have trained a model, inference can be done using the `StableDiffusionPipeline`, by simply indicating the path where the model was saved. Make sure that your prompts include the special `identifier` used during training (`sks` in the previous examples). -**Note**: If you have installed `"accelerate>=0.15.0dev0"` you can use the following code to run -inference from an intermediat checkpoint. +**Note**: If you have installed `"accelerate>=0.16.0"` you can use the following code to run +inference from an intermediate checkpoint. ```python diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index fd1282cbc6df..6d2a8a97cc03 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -625,7 +625,7 @@ def load_model_hook(models, input_dir): model = models.pop() if type(model) == type(text_encoder): - # load transformer style into model + # load transformers style into model load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder") model.config = load_model.config else: diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index f645a9d4365b..5df1c1bd6373 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -61,7 +61,6 @@ def __init__( model_config: Dict[str, Any] = None, **kwargs, ): - super().__init__() """ Args: parameters (Iterable[torch.nn.Parameter]): The parameters to track. From a8b6a0ae5b3025643be8b5f7b64d49a805912d92 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 7 Feb 2023 19:40:15 +0100 Subject: [PATCH 14/14] Update docs/source/en/training/dreambooth.mdx Co-authored-by: Pedro Cuenca --- docs/source/en/training/dreambooth.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/training/dreambooth.mdx b/docs/source/en/training/dreambooth.mdx index abac7602addb..5ff5cca4bf82 100644 --- a/docs/source/en/training/dreambooth.mdx +++ b/docs/source/en/training/dreambooth.mdx @@ -150,7 +150,7 @@ pipeline.to("cuda") pipeline.save_pretrained("dreambooth-pipeline") ``` -If you have installed `"accelerate<0.15.0dev0"` you need to first convert it to an inference pipeline. This is how you could do it: +If you have installed `"accelerate<0.16.0"` you need to first convert it to an inference pipeline. This is how you could do it: ```python from accelerate import Accelerator