diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 190f4625a16c..b30f8132ba96 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -608,6 +608,17 @@ def __getitem__(self, index): example["index"] = index return example +def enable_xformers_for_object(obj_name): + if is_xformers_available(): + import xformers + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + obj_name.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") def main(args): logging_dir = Path(args.output_dir, args.logging_dir) @@ -676,6 +687,9 @@ def main(args): ) pipeline.set_progress_bar_config(disable=True) + if args.enable_xformers_memory_efficient_attention: + enable_xformers_for_object(pipeline) + num_new_images = args.num_class_images - cur_class_images logger.info(f"Number of class images to sample: {num_new_images}.") @@ -769,17 +783,7 @@ def load_model_hook(models, input_dir): text_encoder.requires_grad_(False) if args.enable_xformers_memory_efficient_attention: - if is_xformers_available(): - import xformers - - xformers_version = version.parse(xformers.__version__) - if xformers_version == version.parse("0.0.16"): - logger.warn( - "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." - ) - unet.enable_xformers_memory_efficient_attention() - else: - raise ValueError("xformers is not available. Make sure it is installed correctly") + enable_xformers_for_object(unet) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 5cefc57c614d..53f396888c02 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -535,7 +535,18 @@ def __getitem__(self, index): example["index"] = index return example - +def enable_xformers_for_object(obj_name): + if is_xformers_available(): + import xformers + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + obj_name.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + def main(args): logging_dir = Path(args.output_dir, args.logging_dir) @@ -604,6 +615,9 @@ def main(args): ) pipeline.set_progress_bar_config(disable=True) + if args.enable_xformers_memory_efficient_attention: + enable_xformers_for_object(pipeline) + num_new_images = args.num_class_images - cur_class_images logger.info(f"Number of class images to sample: {num_new_images}.") @@ -680,17 +694,7 @@ def main(args): text_encoder.to(accelerator.device, dtype=weight_dtype) if args.enable_xformers_memory_efficient_attention: - if is_xformers_available(): - import xformers - - xformers_version = version.parse(xformers.__version__) - if xformers_version == version.parse("0.0.16"): - logger.warn( - "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." - ) - unet.enable_xformers_memory_efficient_attention() - else: - raise ValueError("xformers is not available. Make sure it is installed correctly") + enable_xformers_for_object(unet) # now we will add new LoRA weights to the attention layers # It's important to realize here how many attention weights will be added and of which sizes