diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index 2bbdb7a5da8f..2858c04c48b0 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -317,4 +317,7 @@ python train_dreambooth_flax.py \ --max_train_steps=800 ``` -You can also use Dreambooth to train the specialized in-painting model. See [the script in the research folder for details](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/dreambooth_inpaint). \ No newline at end of file +### Training with xformers: +You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation. + +You can also use Dreambooth to train the specialized in-painting model. See [the script in the research folder for details](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/dreambooth_inpaint). diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 122d346ff5ce..377f226150ee 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -248,6 +248,9 @@ def parse_args(input_args=None): ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) if input_args is not None: args = parser.parse_args(input_args) @@ -516,14 +519,11 @@ def main(args): revision=args.revision, ) - if is_xformers_available(): - try: + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): unet.enable_xformers_memory_efficient_attention() - except Exception as e: - logger.warning( - "Could not enable memory efficient attention. Make sure xformers is installed" - f" correctly and a GPU is available: {e}" - ) + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") vae.requires_grad_(False) if not args.train_text_encoder: diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md index 407578e3b717..e98e136a4b31 100644 --- a/examples/text_to_image/README.md +++ b/examples/text_to_image/README.md @@ -160,3 +160,6 @@ python train_text_to_image_flax.py \ --max_grad_norm=1 \ --output_dir="sd-pokemon-model" ``` + +### Training with xformers: +You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation. diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 8aac7b6b5b93..224fe471889e 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -234,6 +234,9 @@ def parse_args(): ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' ), ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -383,14 +386,11 @@ def main(): revision=args.revision, ) - if is_xformers_available(): - try: + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): unet.enable_xformers_memory_efficient_attention() - except Exception as e: - logger.warning( - "Could not enable memory efficient attention. Make sure xformers is installed" - f" correctly and a GPU is available: {e}" - ) + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") # Freeze vae and text_encoder vae.requires_grad_(False) diff --git a/examples/textual_inversion/README.md b/examples/textual_inversion/README.md index a2bde75b51de..3a7c96be69fb 100644 --- a/examples/textual_inversion/README.md +++ b/examples/textual_inversion/README.md @@ -124,3 +124,6 @@ python textual_inversion_flax.py \ --output_dir="textual_inversion_cat" ``` It should be at least 70% faster than the PyTorch script with the same configuration. + +### Training with xformers: +You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation. diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 7fbca761bdc8..74fcf71cb22c 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -222,6 +222,9 @@ def parse_args(): ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' ), ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -457,14 +460,11 @@ def main(): revision=args.revision, ) - if is_xformers_available(): - try: + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): unet.enable_xformers_memory_efficient_attention() - except Exception as e: - logger.warning( - "Could not enable memory efficient attention. Make sure xformers is installed" - f" correctly and a GPU is available: {e}" - ) + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") # Resize the token embeddings as we are adding new special tokens to the tokenizer text_encoder.resize_token_embeddings(len(tokenizer))