From f81ffdcecd39a520be119c66a9db14ab05cb64f5 Mon Sep 17 00:00:00 2001 From: kn Date: Sun, 18 Dec 2022 21:10:53 -0500 Subject: [PATCH 1/5] Make xformers optional even if it is available --- examples/dreambooth/train_dreambooth.py | 20 +++++++++++-------- examples/text_to_image/train_text_to_image.py | 20 +++++++++++-------- .../textual_inversion/textual_inversion.py | 20 +++++++++++-------- 3 files changed, 36 insertions(+), 24 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 122d346ff5ce..21c7acfe5b6e 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -248,6 +248,7 @@ def parse_args(input_args=None): ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument("--use_xformers", action="store_true", help="Whether or not to use xformers.") if input_args is not None: args = parser.parse_args(input_args) @@ -516,14 +517,17 @@ def main(args): revision=args.revision, ) - if is_xformers_available(): - try: - unet.enable_xformers_memory_efficient_attention() - except Exception as e: - logger.warning( - "Could not enable memory efficient attention. Make sure xformers is installed" - f" correctly and a GPU is available: {e}" - ) + if args.use_xformers: + if is_xformers_available(): + try: + unet.enable_xformers_memory_efficient_attention() + except Exception as e: + logger.warning( + "Could not enable memory efficient attention. Make sure xformers is installed" + f" correctly and a GPU is available: {e}" + ) + else: + logger.warning("Could not enable memory efficient attention. Make sure xformers is installed correctly.") vae.requires_grad_(False) if not args.train_text_encoder: diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 8aac7b6b5b93..4a2fdb05d9e4 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -234,6 +234,7 @@ def parse_args(): ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' ), ) + parser.add_argument("--use_xformers", action="store_true", help="Whether or not to use xformers.") args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -383,14 +384,17 @@ def main(): revision=args.revision, ) - if is_xformers_available(): - try: - unet.enable_xformers_memory_efficient_attention() - except Exception as e: - logger.warning( - "Could not enable memory efficient attention. Make sure xformers is installed" - f" correctly and a GPU is available: {e}" - ) + if args.use_xformers: + if is_xformers_available(): + try: + unet.enable_xformers_memory_efficient_attention() + except Exception as e: + logger.warning( + "Could not enable memory efficient attention. Make sure xformers is installed" + f" correctly and a GPU is available: {e}" + ) + else: + logger.warning("Could not enable memory efficient attention. Make sure xformers is installed correctly.") # Freeze vae and text_encoder vae.requires_grad_(False) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 7fbca761bdc8..cd6025994dd8 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -222,6 +222,7 @@ def parse_args(): ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' ), ) + parser.add_argument("--use_xformers", action="store_true", help="Whether or not to use xformers.") args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -457,14 +458,17 @@ def main(): revision=args.revision, ) - if is_xformers_available(): - try: - unet.enable_xformers_memory_efficient_attention() - except Exception as e: - logger.warning( - "Could not enable memory efficient attention. Make sure xformers is installed" - f" correctly and a GPU is available: {e}" - ) + if args.use_xformers: + if is_xformers_available(): + try: + unet.enable_xformers_memory_efficient_attention() + except Exception as e: + logger.warning( + "Could not enable memory efficient attention. Make sure xformers is installed" + f" correctly and a GPU is available: {e}" + ) + else: + logger.warning("Could not enable memory efficient attention. Make sure xformers is installed correctly.") # Resize the token embeddings as we are adding new special tokens to the tokenizer text_encoder.resize_token_embeddings(len(tokenizer)) From a576f93f29c5b6f488aaf32b24cd0dd677dbc0ca Mon Sep 17 00:00:00 2001 From: kn Date: Mon, 19 Dec 2022 08:04:17 -0500 Subject: [PATCH 2/5] Raise exception if xformers is used but not available --- examples/dreambooth/train_dreambooth.py | 10 ++-------- examples/text_to_image/train_text_to_image.py | 10 ++-------- examples/textual_inversion/textual_inversion.py | 10 ++-------- 3 files changed, 6 insertions(+), 24 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 21c7acfe5b6e..6d4a6c7de902 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -519,15 +519,9 @@ def main(args): if args.use_xformers: if is_xformers_available(): - try: - unet.enable_xformers_memory_efficient_attention() - except Exception as e: - logger.warning( - "Could not enable memory efficient attention. Make sure xformers is installed" - f" correctly and a GPU is available: {e}" - ) + unet.enable_xformers_memory_efficient_attention() else: - logger.warning("Could not enable memory efficient attention. Make sure xformers is installed correctly.") + raise ValueError("xformers is not available. Make sure it is installed correctly") vae.requires_grad_(False) if not args.train_text_encoder: diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 4a2fdb05d9e4..2de39c67e626 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -386,15 +386,9 @@ def main(): if args.use_xformers: if is_xformers_available(): - try: - unet.enable_xformers_memory_efficient_attention() - except Exception as e: - logger.warning( - "Could not enable memory efficient attention. Make sure xformers is installed" - f" correctly and a GPU is available: {e}" - ) + unet.enable_xformers_memory_efficient_attention() else: - logger.warning("Could not enable memory efficient attention. Make sure xformers is installed correctly.") + raise ValueError("xformers is not available. Make sure it is installed correctly") # Freeze vae and text_encoder vae.requires_grad_(False) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index cd6025994dd8..3016ed88a7e8 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -460,15 +460,9 @@ def main(): if args.use_xformers: if is_xformers_available(): - try: - unet.enable_xformers_memory_efficient_attention() - except Exception as e: - logger.warning( - "Could not enable memory efficient attention. Make sure xformers is installed" - f" correctly and a GPU is available: {e}" - ) + unet.enable_xformers_memory_efficient_attention() else: - logger.warning("Could not enable memory efficient attention. Make sure xformers is installed correctly.") + raise ValueError("xformers is not available. Make sure it is installed correctly") # Resize the token embeddings as we are adding new special tokens to the tokenizer text_encoder.resize_token_embeddings(len(tokenizer)) From f83a02376da0adacf6804413c2a47f8228f43860 Mon Sep 17 00:00:00 2001 From: kn Date: Tue, 27 Dec 2022 13:26:15 -0500 Subject: [PATCH 3/5] Rename use_xformers to enable_xformers_memory_efficient_attention --- examples/dreambooth/train_dreambooth.py | 4 ++-- examples/text_to_image/train_text_to_image.py | 4 ++-- examples/textual_inversion/textual_inversion.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 6d4a6c7de902..82db92ca3b79 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -248,7 +248,7 @@ def parse_args(input_args=None): ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") - parser.add_argument("--use_xformers", action="store_true", help="Whether or not to use xformers.") + parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.") if input_args is not None: args = parser.parse_args(input_args) @@ -517,7 +517,7 @@ def main(args): revision=args.revision, ) - if args.use_xformers: + if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): unet.enable_xformers_memory_efficient_attention() else: diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 2de39c67e626..5dadea573eee 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -234,7 +234,7 @@ def parse_args(): ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' ), ) - parser.add_argument("--use_xformers", action="store_true", help="Whether or not to use xformers.") + parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.") args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -384,7 +384,7 @@ def main(): revision=args.revision, ) - if args.use_xformers: + if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): unet.enable_xformers_memory_efficient_attention() else: diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 3016ed88a7e8..d0380a24cdf5 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -222,7 +222,7 @@ def parse_args(): ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' ), ) - parser.add_argument("--use_xformers", action="store_true", help="Whether or not to use xformers.") + parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.") args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -458,7 +458,7 @@ def main(): revision=args.revision, ) - if args.use_xformers: + if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): unet.enable_xformers_memory_efficient_attention() else: From 8dc93aef5b4e8c53fe4d20c6d3c27d3e3e3a65e1 Mon Sep 17 00:00:00 2001 From: kn Date: Tue, 27 Dec 2022 13:32:35 -0500 Subject: [PATCH 4/5] Add a note about xformers in README --- examples/dreambooth/README.md | 5 ++++- examples/text_to_image/README.md | 3 +++ examples/textual_inversion/README.md | 3 +++ 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index 2bbdb7a5da8f..2858c04c48b0 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -317,4 +317,7 @@ python train_dreambooth_flax.py \ --max_train_steps=800 ``` -You can also use Dreambooth to train the specialized in-painting model. See [the script in the research folder for details](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/dreambooth_inpaint). \ No newline at end of file +### Training with xformers: +You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation. + +You can also use Dreambooth to train the specialized in-painting model. See [the script in the research folder for details](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/dreambooth_inpaint). diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md index 407578e3b717..e98e136a4b31 100644 --- a/examples/text_to_image/README.md +++ b/examples/text_to_image/README.md @@ -160,3 +160,6 @@ python train_text_to_image_flax.py \ --max_grad_norm=1 \ --output_dir="sd-pokemon-model" ``` + +### Training with xformers: +You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation. diff --git a/examples/textual_inversion/README.md b/examples/textual_inversion/README.md index a2bde75b51de..3a7c96be69fb 100644 --- a/examples/textual_inversion/README.md +++ b/examples/textual_inversion/README.md @@ -124,3 +124,6 @@ python textual_inversion_flax.py \ --output_dir="textual_inversion_cat" ``` It should be at least 70% faster than the PyTorch script with the same configuration. + +### Training with xformers: +You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation. From ee3e63790540b44213885e3cdd7ede75672ea00b Mon Sep 17 00:00:00 2001 From: kn Date: Tue, 27 Dec 2022 13:41:24 -0500 Subject: [PATCH 5/5] Reformat code style --- examples/dreambooth/train_dreambooth.py | 4 +++- examples/text_to_image/train_text_to_image.py | 4 +++- examples/textual_inversion/textual_inversion.py | 4 +++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 82db92ca3b79..377f226150ee 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -248,7 +248,9 @@ def parse_args(input_args=None): ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") - parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) if input_args is not None: args = parser.parse_args(input_args) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 5dadea573eee..224fe471889e 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -234,7 +234,9 @@ def parse_args(): ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' ), ) - parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index d0380a24cdf5..74fcf71cb22c 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -222,7 +222,9 @@ def parse_args(): ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' ), ) - parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1))