diff --git a/docs/source/en/training/dreambooth.mdx b/docs/source/en/training/dreambooth.mdx index c5a5a047d114..09b877c7d0cc 100644 --- a/docs/source/en/training/dreambooth.mdx +++ b/docs/source/en/training/dreambooth.mdx @@ -98,7 +98,8 @@ accelerate launch train_dreambooth.py \ --learning_rate=5e-6 \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ - --max_train_steps=400 + --max_train_steps=400 \ + --push_to_hub ``` @@ -161,7 +162,8 @@ accelerate launch train_dreambooth.py \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --num_class_images=200 \ - --max_train_steps=800 + --max_train_steps=800 \ + --push_to_hub ``` @@ -225,7 +227,8 @@ accelerate launch train_dreambooth.py \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --num_class_images=200 \ - --max_train_steps=800 + --max_train_steps=800 \ + --push_to_hub ``` @@ -387,7 +390,8 @@ accelerate launch train_dreambooth.py \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --num_class_images=200 \ - --max_train_steps=800 + --max_train_steps=800 \ + --push_to_hub ``` ### 12GB GPU @@ -418,7 +422,8 @@ accelerate launch train_dreambooth.py \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --num_class_images=200 \ - --max_train_steps=800 + --max_train_steps=800 \ + --push_to_hub ``` ### 8 GB GPU @@ -464,7 +469,8 @@ accelerate launch train_dreambooth.py \ --lr_warmup_steps=0 \ --num_class_images=200 \ --max_train_steps=800 \ - --mixed_precision=fp16 + --mixed_precision=fp16 \ + --push_to_hub ``` ## Inference diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index e1eb8a06b0ff..490e31458988 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -80,7 +80,8 @@ accelerate launch train_dreambooth.py \ --learning_rate=5e-6 \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ - --max_train_steps=400 + --max_train_steps=400 \ + --push_to_hub ``` ### Training with prior-preservation loss @@ -109,7 +110,8 @@ accelerate launch train_dreambooth.py \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --num_class_images=200 \ - --max_train_steps=800 + --max_train_steps=800 \ + --push_to_hub ``` @@ -141,7 +143,8 @@ accelerate launch train_dreambooth.py \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --num_class_images=200 \ - --max_train_steps=800 + --max_train_steps=800 \ + --push_to_hub ``` @@ -176,7 +179,8 @@ accelerate launch train_dreambooth.py \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --num_class_images=200 \ - --max_train_steps=800 + --max_train_steps=800 \ + --push_to_hub ``` @@ -218,7 +222,8 @@ accelerate launch --mixed_precision="fp16" train_dreambooth.py \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --num_class_images=200 \ - --max_train_steps=800 + --max_train_steps=800 \ + --push_to_hub ``` ### Fine-tune text encoder with the UNet. @@ -251,7 +256,8 @@ accelerate launch train_dreambooth.py \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --num_class_images=200 \ - --max_train_steps=800 + --max_train_steps=800 \ + --push_to_hub ``` ### Using DreamBooth for pipelines other than Stable Diffusion diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 593af005d6f4..190f4625a16c 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -61,6 +61,39 @@ logger = get_logger(__name__) +def save_model_card(repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None): + img_str = "" + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {base_model} +instance_prompt: {prompt} +tags: +- stable-diffusion +- stable-diffusion-diffusers +- text-to-image +- diffusers +- dreambooth +inference: true +--- + """ + model_card = f""" +# DreamBooth - {repo_id} + +This is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). +You can find some example images in the following. \n +{img_str} + +DreamBooth for the text encoder was enabled: {train_text_encoder}. +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" @@ -104,6 +137,8 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight del pipeline torch.cuda.empty_cache() + return images + def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( @@ -997,13 +1032,16 @@ def load_model_hook(models, input_dir): global_step += 1 if accelerator.is_main_process: + images = [] if global_step % args.checkpointing_steps == 0: save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") if args.validation_prompt is not None and global_step % args.validation_steps == 0: - log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch) + images = log_validation( + text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch + ) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -1024,6 +1062,14 @@ def load_model_hook(models, input_dir): pipeline.save_pretrained(args.output_dir) if args.push_to_hub: + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + train_text_encoder=args.train_text_encoder, + prompt=args.instance_prompt, + repo_folder=args.output_dir, + ) upload_folder( repo_id=repo_id, folder_path=args.output_dir,