From 8a21d7e0a98ca0602598361408d6b94f5bb77f6e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 18 Jan 2023 18:51:16 +0000 Subject: [PATCH 1/4] [Lora] up lora training --- examples/dreambooth/train_dreambooth_lora.py | 69 +++++++++++++------- 1 file changed, 44 insertions(+), 25 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 66986e6c8f00..44d0e09dcdb3 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -58,6 +58,22 @@ logger = get_logger(__name__) +def create_model_card(images, base_model, prompt): + markdown_lines = [] + --- +license: creativeml-openrail-m +tags: +- stable-diffusion +- stable-diffusion-diffusers +- text-to-image +- diffusers +inference: true +--- + + + + + def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, @@ -913,34 +929,37 @@ def main(args): unet = unet.to(torch.float32) unet.save_attn_procs(args.output_dir) + # Final inference + # Load previous pipeline + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype + ) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline = pipeline.to(accelerator.device) + + # load attention processors + pipeline.unet.load_attn_procs(args.output_dir) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + prompt = args.num_validation_images * [args.validation_prompt] + images = pipeline(prompt, num_inference_steps=25, generator=generator).images + + for tracker in accelerator.trackers: + if tracker.name == "wandb": + tracker.log( + { + "test": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + model_card = create_model_card(images, base_model=args.pretrained_model_name_or_path, prompt=args.valiadtion_prompt) + if args.push_to_hub: repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) - # Final inference - # Load previous pipeline - pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype - ) - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) - pipeline = pipeline.to(accelerator.device) - - # load attention processors - pipeline.unet.load_attn_procs(args.output_dir) - - # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) - prompt = args.num_validation_images * [args.validation_prompt] - images = pipeline(prompt, num_inference_steps=25, generator=generator).images - - for tracker in accelerator.trackers: - if tracker.name == "wandb": - tracker.log( - { - "test": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) - ] - } - ) accelerator.end_training() From f39ec8459c8fc0e90a8cbe5b68c9fa856441e47f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 18 Jan 2023 19:54:29 +0100 Subject: [PATCH 2/4] finish --- examples/dreambooth/train_dreambooth_lora.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 44d0e09dcdb3..fb2e6cbef2dd 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -59,9 +59,10 @@ def create_model_card(images, base_model, prompt): - markdown_lines = [] - --- + yaml = f""" +--- license: creativeml-openrail-m +base_model: {base_model} tags: - stable-diffusion - stable-diffusion-diffusers @@ -69,6 +70,7 @@ def create_model_card(images, base_model, prompt): - diffusers inference: true --- + """ From b2068632a4453d4decc52b4ba96074f8813d943f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 19 Jan 2023 00:09:27 +0100 Subject: [PATCH 3/4] finish --- examples/dreambooth/train_dreambooth_lora.py | 29 +++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index fb2e6cbef2dd..095698654456 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -58,7 +58,12 @@ logger = get_logger(__name__) -def create_model_card(images, base_model, prompt): +def save_model_card(repo_name, images=None, base_model=str, prompt=str, repo_folder=None): + img_str = "" + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += "![img_{i}](./image_{i}.png)\n" + yaml = f""" --- license: creativeml-openrail-m @@ -71,9 +76,14 @@ def create_model_card(images, base_model, prompt): inference: true --- """ + model_card = f""" +# LoRA DreamBooth - {repo_name} - - +These are LoRA adaption weights for {repo_name}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n +{img_str} +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): @@ -952,17 +962,22 @@ def main(args): tracker.log( { "test": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) ] } ) - model_card = create_model_card(images, base_model=args.pretrained_model_name_or_path, prompt=args.valiadtion_prompt) - if args.push_to_hub: + save_model_card( + repo_name, + images=images, + base_model=args.pretrained_model_name_or_path, + prompt=args.instance_prompt, + repo_folder=args.output_dir, + ) repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) - accelerator.end_training() From ea1b6fb543aaad9cd7a730f449b1ff07fe4a6a1c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 18 Jan 2023 23:22:06 +0000 Subject: [PATCH 4/4] finish model card --- examples/dreambooth/train_dreambooth_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 095698654456..f516fd1b638b 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -62,7 +62,7 @@ def save_model_card(repo_name, images=None, base_model=str, prompt=str, repo_fol img_str = "" for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) - img_str += "![img_{i}](./image_{i}.png)\n" + img_str += f"![img_{i}](./image_{i}.png)\n" yaml = f""" ---