From 90669060ca716190c79e384e45f8752925a9b0cf Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Wed, 12 Apr 2023 11:43:00 -0700 Subject: [PATCH 1/5] [deepspeed] partial ZeRO-3 support --- examples/text_to_image/train_text_to_image.py | 36 ++++++++++++++++--- src/diffusers/training_utils.py | 15 +++++--- 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 4bbf4706f01c..a552192a5836 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -461,10 +461,38 @@ def main(): tokenizer = CLIPTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision ) - text_encoder = CLIPTextModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision - ) - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + + from accelerate.state import AcceleratorState + from transformers.utils import ContextManagers + + def get_deepspeed_plugin(): + if accelerate.state.is_initialized(): + return AcceleratorState().deepspeed_plugin + else: + return None + + def deepspeed_zero_init_disabled_context_manager(): + """ + returns either a context list that includes one that will disable zero.Init or an empty context list + """ + deepspeed_plugin = get_deepspeed_plugin() + if deepspeed_plugin is None: + return [] + + return [deepspeed_plugin.zero3_init_context_manager(enable=False)] + + # Currently Accelerate doesn't know how to handle multiple models Deepspeed ZeRO stage 3. + # For this to work properly all models must be run through `accelerate.prepare`. But accelerate will try to assign the same optimizer with the same weights to all models during deepspeed.initialize, which of course doesn't work. + # For now this will partially support Deepspeed ZeRO-3, by excluding the 2 frozen models from being partitioned during zero.Init which gets called during `from_pretrained` + # So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding across multiple gpus and only UNet2DConditionModel will get zero sharded. + with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision + ) + unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision ) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 340b96e29ac5..f0379d6ce5f0 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -5,6 +5,7 @@ import numpy as np import torch +import transformers from .utils import deprecate @@ -197,11 +198,17 @@ def step(self, parameters: Iterable[torch.nn.Parameter]): self.cur_decay_value = decay one_minus_decay = 1 - decay + if transformers.deepspeed.is_deepspeed_zero3_enabled(): + import deepspeed + context_manager = [] for s_param, param in zip(self.shadow_params, parameters): - if param.requires_grad: - s_param.sub_(one_minus_decay * (s_param - param)) - else: - s_param.copy_(param) + if transformers.deepspeed.is_deepspeed_zero3_enabled(): + context_manager = [deepspeed.zero.GatheredParameters(param, modifier_rank=None)] + with transformers.utils.ContextManagers(context_manager): + if param.requires_grad: + s_param.sub_(one_minus_decay * (s_param - param)) + else: + s_param.copy_(param) def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: """ From ebaa4ffd43d7f85ab51ec694c152544705059a33 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Wed, 12 Apr 2023 11:54:01 -0700 Subject: [PATCH 2/5] cleanup --- examples/text_to_image/train_text_to_image.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index a552192a5836..de9196157818 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -481,10 +481,15 @@ def deepspeed_zero_init_disabled_context_manager(): return [deepspeed_plugin.zero3_init_context_manager(enable=False)] - # Currently Accelerate doesn't know how to handle multiple models Deepspeed ZeRO stage 3. - # For this to work properly all models must be run through `accelerate.prepare`. But accelerate will try to assign the same optimizer with the same weights to all models during deepspeed.initialize, which of course doesn't work. - # For now this will partially support Deepspeed ZeRO-3, by excluding the 2 frozen models from being partitioned during zero.Init which gets called during `from_pretrained` - # So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding across multiple gpus and only UNet2DConditionModel will get zero sharded. + # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3. + # For this to work properly all models must be run through `accelerate.prepare`. But accelerate + # will try to assign the same optimizer with the same weights to all models during + # `deepspeed.initialize`, which of course doesn't work. + # + # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2 + # frozen models from being partitioned during `zero.Init` which gets called during + # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding + # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. with ContextManagers(deepspeed_zero_init_disabled_context_manager()): text_encoder = CLIPTextModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision From 6e390bebb8a66e0fd6404e5e55f00379ffc5c7a8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 11 May 2023 11:40:28 +0200 Subject: [PATCH 3/5] improve deepspeed fixes --- examples/text_to_image/train_text_to_image.py | 14 ++++---------- src/diffusers/training_utils.py | 18 ++++++++++++------ 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index de9196157818..3ffafcf714fb 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -36,6 +36,9 @@ from torchvision import transforms from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer +from accelerate.state import AcceleratorState +from transformers.utils import ContextManagers + import diffusers from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel @@ -462,20 +465,11 @@ def main(): args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision ) - from accelerate.state import AcceleratorState - from transformers.utils import ContextManagers - - def get_deepspeed_plugin(): - if accelerate.state.is_initialized(): - return AcceleratorState().deepspeed_plugin - else: - return None - def deepspeed_zero_init_disabled_context_manager(): """ returns either a context list that includes one that will disable zero.Init or an empty context list """ - deepspeed_plugin = get_deepspeed_plugin() + deepspeed_plugin = AcceleratorState() if accelerate.state.is_initialized() else None if deepspeed_plugin is None: return [] diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index f0379d6ce5f0..5926263397b9 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -2,12 +2,16 @@ import os import random from typing import Any, Dict, Iterable, Optional, Union +import contextlib import numpy as np import torch import transformers -from .utils import deprecate +from .utils import deprecate, is_transformers_available + +if is_transformers_available(): + import transformers def enable_full_determinism(seed: int): @@ -198,13 +202,15 @@ def step(self, parameters: Iterable[torch.nn.Parameter]): self.cur_decay_value = decay one_minus_decay = 1 - decay - if transformers.deepspeed.is_deepspeed_zero3_enabled(): + context_manager = contextlib.nullcontext + if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): import deepspeed - context_manager = [] + for s_param, param in zip(self.shadow_params, parameters): - if transformers.deepspeed.is_deepspeed_zero3_enabled(): - context_manager = [deepspeed.zero.GatheredParameters(param, modifier_rank=None)] - with transformers.utils.ContextManagers(context_manager): + if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): + context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None) + + with context_manager(): if param.requires_grad: s_param.sub_(one_minus_decay * (s_param - param)) else: From 73a82874d90d26a1499e014eb8e5a74b1618fa9b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 11 May 2023 11:43:35 +0200 Subject: [PATCH 4/5] Improve --- examples/text_to_image/train_text_to_image.py | 3 +-- src/diffusers/training_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 3ffafcf714fb..8f0bda220933 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -29,6 +29,7 @@ import transformers from accelerate import Accelerator from accelerate.logging import get_logger +from accelerate.state import AcceleratorState from accelerate.utils import ProjectConfiguration, set_seed from datasets import load_dataset from huggingface_hub import create_repo, upload_folder @@ -36,10 +37,8 @@ from torchvision import transforms from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer -from accelerate.state import AcceleratorState from transformers.utils import ContextManagers - import diffusers from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 5926263397b9..1a3abb49a065 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -1,15 +1,15 @@ +import contextlib import copy import os import random from typing import Any, Dict, Iterable, Optional, Union -import contextlib import numpy as np import torch -import transformers from .utils import deprecate, is_transformers_available + if is_transformers_available(): import transformers From def2072ff2356d172c45a9530daf25d2a744d4bb Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 11 May 2023 11:44:17 +0200 Subject: [PATCH 5/5] make style --- src/diffusers/loaders.py | 2 +- tests/test_lora_layers.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index e814981a85c9..aa94143a8dda 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -792,7 +792,7 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): """ # Loop over the original attention modules. for name, _ in self.text_encoder.named_modules(): - if any([x in name for x in TEXT_ENCODER_TARGET_MODULES]): + if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): # Retrieve the module and its corresponding LoRA processor. module = self.text_encoder.get_submodule(name) # Construct a new function that performs the LoRA merging. We will monkey patch diff --git a/tests/test_lora_layers.py b/tests/test_lora_layers.py index 9bcdc5d93301..87ec4b015363 100644 --- a/tests/test_lora_layers.py +++ b/tests/test_lora_layers.py @@ -46,7 +46,7 @@ def create_unet_lora_layers(unet: nn.Module): def create_text_encoder_lora_layers(text_encoder: nn.Module): text_lora_attn_procs = {} for name, module in text_encoder.named_modules(): - if any([x in name for x in TEXT_ENCODER_TARGET_MODULES]): + if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): text_lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None) text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) return text_encoder_lora_layers