Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions docs/source/en/training/dreambooth.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ accelerate launch train_dreambooth.py \
--learning_rate=5e-6 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=400
--max_train_steps=400 \
--push_to_hub
```
</pt>
<jax>
Expand Down Expand Up @@ -161,7 +162,8 @@ accelerate launch train_dreambooth.py \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--num_class_images=200 \
--max_train_steps=800
--max_train_steps=800 \
--push_to_hub
```
</pt>
<jax>
Expand Down Expand Up @@ -225,7 +227,8 @@ accelerate launch train_dreambooth.py \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--num_class_images=200 \
--max_train_steps=800
--max_train_steps=800 \
--push_to_hub
```
</pt>
<jax>
Expand Down Expand Up @@ -387,7 +390,8 @@ accelerate launch train_dreambooth.py \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--num_class_images=200 \
--max_train_steps=800
--max_train_steps=800 \
--push_to_hub
```

### 12GB GPU
Expand Down Expand Up @@ -418,7 +422,8 @@ accelerate launch train_dreambooth.py \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--num_class_images=200 \
--max_train_steps=800
--max_train_steps=800 \
--push_to_hub
```

### 8 GB GPU
Expand Down Expand Up @@ -464,7 +469,8 @@ accelerate launch train_dreambooth.py \
--lr_warmup_steps=0 \
--num_class_images=200 \
--max_train_steps=800 \
--mixed_precision=fp16
--mixed_precision=fp16 \
--push_to_hub
```

## Inference
Expand Down
18 changes: 12 additions & 6 deletions examples/dreambooth/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ accelerate launch train_dreambooth.py \
--learning_rate=5e-6 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=400
--max_train_steps=400 \
--push_to_hub
```

### Training with prior-preservation loss
Expand Down Expand Up @@ -109,7 +110,8 @@ accelerate launch train_dreambooth.py \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--num_class_images=200 \
--max_train_steps=800
--max_train_steps=800 \
--push_to_hub
```


Expand Down Expand Up @@ -141,7 +143,8 @@ accelerate launch train_dreambooth.py \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--num_class_images=200 \
--max_train_steps=800
--max_train_steps=800 \
--push_to_hub
```


Expand Down Expand Up @@ -176,7 +179,8 @@ accelerate launch train_dreambooth.py \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--num_class_images=200 \
--max_train_steps=800
--max_train_steps=800 \
--push_to_hub
```


Expand Down Expand Up @@ -218,7 +222,8 @@ accelerate launch --mixed_precision="fp16" train_dreambooth.py \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--num_class_images=200 \
--max_train_steps=800
--max_train_steps=800 \
--push_to_hub
```

### Fine-tune text encoder with the UNet.
Expand Down Expand Up @@ -251,7 +256,8 @@ accelerate launch train_dreambooth.py \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--num_class_images=200 \
--max_train_steps=800
--max_train_steps=800 \
--push_to_hub
```

### Using DreamBooth for pipelines other than Stable Diffusion
Expand Down
48 changes: 47 additions & 1 deletion examples/dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,39 @@
logger = get_logger(__name__)


def save_model_card(repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None):
img_str = ""
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"

yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
instance_prompt: {prompt}
tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
- diffusers
- dreambooth
inference: true
---
"""
model_card = f"""
# DreamBooth - {repo_id}

This is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/).
You can find some example images in the following. \n
{img_str}

DreamBooth for the text encoder was enabled: {train_text_encoder}.
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)


def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
Expand Down Expand Up @@ -104,6 +137,8 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight
del pipeline
torch.cuda.empty_cache()

return images


def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
text_encoder_config = PretrainedConfig.from_pretrained(
Expand Down Expand Up @@ -997,13 +1032,16 @@ def load_model_hook(models, input_dir):
global_step += 1

if accelerator.is_main_process:
images = []
if global_step % args.checkpointing_steps == 0:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")

if args.validation_prompt is not None and global_step % args.validation_steps == 0:
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
images = log_validation(
text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch
)

logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
Expand All @@ -1024,6 +1062,14 @@ def load_model_hook(models, input_dir):
pipeline.save_pretrained(args.output_dir)

if args.push_to_hub:
save_model_card(
repo_id,
images=images,
base_model=args.pretrained_model_name_or_path,
train_text_encoder=args.train_text_encoder,
prompt=args.instance_prompt,
repo_folder=args.output_dir,
)
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
Expand Down