Skip to content

Default synced_gpus to True when using FullyShardedDataParallel#33483

Merged
muellerzr merged 17 commits intohuggingface:mainfrom
ringohoffman:fsdp-generate-synced_gpus-default-to-true
Oct 10, 2024
Merged

Default synced_gpus to True when using FullyShardedDataParallel#33483
muellerzr merged 17 commits intohuggingface:mainfrom
ringohoffman:fsdp-generate-synced_gpus-default-to-true

Conversation

@ringohoffman
Copy link
Copy Markdown

What does this PR do?

Fixes #30228

Related:

Similar to DeepSpeed ZeRO Stage 3, when using FSDP with multiple GPUs and differently sized data per rank, the ranks reach different synchronization points at the same time, leading to deadlock.

To avoid this, we can automatically set synced_gpus to True if we detect that a PreTrainedModel is being managed by FSDP using _is_fsdp_managed_module, which was added in 2.0.0 for torch.compile: https://github.com/pytorch/pytorch/blob/v2.0.0/torch/distributed/fsdp/_dynamo_utils.py

To facilitate this, I created a module called transformers.integrations.fsdp containing the function is_fsdp_managed_module which returns True if a Module has _is_fsdp_managed_module set to True on it or if the Module itself is a FullyShardedDataParallel instance.

Here is the script I used to test my fix:

OMP_NUM_THREADS=2 \
TOKENIZERS_PARALLELISM=false \
CUDA_VISIBLE_DEVICES=6,7 \
torchrun \
    --rdzv-backend=c10d \
    --rdzv-endpoint=localhost:0 \
    --nnodes=1 \
    --nproc-per-node=2 \
    fsdp_generate.py

fsdp_generate.py

import functools

import torch
import torch.distributed
import torch.distributed.fsdp
import torch.distributed.fsdp.wrap
import transformers
import transformers.models.gpt_neo.modeling_gpt_neo


def main() -> None:
    torch.distributed.init_process_group(world_size=2)
    device = torch.device(torch.distributed.get_rank())
    torch.cuda.set_device(device)

    pretrained_model_name_or_path = "EleutherAI/gpt-neo-125m"
    model = transformers.AutoModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path,
        device_map=device,
        attn_implementation="flash_attention_2",  # I'm using flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
        torch_dtype=torch.bfloat16,
    )
    assert isinstance(model, transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoForCausalLM)

    tokenizer = transformers.AutoTokenizer.from_pretrained(
        pretrained_model_name_or_path,
    )
    tokenizer.pad_token_id = tokenizer.eos_token_id

    fsdp_model = torch.distributed.fsdp.FullyShardedDataParallel(
        model,
        auto_wrap_policy=functools.partial(
            torch.distributed.fsdp.wrap.transformer_auto_wrap_policy,
            transformer_layer_cls={
                transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoBlock
            },
        ),
        limit_all_gathers=True,
        use_orig_params=True,  # required to overcome the error "The tensor has a non-zero number of elements, but its data is not allocated yet" ... PreTrainedModel.generate is probably using some torch.compile-wrapped function
    )

    data_by_rank = {  # differently sized causes FSDP to hang
        0: "Hello world!",
        1: "The quick brown fox jumps over the lazy dog."
    }

    batch = tokenizer(
        data_by_rank[torch.distributed.get_rank()],
        return_tensors="pt",
        return_attention_mask=True,
    ).to(device)

    with torch.distributed.fsdp.FullyShardedDataParallel.summon_full_params(fsdp_model):  # required to overcome to the error "'weight' must be 2-D"
        generated_text = fsdp_model.module.generate(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            max_length=20,
            # synced_gpus=True,  # currently, True is required to use differently sized data with FSDP + generate (current default is False)
        )

    torch.distributed.barrier()
    torch.distributed.destroy_process_group()


if __name__ == "__main__":
    main()

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@gante
@ArthurZucker

Matthew Hoffman added 4 commits September 13, 2024 17:05
Fixes huggingface#30228

Related:

* pytorch/pytorch#100069
* pytorch/pytorch#123962

Similar to DeepSpeed ZeRO Stage 3, when using FSDP with multiple GPUs and differently sized data per rank, the ranks reach different synchronization points at the same time, leading to deadlock

