Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions docs/source/en/training/dreambooth.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,30 @@ This would be a good opportunity to tweak some of your hyperparameters if you wi

Saved checkpoints are stored in a format suitable for resuming training. They not only include the model weights, but also the state of the optimizer, data loaders and learning rate.

You can use a checkpoint for inference, but first you need to convert it to an inference pipeline. This is how you could do it:
**Note**: If you have installed `"accelerate>=0.16.0"` you can use the following code to run
inference from an intermediate checkpoint.

```python
from diffusers import DiffusionPipeline, UNet2DConditionModel
from transformers import CLIPTextModel
import torch

# Load the pipeline with the same arguments (model, revision) that were used for training
model_id = "CompVis/stable-diffusion-v1-4"

unet = UNet2DConditionModel.from_pretrained("/sddata/dreambooth/daruma-v2-1/checkpoint-100/unet")

# if you have trained with `--args.train_text_encoder` make sure to also load the text encoder
text_encoder = CLIPTextModel.from_pretrained("/sddata/dreambooth/daruma-v2-1/checkpoint-100/text_encoder")

pipeline = DiffusionPipeline.from_pretrained(model_id, unet=unet, text_encoder=text_encoder, dtype=torch.float16)
pipeline.to("cuda")

# Perform inference, or save, or push to the hub
pipeline.save_pretrained("dreambooth-pipeline")
```

If you have installed `"accelerate<0.16.0"` you need to first convert it to an inference pipeline. This is how you could do it:

