Skip to content

Validation Errors When Training ControlNet with FSDP #4037

@liming-ai

Description

@liming-ai

Describe the bug

Adding FSDP into training the given ControlNet example training code leads to an unexpected bug, with the following config:

compute_environment: LOCAL_MACHINE
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: SIZE_BASED_WRAP
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_min_num_params: 100000000
  fsdp_offload_params: false
  fsdp_sharding_strategy: 1
  fsdp_state_dict_type: FULL_STATE_DICT
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

This error is caused by this line cannot unwrap FSDP into the original class

controlnet = accelerator.unwrap_model(controlnet)

# type(controlnet): <class 'torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel'
controlnet = accelerator.unwrap_model(controlnet)
# type(controlnet): <class 'torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel'

To this end, we cannot use the default diffusers.StableDiffusionControlNetPipeline to run the inference/validation:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/tiger/diffusers/examples/controlnet/train_controlnet.py:1127 in <module>                   │
│                                                                                                  │
│   1124                                                                                           │
│   1125 if __name__ == "__main__":                                                                │
│   1126 │   args = parse_args()                                                                   │
│ ❱ 1127 │   main(args)                                                                            │
│   1128                                                                                           │
│                                                                                                  │
│ /home/tiger/diffusers/examples/controlnet/train_controlnet.py:1083 in main                       │
│                                                                                                  │
│   1080 │   │   │   │   │   │   │   │   │   shutil.rmtree(removing_checkpoint)                    │
│   1081 │   │   │   │   │                                                                         │
│   1082 │   │   │   │   │   if args.validation_prompt is not None and global_step % args.validat  │
│ ❱ 1083 │   │   │   │   │   │   image_logs = log_validation(                                      │
│   1084 │   │   │   │   │   │   │   vae,                                                          │
│   1085 │   │   │   │   │   │   │   text_encoder,                                                 │
│   1086 │   │   │   │   │   │   │   tokenizer,                                                    │
│                                                                                                  │
│ /home/tiger/diffusers/examples/controlnet/train_controlnet.py:126 in log_validation              │
│                                                                                                  │
│    123 │   │                                                                                     │
│    124 │   │   for _ in range(args.num_validation_images):                                       │
│    125 │   │   │   with torch.autocast("cuda"):                                                  │
│ ❱  126 │   │   │   │   image = pipeline(                                                         │
│    127 │   │   │   │   │   validation_prompt, validation_image, num_inference_steps=20, generat  │
│    128 │   │   │   │   ).images[0]                                                               │
│    129                                                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.9/dist-packages/torch/utils/_contextlib.py:115 in decorate_context        │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ /home/tiger/.local/lib/python3.9/site-packages/diffusers/pipelines/controlnet/pipeline_controlne │
│ t.py:840 in __call__                                                                             │
│                                                                                                  │
│    837 │   │   │   ]                                                                             │
│    838 │   │                                                                                     │
│    839 │   │   # 1. Check inputs. Raise error if not correct                                     │
│ ❱  840 │   │   self.check_inputs(                                                                │
│    841 │   │   │   prompt,                                                                       │
│    842 │   │   │   image,                                                                        │
│    843 │   │   │   callback_steps,                                                               │
│                                                                                                  │
│ /home/tiger/.local/lib/python3.9/site-packages/diffusers/pipelines/controlnet/pipeline_controlne │
│ t.py:570 in check_inputs                                                                         │
│                                                                                                  │
│    567 │   │   │   for image_ in image:                                                          │
│    568 │   │   │   │   self.check_image(image_, prompt, prompt_embeds)                           │
│    569 │   │   else:                                                                             │
│ ❱  570 │   │   │   assert False                                                                  │
│    571 │   │                                                                                     │
│    572 │   │   # Check `controlnet_conditioning_scale`                                           │
│    573 │   │   if (                                                                              │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AssertionError

Reproduction

Please use FSDP, and run the given ControlNet example training code.

Logs

No response

System Info

  • diffusers version: 0.18.0.dev0
  • Platform: Linux-5.4.56.bsk.10-amd64-x86_64-with-glibc2.31
  • Python version: 3.9.2
  • PyTorch version (GPU?): 2.0.0+cu117 (True)
  • Huggingface_hub version: 0.15.1
  • Transformers version: 4.27.4
  • Accelerate version: 0.20.3
  • xFormers version: 0.0.18
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: Yes

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleIssues that haven't received updates

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions