diff --git a/docs/source/en/training/dreambooth.mdx b/docs/source/en/training/dreambooth.mdx index 7f9de798c470..5ff5cca4bf82 100644 --- a/docs/source/en/training/dreambooth.mdx +++ b/docs/source/en/training/dreambooth.mdx @@ -127,7 +127,30 @@ This would be a good opportunity to tweak some of your hyperparameters if you wi Saved checkpoints are stored in a format suitable for resuming training. They not only include the model weights, but also the state of the optimizer, data loaders and learning rate. -You can use a checkpoint for inference, but first you need to convert it to an inference pipeline. This is how you could do it: +**Note**: If you have installed `"accelerate>=0.16.0"` you can use the following code to run +inference from an intermediate checkpoint. + +```python +from diffusers import DiffusionPipeline, UNet2DConditionModel +from transformers import CLIPTextModel +import torch + +# Load the pipeline with the same arguments (model, revision) that were used for training +model_id = "CompVis/stable-diffusion-v1-4" + +unet = UNet2DConditionModel.from_pretrained("/sddata/dreambooth/daruma-v2-1/checkpoint-100/unet") + +# if you have trained with `--args.train_text_encoder` make sure to also load the text encoder +text_encoder = CLIPTextModel.from_pretrained("/sddata/dreambooth/daruma-v2-1/checkpoint-100/text_encoder") + +pipeline = DiffusionPipeline.from_pretrained(model_id, unet=unet, text_encoder=text_encoder, dtype=torch.float16) +pipeline.to("cuda") + +# Perform inference, or save, or push to the hub +pipeline.save_pretrained("dreambooth-pipeline") +``` + +If you have installed `"accelerate<0.16.0"` you need to first convert it to an inference pipeline. This is how you could do it: ```python from accelerate import Accelerator @@ -271,6 +294,10 @@ accelerate launch train_dreambooth.py \ Once you have trained a model, inference can be done using the `StableDiffusionPipeline`, by simply indicating the path where the model was saved. Make sure that your prompts include the special `identifier` used during training (`sks` in the previous examples). +**Note**: If you have installed `"accelerate>=0.16.0"` you can use the following code to run +inference from an intermediate checkpoint. + + ```python from diffusers import StableDiffusionPipeline import torch @@ -284,4 +311,4 @@ image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0] image.save("dog-bucket.png") ``` -You may also run inference from [any of the saved training checkpoints](#performing-inference-using-a-saved-checkpoint). \ No newline at end of file +You may also run inference from [any of the saved training checkpoints](#performing-inference-using-a-saved-checkpoint). diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 9cbac6f13fed..6d2a8a97cc03 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -28,6 +28,7 @@ import torch.utils.checkpoint from torch.utils.data import Dataset +import accelerate import diffusers import transformers from accelerate import Accelerator @@ -38,6 +39,7 @@ from diffusers.utils import check_min_version from diffusers.utils.import_utils import is_xformers_available from huggingface_hub import HfFolder, Repository, create_repo, whoami +from packaging import version from PIL import Image from torchvision import transforms from tqdm.auto import tqdm @@ -606,6 +608,37 @@ def main(args): args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + for model in models: + sub_dir = "unet" if type(model) == type(unet) else "text_encoder" + model.save_pretrained(os.path.join(output_dir, sub_dir)) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + while len(models) > 0: + # pop models so that they are not loaded again + model = models.pop() + + if type(model) == type(text_encoder): + # load transformers style into model + load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder") + model.config = load_model.config + else: + # load diffusers style into model + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + vae.requires_grad_(False) if not args.train_text_encoder: text_encoder.requires_grad_(False) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 7f6ddeaee135..e15781aacec7 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -26,6 +26,7 @@ import torch.nn.functional as F import torch.utils.checkpoint +import accelerate import datasets import diffusers import transformers @@ -36,9 +37,10 @@ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel -from diffusers.utils import check_min_version +from diffusers.utils import check_min_version, deprecate from diffusers.utils.import_utils import is_xformers_available from huggingface_hub import HfFolder, Repository, create_repo, whoami +from packaging import version from torchvision import transforms from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer @@ -319,6 +321,16 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: def main(): args = parse_args() + + if args.non_ema_revision is not None: + deprecate( + "non_ema_revision!=None", + "0.15.0", + message=( + "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" + " use `--variant=non_ema` instead." + ), + ) logging_dir = os.path.join(args.output_dir, args.logging_dir) accelerator = Accelerator( @@ -396,6 +408,39 @@ def main(): else: raise ValueError("xformers is not available. Make sure it is installed correctly") + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if args.use_ema: + ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "unet")) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) + ema_unet.load_state_dict(load_model.state_dict()) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -552,8 +597,9 @@ def collate_fn(examples): unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, optimizer, train_dataloader, lr_scheduler ) + if args.use_ema: - accelerator.register_for_checkpointing(ema_unet) + ema_unet.to(accelerator.device) # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. @@ -566,8 +612,6 @@ def collate_fn(examples): # Move text_encode and vae to gpu and cast to weight_dtype text_encoder.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) - if args.use_ema: - ema_unet.to(accelerator.device) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index d4df7adacb88..bdba3e7805be 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -9,6 +9,7 @@ import torch import torch.nn.functional as F +import accelerate import datasets import diffusers from accelerate import Accelerator @@ -19,6 +20,7 @@ from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version from huggingface_hub import HfFolder, Repository, create_repo, whoami +from packaging import version from torchvision import transforms from tqdm.auto import tqdm @@ -271,6 +273,40 @@ def main(args): logging_dir=logging_dir, ) + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if args.use_ema: + ema_model.save_pretrained(os.path.join(output_dir, "unet_ema")) + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "unet")) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DModel) + ema_model.load_state_dict(load_model.state_dict()) + ema_model.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = UNet2DModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", @@ -336,6 +372,8 @@ def main(args): use_ema_warmup=True, inv_gamma=args.ema_inv_gamma, power=args.ema_power, + model_cls=UNet2DModel, + model_config=model.config, ) # Initialize the scheduler @@ -411,7 +449,6 @@ def transform_images(examples): ) if args.use_ema: - accelerator.register_for_checkpointing(ema_model) ema_model.to(accelerator.device) # We need to initialize the trackers we use, and also store our configuration. diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index c5449556a12f..5df1c1bd6373 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -1,7 +1,7 @@ import copy import os import random -from typing import Iterable, Union +from typing import Any, Dict, Iterable, Optional, Union import numpy as np import torch @@ -57,6 +57,8 @@ def __init__( use_ema_warmup: bool = False, inv_gamma: Union[float, int] = 1.0, power: Union[float, int] = 2 / 3, + model_cls: Optional[Any] = None, + model_config: Dict[str, Any] = None, **kwargs, ): """ @@ -123,6 +125,35 @@ def __init__( self.power = power self.optimization_step = 0 + self.model_cls = model_cls + self.model_config = model_config + + @classmethod + def from_pretrained(cls, path, model_cls) -> "EMAModel": + _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True) + model = model_cls.from_pretrained(path) + + ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config) + + ema_model.load_state_dict(ema_kwargs) + return ema_model + + def save_pretrained(self, path): + if self.model_cls is None: + raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.") + + if self.model_config is None: + raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.") + + model = self.model_cls.from_config(self.model_config) + state_dict = self.state_dict() + state_dict.pop("shadow_params", None) + state_dict.pop("collected_params", None) + + model.register_to_config(**state_dict) + self.copy_to(model.parameters()) + model.save_pretrained(path) + def get_decay(self, optimization_step: int) -> float: """ Compute the decay factor for the exponential moving average. @@ -184,7 +215,7 @@ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: """ parameters = list(parameters) for s_param, param in zip(self.shadow_params, parameters): - param.data.copy_(s_param.data) + param.data.copy_(s_param.to(param.device).data) def to(self, device=None, dtype=None) -> None: r"""Move internal buffers of the ExponentialMovingAverage to `device`. @@ -257,13 +288,15 @@ def load_state_dict(self, state_dict: dict) -> None: if not isinstance(self.power, (float, int)): raise ValueError("Invalid power") - self.shadow_params = state_dict["shadow_params"] - if not isinstance(self.shadow_params, list): - raise ValueError("shadow_params must be a list") - if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): - raise ValueError("shadow_params must all be Tensors") + shadow_params = state_dict.get("shadow_params", None) + if shadow_params is not None: + self.shadow_params = shadow_params + if not isinstance(self.shadow_params, list): + raise ValueError("shadow_params must be a list") + if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): + raise ValueError("shadow_params must all be Tensors") - self.collected_params = state_dict["collected_params"] + self.collected_params = state_dict.get("collected_params", None) if self.collected_params is not None: if not isinstance(self.collected_params, list): raise ValueError("collected_params must be a list")