From ba93a84a83520e8983fab32846bee9656783a7ad Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 14 Dec 2023 10:41:39 -0600 Subject: [PATCH 1/8] load pipeline for inference only if validation prompt is used --- .../train_dreambooth_lora_sdxl_advanced.py | 59 +++++++++---------- 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index a46a1afcc145..f1564a116537 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -2010,43 +2010,42 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): text_encoder_lora_layers=text_encoder_lora_layers, text_encoder_2_lora_layers=text_encoder_2_lora_layers, ) + images = [] + if args.validation_prompt and args.num_validation_images > 0: + # Final inference + # Load previous pipeline + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline = StableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) - # Final inference - # Load previous pipeline - vae = AutoencoderKL.from_pretrained( - vae_path, - subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, - revision=args.revision, - variant=args.variant, - torch_dtype=weight_dtype, - ) - pipeline = StableDiffusionXLPipeline.from_pretrained( - args.pretrained_model_name_or_path, - vae=vae, - revision=args.revision, - variant=args.variant, - torch_dtype=weight_dtype, - ) - - # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it - scheduler_args = {} + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} - if "variance_type" in pipeline.scheduler.config: - variance_type = pipeline.scheduler.config.variance_type + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type - if variance_type in ["learned", "learned_range"]: - variance_type = "fixed_small" + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" - scheduler_args["variance_type"] = variance_type + scheduler_args["variance_type"] = variance_type - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) - # load attention processors - pipeline.load_lora_weights(args.output_dir) + # load attention processors + pipeline.load_lora_weights(args.output_dir) - # run inference - images = [] - if args.validation_prompt and args.num_validation_images > 0: + # run inference pipeline = pipeline.to(accelerator.device) generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None images = [ From b0f3d0c056b9c4d0adb09e7de159a68bdd43b9f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?apolin=C3=A1rio?= Date: Thu, 14 Dec 2023 10:52:32 -0600 Subject: [PATCH 2/8] move things outside --- .../train_dreambooth_lora_sdxl_advanced.py | 161 +++++++++--------- 1 file changed, 80 insertions(+), 81 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index f1564a116537..31a3c3b5c638 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -2010,88 +2010,87 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): text_encoder_lora_layers=text_encoder_lora_layers, text_encoder_2_lora_layers=text_encoder_2_lora_layers, ) - images = [] - if args.validation_prompt and args.num_validation_images > 0: - # Final inference - # Load previous pipeline - vae = AutoencoderKL.from_pretrained( - vae_path, - subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, - revision=args.revision, - variant=args.variant, - torch_dtype=weight_dtype, - ) - pipeline = StableDiffusionXLPipeline.from_pretrained( - args.pretrained_model_name_or_path, - vae=vae, - revision=args.revision, - variant=args.variant, - torch_dtype=weight_dtype, - ) - - # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it - scheduler_args = {} - - if "variance_type" in pipeline.scheduler.config: - variance_type = pipeline.scheduler.config.variance_type - - if variance_type in ["learned", "learned_range"]: - variance_type = "fixed_small" - - scheduler_args["variance_type"] = variance_type - - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) - - # load attention processors - pipeline.load_lora_weights(args.output_dir) - - # run inference - pipeline = pipeline.to(accelerator.device) - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None - images = [ - pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] - for _ in range(args.num_validation_images) - ] - - for tracker in accelerator.trackers: - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") - if tracker.name == "wandb": - tracker.log( - { - "test": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) - ] - } - ) - - if args.train_text_encoder_ti: - embedding_handler.save_embeddings( - f"{args.output_dir}/embeddings.safetensors", - ) - save_model_card( - model_id if not args.push_to_hub else repo_id, - images=images, - base_model=args.pretrained_model_name_or_path, - train_text_encoder=args.train_text_encoder, - train_text_encoder_ti=args.train_text_encoder_ti, - token_abstraction_dict=train_dataset.token_abstraction_dict, - instance_prompt=args.instance_prompt, - validation_prompt=args.validation_prompt, - repo_folder=args.output_dir, - vae_path=args.pretrained_vae_model_name_or_path, - ) - if args.push_to_hub: - upload_folder( - repo_id=repo_id, - folder_path=args.output_dir, - commit_message="End of training", - ignore_patterns=["step_*", "epoch_*"], - ) - accelerator.end_training() + images = [] + if args.validation_prompt and args.num_validation_images > 0: + # Final inference + # Load previous pipeline + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline = StableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) + + # load attention processors + pipeline.load_lora_weights(args.output_dir) + + # run inference + pipeline = pipeline.to(accelerator.device) + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + images = [ + pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + for _ in range(args.num_validation_images) + ] + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "test": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + if args.train_text_encoder_ti: + embedding_handler.save_embeddings( + f"{args.output_dir}/embeddings.safetensors", + ) + save_model_card( + model_id if not args.push_to_hub else repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + train_text_encoder=args.train_text_encoder, + train_text_encoder_ti=args.train_text_encoder_ti, + token_abstraction_dict=train_dataset.token_abstraction_dict, + instance_prompt=args.instance_prompt, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + vae_path=args.pretrained_vae_model_name_or_path, + ) + if args.push_to_hub: + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) if __name__ == "__main__": From 9d8c2d1e5e1615e79e0595a502c2dbff78c77439 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 14 Dec 2023 11:09:01 -0600 Subject: [PATCH 3/8] load pipeline for inference only if validation prompt is used --- .../train_dreambooth_lora_sdxl_advanced.py | 161 +++++++++--------- 1 file changed, 81 insertions(+), 80 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 31a3c3b5c638..f1564a116537 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -2010,87 +2010,88 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): text_encoder_lora_layers=text_encoder_lora_layers, text_encoder_2_lora_layers=text_encoder_2_lora_layers, ) - accelerator.end_training() - images = [] - if args.validation_prompt and args.num_validation_images > 0: - # Final inference - # Load previous pipeline - vae = AutoencoderKL.from_pretrained( - vae_path, - subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, - revision=args.revision, - variant=args.variant, - torch_dtype=weight_dtype, - ) - pipeline = StableDiffusionXLPipeline.from_pretrained( - args.pretrained_model_name_or_path, - vae=vae, - revision=args.revision, - variant=args.variant, - torch_dtype=weight_dtype, - ) - - # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it - scheduler_args = {} - - if "variance_type" in pipeline.scheduler.config: - variance_type = pipeline.scheduler.config.variance_type - - if variance_type in ["learned", "learned_range"]: - variance_type = "fixed_small" - - scheduler_args["variance_type"] = variance_type - - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) - - # load attention processors - pipeline.load_lora_weights(args.output_dir) - - # run inference - pipeline = pipeline.to(accelerator.device) - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None - images = [ - pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] - for _ in range(args.num_validation_images) - ] - - for tracker in accelerator.trackers: - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") - if tracker.name == "wandb": - tracker.log( - { - "test": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) - ] - } - ) - - if args.train_text_encoder_ti: - embedding_handler.save_embeddings( - f"{args.output_dir}/embeddings.safetensors", - ) - save_model_card( - model_id if not args.push_to_hub else repo_id, - images=images, - base_model=args.pretrained_model_name_or_path, - train_text_encoder=args.train_text_encoder, - train_text_encoder_ti=args.train_text_encoder_ti, - token_abstraction_dict=train_dataset.token_abstraction_dict, - instance_prompt=args.instance_prompt, - validation_prompt=args.validation_prompt, - repo_folder=args.output_dir, - vae_path=args.pretrained_vae_model_name_or_path, - ) - if args.push_to_hub: - upload_folder( - repo_id=repo_id, - folder_path=args.output_dir, - commit_message="End of training", - ignore_patterns=["step_*", "epoch_*"], + images = [] + if args.validation_prompt and args.num_validation_images > 0: + # Final inference + # Load previous pipeline + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline = StableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) + + # load attention processors + pipeline.load_lora_weights(args.output_dir) + + # run inference + pipeline = pipeline.to(accelerator.device) + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + images = [ + pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + for _ in range(args.num_validation_images) + ] + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "test": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + if args.train_text_encoder_ti: + embedding_handler.save_embeddings( + f"{args.output_dir}/embeddings.safetensors", + ) + save_model_card( + model_id if not args.push_to_hub else repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + train_text_encoder=args.train_text_encoder, + train_text_encoder_ti=args.train_text_encoder_ti, + token_abstraction_dict=train_dataset.token_abstraction_dict, + instance_prompt=args.instance_prompt, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + vae_path=args.pretrained_vae_model_name_or_path, ) + if args.push_to_hub: + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() if __name__ == "__main__": From f729e903487de1e97603b219176070af10a38e97 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 14 Dec 2023 11:21:19 -0600 Subject: [PATCH 4/8] fix readme when validation prompt is used --- .../train_dreambooth_lora_sdxl_advanced.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index f1564a116537..ad37363b7d30 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -112,7 +112,7 @@ def save_model_card( repo_folder=None, vae_path=None, ): - img_str = "widget:\n" if images else "" + img_str = "widget:\n" for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) img_str += f""" @@ -121,6 +121,10 @@ def save_model_card( url: "image_{i}.png" """ + if not images: + img_str += f""" + - text: '{instance_prompt}' + """ trigger_str = f"You should use {instance_prompt} to trigger the image generation." diffusers_imports_pivotal = "" @@ -157,8 +161,6 @@ def save_model_card( base_model: {base_model} instance_prompt: {instance_prompt} license: openrail++ -widget: - - text: '{validation_prompt if validation_prompt else instance_prompt}' --- """ From eb08727bb61b74d7fba314c11d1d0c674966a2e6 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 25 Dec 2023 11:32:29 +0200 Subject: [PATCH 5/8] chunk timesteps when using prior preservation loss + snr gamma --- .../train_dreambooth_lora_sdxl_advanced.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index ad37363b7d30..cbbc725fbb8b 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -1839,9 +1839,17 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. - snr = compute_snr(noise_scheduler, timesteps) + + if args.with_prior_preservation: + # if we're using prior preservation, we calc snr for instance loss only - + # and hence only need timesteps corresponding to instance images + snr_timesteps, _ = torch.chunk(timesteps, 2, dim=0) + else: + snr_timesteps = timesteps + + snr = compute_snr(noise_scheduler, snr_timesteps) base_weight = ( - torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + torch.stack([snr, args.snr_gamma * torch.ones_like(snr_timesteps)], dim=1).min(dim=1)[0] / snr ) if noise_scheduler.config.prediction_type == "v_prediction": From 5a5a395dfbff5073427d404d3968583254b06362 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 25 Dec 2023 16:14:04 +0200 Subject: [PATCH 6/8] add peft changes from the canonical script --- .../train_dreambooth_lora_sdxl_advanced.py | 186 +++++++----------- 1 file changed, 70 insertions(+), 116 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index cbbc725fbb8b..432e100f817e 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -36,7 +36,10 @@ from accelerate.logging import get_logger from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder +from huggingface_hub.utils import insecure_hashlib from packaging import version +from peft import LoraConfig +from peft.utils import get_peft_model_state_dict from PIL import Image from PIL.ImageOps import exif_transpose from safetensors.torch import save_file @@ -54,9 +57,8 @@ UNet2DConditionModel, ) from diffusers.loaders import LoraLoaderMixin -from diffusers.models.lora import LoRALinearLayer from diffusers.optimization import get_scheduler -from diffusers.training_utils import compute_snr, unet_lora_state_dict +from diffusers.training_utils import compute_snr from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -67,39 +69,6 @@ logger = get_logger(__name__) -# TODO: This function should be removed once training scripts are rewritten in PEFT -def text_encoder_lora_state_dict(text_encoder): - state_dict = {} - - def text_encoder_attn_modules(text_encoder): - from transformers import CLIPTextModel, CLIPTextModelWithProjection - - attn_modules = [] - - if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): - for i, layer in enumerate(text_encoder.text_model.encoder.layers): - name = f"text_model.encoder.layers.{i}.self_attn" - mod = layer.self_attn - attn_modules.append((name, mod)) - - return attn_modules - - for name, module in text_encoder_attn_modules(text_encoder): - for k, v in module.q_proj.lora_linear_layer.state_dict().items(): - state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v - - for k, v in module.k_proj.lora_linear_layer.state_dict().items(): - state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v - - for k, v in module.v_proj.lora_linear_layer.state_dict().items(): - state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v - - for k, v in module.out_proj.lora_linear_layer.state_dict().items(): - state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v - - return state_dict - - def save_model_card( repo_id: str, images=None, @@ -1123,7 +1092,7 @@ def main(args): images = pipeline(example["prompt"]).images for i, image in enumerate(images): - hash_image = hashlib.sha1(image.tobytes()).hexdigest() + hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) @@ -1136,10 +1105,10 @@ def main(args): if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) - model_id = args.hub_model_id or Path(args.output_dir).name - repo_id = None if args.push_to_hub: - repo_id = create_repo(repo_id=model_id, exist_ok=True, token=args.hub_token).repo_id + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id # Load the tokenizers tokenizer_one = AutoTokenizer.from_pretrained( @@ -1262,76 +1231,37 @@ def main(args): text_encoder_two.gradient_checkpointing_enable() # now we will add new LoRA weights to the attention layers - # Set correct lora layers - unet_lora_parameters = [] - for attn_processor_name, attn_processor in unet.attn_processors.items(): - # Parse the attention module. - attn_module = unet - for n in attn_processor_name.split(".")[:-1]: - attn_module = getattr(attn_module, n) - - # Set the `lora_layer` attribute of the attention-related matrices. - attn_module.to_q.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank - ) - ) - attn_module.to_k.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank - ) - ) - attn_module.to_v.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank - ) - ) - attn_module.to_out[0].set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_out[0].in_features, - out_features=attn_module.to_out[0].out_features, - rank=args.rank, - ) - ) - - # Accumulate the LoRA params to optimize. - unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters()) - unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters()) - unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) - unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) + unet_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + ) + unet.add_adapter(unet_lora_config) # The text encoder comes from 🤗 transformers, so we cannot directly modify it. # So, instead, we monkey-patch the forward calls of its attention-blocks. if args.train_text_encoder: - # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 - text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder( - text_encoder_one, dtype=torch.float32, rank=args.rank - ) - text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder( - text_encoder_two, dtype=torch.float32, rank=args.rank + text_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], ) + text_encoder_one.add_adapter(text_lora_config) + text_encoder_two.add_adapter(text_lora_config) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [unet] + if args.train_text_encoder: + models.extend([text_encoder_one, text_encoder_two]) + for model in models: + for param in model.parameters(): + # only upcast trainable parameters (LoRA) into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) - # if we use textual inversion, we freeze all parameters except for the token embeddings - # in text encoder - elif args.train_text_encoder_ti: - text_lora_parameters_one = [] - for name, param in text_encoder_one.named_parameters(): - if "token_embedding" in name: - # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 - param = param.to(dtype=torch.float32) - param.requires_grad = True - text_lora_parameters_one.append(param) - else: - param.requires_grad = False - text_lora_parameters_two = [] - for name, param in text_encoder_two.named_parameters(): - if "token_embedding" in name: - # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 - param = param.to(dtype=torch.float32) - param.requires_grad = True - text_lora_parameters_two.append(param) - else: - param.requires_grad = False # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): @@ -1344,11 +1274,11 @@ def save_model_hook(models, weights, output_dir): for model in models: if isinstance(model, type(accelerator.unwrap_model(unet))): - unet_lora_layers_to_save = unet_lora_state_dict(model) + unet_lora_layers_to_save = get_peft_model_state_dict(model) elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): - text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model) + text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): - text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model) + text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model) else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1405,25 +1335,47 @@ def load_model_hook(models, input_dir): args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) + unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters())) + # If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti) + if args.train_text_encoder: + text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) + text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters())) + # if we use textual inversion, we freeze all parameters except for the token embeddings + # in text encoder + elif args.train_text_encoder_ti: + text_lora_parameters_one = [] + for name, param in text_encoder_one.named_parameters(): + if "token_embedding" in name: + # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 + param = param.to(dtype=torch.float32) + param.requires_grad = True + text_lora_parameters_one.append(param) + else: + param.requires_grad = False + text_lora_parameters_two = [] + for name, param in text_encoder_two.named_parameters(): + if "token_embedding" in name: + # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 + param = param.to(dtype=torch.float32) + param.requires_grad = True + text_lora_parameters_two.append(param) + else: + param.requires_grad = False # Optimization parameters unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate} if not freeze_text_encoder: # different learning rate for text encoder and unet text_lora_parameters_one_with_lr = { "params": text_lora_parameters_one, - "weight_decay": args.adam_weight_decay_text_encoder - if args.adam_weight_decay_text_encoder - else args.adam_weight_decay, + "weight_decay": args.adam_weight_decay_text_encoder, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, } text_lora_parameters_two_with_lr = { "params": text_lora_parameters_two, - "weight_decay": args.adam_weight_decay_text_encoder - if args.adam_weight_decay_text_encoder - else args.adam_weight_decay, + "weight_decay": args.adam_weight_decay_text_encoder, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, } params_to_optimize = [ @@ -2003,13 +1955,13 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): if accelerator.is_main_process: unet = accelerator.unwrap_model(unet) unet = unet.to(torch.float32) - unet_lora_layers = unet_lora_state_dict(unet) + unet_lora_layers = get_peft_model_state_dict(unet) if args.train_text_encoder: text_encoder_one = accelerator.unwrap_model(text_encoder_one) - text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32)) + text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32)) text_encoder_two = accelerator.unwrap_model(text_encoder_two) - text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32)) + text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32)) else: text_encoder_lora_layers = None text_encoder_2_lora_layers = None @@ -2055,7 +2007,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # load attention processors pipeline.load_lora_weights(args.output_dir) - # run inference + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: pipeline = pipeline.to(accelerator.device) generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None images = [ From 5cf15d10a204074526580b6dacd96ee4a9aaaa37 Mon Sep 17 00:00:00 2001 From: Linoy Date: Mon, 25 Dec 2023 15:24:43 +0000 Subject: [PATCH 7/8] style --- .../train_dreambooth_lora_sdxl_advanced.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 432e100f817e..8c9a0423397a 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -15,7 +15,6 @@ import argparse import gc -import hashlib import itertools import logging import math @@ -1262,7 +1261,6 @@ def main(args): if param.requires_grad: param.data = param.to(torch.float32) - # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: From a3e1ec767da33a6c65799da049b2b52444a40794 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 25 Dec 2023 17:39:40 +0200 Subject: [PATCH 8/8] model id --- .../train_dreambooth_lora_sdxl_advanced.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 8c9a0423397a..2f1002012473 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -1104,10 +1104,10 @@ def main(args): if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) + model_id = args.hub_model_id or Path(args.output_dir).name + repo_id = None if args.push_to_hub: - repo_id = create_repo( - repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token - ).repo_id + repo_id = create_repo(repo_id=model_id, exist_ok=True, token=args.hub_token).repo_id # Load the tokenizers tokenizer_one = AutoTokenizer.from_pretrained(