```python
from accelerate import Accelerator
Expand Down Expand Up @@ -271,6 +294,10 @@ accelerate launch train_dreambooth.py \

Once you have trained a model, inference can be done using the `StableDiffusionPipeline`, by simply indicating the path where the model was saved. Make sure that your prompts include the special `identifier` used during training (`sks` in the previous examples).

**Note**: If you have installed `"accelerate>=0.16.0"` you can use the following code to run
inference from an intermediate checkpoint.


```python
from diffusers import StableDiffusionPipeline
import torch
Expand All @@ -284,4 +311,4 @@ image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
image.save("dog-bucket.png")
```

You may also run inference from [any of the saved training checkpoints](#performing-inference-using-a-saved-checkpoint).
You may also run inference from [any of the saved training checkpoints](#performing-inference-using-a-saved-checkpoint).
33 changes: 33 additions & 0 deletions examples/dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import torch.utils.checkpoint
from torch.utils.data import Dataset

import accelerate
import diffusers
import transformers
from accelerate import Accelerator
Expand All @@ -38,6 +39,7 @@
from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available
from huggingface_hub import HfFolder, Repository, create_repo, whoami
from packaging import version
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
Expand Down Expand Up @@ -606,6 +608,37 @@ def main(args):
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)

# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
for model in models:
sub_dir = "unet" if type(model) == type(unet) else "text_encoder"
model.save_pretrained(os.path.join(output_dir, sub_dir))

# make sure to pop weight so that corresponding model is not saved again
weights.pop()

def load_model_hook(models, input_dir):
while len(models) > 0:
# pop models so that they are not loaded again
model = models.pop()

if type(model) == type(text_encoder):
# load transformers style into model
load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
model.config = load_model.config
else:
# load diffusers style into model
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
model.register_to_config(**load_model.config)

model.load_state_dict(load_model.state_dict())
del load_model

accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)

vae.requires_grad_(False)
if not args.train_text_encoder:
text_encoder.requires_grad_(False)
Expand Down
52 changes: 48 additions & 4 deletions examples/text_to_image/train_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import torch.nn.functional as F
import torch.utils.checkpoint

import accelerate
import datasets
import diffusers
import transformers
Expand All @@ -36,9 +37,10 @@
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version
from diffusers.utils import check_min_version, deprecate
from diffusers.utils.import_utils import is_xformers_available
from huggingface_hub import HfFolder, Repository, create_repo, whoami
from packaging import version
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
Expand Down Expand Up @@ -319,6 +321,16 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:

def main():
args = parse_args()

if args.non_ema_revision is not None:
deprecate(
"non_ema_revision!=None",
"0.15.0",
message=(
"Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
" use `--variant=non_ema` instead."
),
)
logging_dir = os.path.join(args.output_dir, args.logging_dir)

accelerator = Accelerator(
Expand Down Expand Up @@ -396,6 +408,39 @@ def main():
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")

# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if args.use_ema:
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))

for i, model in enumerate(models):
model.save_pretrained(os.path.join(output_dir, "unet"))

# make sure to pop weight so that corresponding model is not saved again
weights.pop()

def load_model_hook(models, input_dir):
if args.use_ema:
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
ema_unet.load_state_dict(load_model.state_dict())
del load_model

for i in range(len(models)):
# pop models so that they are not loaded again
model = models.pop()

# load diffusers style into model
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
model.register_to_config(**load_model.config)

model.load_state_dict(load_model.state_dict())
del load_model

accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)

if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()

Expand Down Expand Up @@ -552,8 +597,9 @@ def collate_fn(examples):
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)

if args.use_ema:
accelerator.register_for_checkpointing(ema_unet)
ema_unet.to(accelerator.device)

# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
Expand All @@ -566,8 +612,6 @@ def collate_fn(examples):
# Move text_encode and vae to gpu and cast to weight_dtype
text_encoder.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
if args.use_ema:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done above

ema_unet.to(accelerator.device)

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
Expand Down
39 changes: 38 additions & 1 deletion examples/unconditional_image_generation/train_unconditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import torch.nn.functional as F

import accelerate
import datasets
import diffusers
from accelerate import Accelerator
Expand All @@ -19,6 +20,7 @@
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version
from huggingface_hub import HfFolder, Repository, create_repo, whoami
from packaging import version
from torchvision import transforms
from tqdm.auto import tqdm

Expand Down Expand Up @@ -271,6 +273,40 @@ def main(args):
logging_dir=logging_dir,
)

# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if args.use_ema:
ema_model.save_pretrained(os.path.join(output_dir, "unet_ema"))

for i, model in enumerate(models):
model.save_pretrained(os.path.join(output_dir, "unet"))

# make sure to pop weight so that corresponding model is not saved again
weights.pop()

def load_model_hook(models, input_dir):
if args.use_ema:
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DModel)
ema_model.load_state_dict(load_model.state_dict())
ema_model.to(accelerator.device)
del load_model

for i in range(len(models)):
# pop models so that they are not loaded again
model = models.pop()

# load diffusers style into model
load_model = UNet2DModel.from_pretrained(input_dir, subfolder="unet")
model.register_to_config(**load_model.config)

model.load_state_dict(load_model.state_dict())
del load_model

accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)

# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
Expand Down Expand Up @@ -336,6 +372,8 @@ def main(args):
use_ema_warmup=True,
inv_gamma=args.ema_inv_gamma,
power=args.ema_power,
model_cls=UNet2DModel,
model_config=model.config,
)

# Initialize the scheduler
Expand Down Expand Up @@ -411,7 +449,6 @@ def transform_images(examples):
)

if args.use_ema:
accelerator.register_for_checkpointing(ema_model)
ema_model.to(accelerator.device)

# We need to initialize the trackers we use, and also store our configuration.
Expand Down
49 changes: 41 additions & 8 deletions src/diffusers/training_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import os
import random
from typing import Iterable, Union
from typing import Any, Dict, Iterable, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -57,6 +57,8 @@ def __init__(
use_ema_warmup: bool = False,
inv_gamma: Union[float, int] = 1.0,
power: Union[float, int] = 2 / 3,
model_cls: Optional[Any] = None,
model_config: Dict[str, Any] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -123,6 +125,35 @@ def __init__(
self.power = power
self.optimization_step = 0

self.model_cls = model_cls
self.model_config = model_config

@classmethod
def from_pretrained(cls, path, model_cls) -> "EMAModel":
_, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True)
model = model_cls.from_pretrained(path)

ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config)

ema_model.load_state_dict(ema_kwargs)
return ema_model

def save_pretrained(self, path):
if self.model_cls is None:
raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.")

if self.model_config is None:
raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.")

model = self.model_cls.from_config(self.model_config)
state_dict = self.state_dict()
state_dict.pop("shadow_params", None)
state_dict.pop("collected_params", None)

model.register_to_config(**state_dict)
self.copy_to(model.parameters())
model.save_pretrained(path)

def get_decay(self, optimization_step: int) -> float:
"""
Compute the decay factor for the exponential moving average.
Expand Down Expand Up @@ -184,7 +215,7 @@ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
"""
parameters = list(parameters)
for s_param, param in zip(self.shadow_params, parameters):
param.data.copy_(s_param.data)
param.data.copy_(s_param.to(param.device).data)

def to(self, device=None, dtype=None) -> None:
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
Expand Down Expand Up @@ -257,13 +288,15 @@ def load_state_dict(self, state_dict: dict) -> None:
if not isinstance(self.power, (float, int)):
raise ValueError("Invalid power")

self.shadow_params = state_dict["shadow_params"]
if not isinstance(self.shadow_params, list):
raise ValueError("shadow_params must be a list")
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
raise ValueError("shadow_params must all be Tensors")
shadow_params = state_dict.get("shadow_params", None)
if shadow_params is not None:
self.shadow_params = shadow_params
if not isinstance(self.shadow_params, list):
raise ValueError("shadow_params must be a list")
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
raise ValueError("shadow_params must all be Tensors")

self.collected_params = state_dict["collected_params"]
self.collected_params = state_dict.get("collected_params", None)
if self.collected_params is not None:
if not isinstance(self.collected_params, list):
raise ValueError("collected_params must be a list")
Expand Down