To avoid this, we can automatically set synced_gpus to True if we detect that a PreTrainedModel is being managed by FSDP using _is_fsdp_managed_module, which was added in 2.0.0 for torch.compile: https://github.com/pytorch/pytorch/blob/v2.0.0/torch/distributed/fsdp/_dynamo_utils.py
@ringohoffman
Copy link
Copy Markdown
Author

I should also mention that my script fails when using attn_implementation="sdpa" with an off-by-one error. I don't see this error when using attn_implementation="flash_attention_2"...

[rank1]: Traceback (most recent call last):
[rank1]:   File "/home/matthew/transformers/fsdp_generate.py", line 64, in <module>
[rank1]:     main()
[rank1]:   File "/home/matthew/transformers/fsdp_generate.py", line 52, in main
[rank1]:     generated_text = fsdp_model.module.generate(
[rank1]:   File "/home/matthew/.conda/envs/transformers310/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank1]:     return func(*args, **kwargs)
[rank1]:   File "/home/matthew/transformers/src/transformers/generation/utils.py", line 2048, in generate
[rank1]:     result = self._sample(
[rank1]:   File "/home/matthew/transformers/src/transformers/generation/utils.py", line 3001, in _sample
[rank1]:     outputs = self(**model_inputs, return_dict=True)
[rank1]:   File "/home/matthew/.conda/envs/transformers310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/home/matthew/.conda/envs/transformers310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/home/matthew/transformers/src/transformers/models/gpt_neo/modeling_gpt_neo.py", line 1038, in forward
[rank1]:     transformer_outputs = self.transformer(
[rank1]:   File "/home/matthew/.conda/envs/transformers310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/home/matthew/.conda/envs/transformers310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/home/matthew/transformers/src/transformers/models/gpt_neo/modeling_gpt_neo.py", line 801, in forward
[rank1]:     outputs = block(
[rank1]:   File "/home/matthew/.conda/envs/transformers310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/home/matthew/.conda/envs/transformers310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/home/matthew/.conda/envs/transformers310/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 863, in forward
[rank1]:     output = self._fsdp_wrapped_module(*args, **kwargs)
[rank1]:   File "/home/matthew/.conda/envs/transformers310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/home/matthew/.conda/envs/transformers310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/home/matthew/transformers/src/transformers/models/gpt_neo/modeling_gpt_neo.py", line 512, in forward
[rank1]:     attn_outputs = self.attn(
[rank1]:   File "/home/matthew/.conda/envs/transformers310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/home/matthew/.conda/envs/transformers310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/home/matthew/transformers/src/transformers/models/gpt_neo/modeling_gpt_neo.py", line 462, in forward
[rank1]:     return self.attention(
[rank1]:   File "/home/matthew/.conda/envs/transformers310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/home/matthew/.conda/envs/transformers310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/home/matthew/transformers/src/transformers/models/gpt_neo/modeling_gpt_neo.py", line 314, in forward
[rank1]:     attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
[rank1]:   File "/home/matthew/transformers/src/transformers/models/gpt_neo/modeling_gpt_neo.py", line 278, in _attn
[rank1]:     attn_weights = attn_weights + causal_mask
[rank1]: RuntimeError: The size of tensor a (21) must match the size of tensor b (20) at non-singleton dimension 3

@ringohoffman ringohoffman changed the title Default synced_gpus to True when using FullyShardedDataParallel Default synced_gpus to True when using FullyShardedDataParallel Sep 13, 2024
@ringohoffman
Copy link
Copy Markdown
Author

The CI failures seem to be from flaky tests. This should be good for review!

@ringohoffman
Copy link
Copy Markdown
Author

Hey @gante and @ArthurZucker, just following up. Do you think you will be able to review this?

@LysandreJik
Copy link
Copy Markdown
Member

@SunMarc and @muellerzr could you take a look here please?

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me! I am just wondering if we can add a small test, fine to merge first as this looks like it would affect a lot of users!

Comment thread src/transformers/integrations/fsdp.py Outdated
Matthew Hoffman and others added 4 commits October 3, 2024 14:24
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Before huggingface#33483, these tests would have hung for 10 minutes before crashing due to a timeout error
@ringohoffman
Copy link
Copy Markdown
Author

Okay, I gave it a stab based on these files + the example in the description. I added a test for FSDP and FSDP2.

https://github.com/huggingface/transformers/blob/main/tests/trainer/test_trainer_distributed.py
https://github.com/huggingface/transformers/blob/main/tests/generation/test_utils.py

Matthew Hoffman added 2 commits October 3, 2024 16:06
I think this might cause more problems if one of the workers was killed
Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this! This is very helpful! I left a comment. Can you have a second look @muellerzr ?
Also, I think that we can create a doc to explain how to perform generate when the model is initialized under fsdp or deepspeed ! I saw a lot of users bieng confused about this.
Also, if you want to dig deeper, that would be nice to have a comparison with ddp/deepspeed that was done here: huggingface/trl#1483 (comment)

Comment thread src/transformers/integrations/fsdp.py Outdated
Comment on lines +28 to +31
def is_fsdp_managed_module(module: nn.Module) -> bool:
return isinstance(module, torch.distributed.fsdp.FullyShardedDataParallel) or getattr(
module, "_is_fsdp_managed_module", False
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we use is_fsdp_enabled just like how we do it for deepspeed ? I guess users will use fsdp in trainer + accelerate, so that should work. But I understand that it won't work with your general script. cc @muellerzr

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I understand that it won't work with your general script

Right, since is_fsdp_enabled is a Trainer attribute, it would only work if you are using a Trainer, whereas this function works in all scenarios, including when using a Trainer.

@SunMarc SunMarc requested a review from muellerzr October 7, 2024 14:04
Copy link
Copy Markdown
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this looks good to me, just one suggestion please to make sure we can test the Accelerate base case as well

Comment thread src/transformers/integrations/fsdp.py
Copy link
Copy Markdown
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! These tests look great!

@muellerzr muellerzr merged commit 70b07d9 into huggingface:main Oct 10, 2024
@ringohoffman ringohoffman deleted the fsdp-generate-synced_gpus-default-to-true branch October 10, 2024 18:10
@YeLuoSuiYou
Copy link
Copy Markdown

YeLuoSuiYou commented Oct 16, 2024

Hi, I found that maybe have a bug of trainer fsdp init, when I use pytest to test the test code tests/trainer/test_trainer_fsdp.py, I got error.
image
may u git me some advice how can i avoid it? @ringohoffman

my packages version are
torch 2.4.1
accelerate 0.34.2
transformers 4.46.0.dev0
thx a lot

@ringohoffman
Copy link
Copy Markdown
Author

@YeLuoSuiYou Seems like it was due to the logic added in this PR #34032

@YeLuoSuiYou
Copy link
Copy Markdown

@YeLuoSuiYou Seems like it was due to the logic added in this PR #34032@YeLuoSuiYou似乎是由于此 PR 中添加的逻辑所致 #34032

thx a lot, maybe I should wait the normal version to fix it? btw, thx much for the function provided by this PR!

@YeLuoSuiYou
Copy link
Copy Markdown

@YeLuoSuiYou Seems like it was due to the logic added in this PR #34032@YeLuoSuiYou似乎是由于此 PR 中添加的逻辑所致 #34032

I commet this function, the test work well done, thx!

BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
…huggingface#33483)

* Default synced_gpus to True when using FullyShardedDataParallel

Fixes huggingface#30228

Related:

* pytorch/pytorch#100069
* pytorch/pytorch#123962

Similar to DeepSpeed ZeRO Stage 3, when using FSDP with multiple GPUs and differently sized data per rank, the ranks reach different synchronization points at the same time, leading to deadlock

To avoid this, we can automatically set synced_gpus to True if we detect that a PreTrainedModel is being managed by FSDP using _is_fsdp_managed_module, which was added in 2.0.0 for torch.compile: https://github.com/pytorch/pytorch/blob/v2.0.0/torch/distributed/fsdp/_dynamo_utils.py

* Remove test file

* ruff formatting

* ruff format

* Update copyright year

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Add test for FSDP-wrapped model generation

Before huggingface#33483, these tests would have hung for 10 minutes before crashing due to a timeout error

* Ruff format

* Move argparse import

* Remove barrier

I think this might cause more problems if one of the workers was killed

* Move import into function to decrease load time

huggingface#33483 (comment)

* Add test for accelerate and Trainer

huggingface#33483 (comment)

* Refactor imports

* Ruff format

* Use nullcontext

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

FSDP Doesn't Work with model.generate()

6 participants