Skip to content

Fix FSDP resume Initialization issue#34032

Merged
SunMarc merged 6 commits intohuggingface:mainfrom
Itssshikhar:fix-fsdp-resume
Oct 15, 2024
Merged

Fix FSDP resume Initialization issue#34032
SunMarc merged 6 commits intohuggingface:mainfrom
Itssshikhar:fix-fsdp-resume

Conversation

@Itssshikhar
Copy link
Copy Markdown
Contributor

Addresses the issue with Fully Sharded Data Parallel (FSDP) initialization when resuming training from a checkpoint. It implements a solution by adding a dummy forward pass during the initialization process.

Fixes #31892

Added tests in the test_trainer.py file to ensure proper FSDP initialization

@muellerzr @SunMarc I am creating a draft PR, let me know if there anymore changes that I can make

@SunMarc
Copy link
Copy Markdown
Member

SunMarc commented Oct 9, 2024

Thanks for the PR ! Could you explain a bit more why this PR fixes the issue that you linked ? Thanks

@Itssshikhar
Copy link
Copy Markdown
Contributor Author

Yeah, sure.

There is a similar issue in Pytorch (pytorch/pytorch#113496) which causes the same error. Reason being, initialization error in the forward pass, which causes FSDP to fail.

The Fix seems fairly simple, as we just have to run forward pass once using dummy values, before initializing FSDP.

Copy link
Copy Markdown
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Fix makes sense to me, thanks for the explanation. Could you document and add the link to that issue on top of the _init_fsdp func so we can fully trace why this is needed?

Also please do pip install -e .[quality]; make fixup and this will fix the quality tests.

@Itssshikhar Itssshikhar marked this pull request as ready for review October 15, 2024 07:04
@Itssshikhar
Copy link
Copy Markdown
Contributor Author

@muellerzr @SunMarc
All the tests have passed, but one remains tests_non_models that requires to have CUDA.

It would be great if you guys can see to it once and if there's anything from my end that needs to be done?!

Thanks!

Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I left a comment to show what to fix in order to pass the CI !

group = trainer.get_optimizer_group(param)
self.assertIn(param, group["params"])


Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to pass a @require_cuda decorator for this test !

@SunMarc SunMarc merged commit 4de1bdb into huggingface:main Oct 15, 2024
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown

@ringohoffman ringohoffman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR broke the test in tests/trainer/test_trainer_fsdp.py, which actually tests initializing a trainer using FSDP.

Given that there are also some flaws in the logic of this PR, it might be worth reverting this so it can be properly relanded.

@SunMarc

dtype=torch.long,
device=device,
)
for name in model.forward.__code__.co_varnames
Copy link
Copy Markdown

@ringohoffman ringohoffman Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are the variable names inside of forward... not the parameters to forward. I think you probably meant to do something like inspect.signature.

Comment on lines +296 to +301
name: torch.ones(
(1, 512),
dtype=torch.long,
device=device,
)
for name in model.forward.__code__.co_varnames
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not every parameter to forward is a tensor, but you are sending in a tensor for every value.

@Qubitium
Copy link
Copy Markdown
Contributor

Regression as result of this merge for trl/sft + fsdp training on 2x gpu.

TypeError: LlamaForCausalLM.forward() got an unexpected keyword argument 'args'

trl 0.11.4
accelerate 1.0.1
transformers 4.46.0.dev
File "/python/ai/train/sft_trainer.py", line 380, in <module>
    trainer = SFTTrainer(
              ^^^^^^^^^^^
  File "/python/ai/train/sft_trainer.py", line 380, in <module>
    trainer = SFTTrainer(
              ^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/huggingface_hub/utils/_deprecation.py", line 101, in inner_f
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/trl/trainer/sft_trainer.py", line 401, in __init__
    super().__init__(
  File "/root/miniconda3/lib/python3.11/site-packages/huggingface_hub/utils/_deprecation.py", line 101, in inner_f
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/transformers/utils/deprecation.py", line 165, in wrapped_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/trl/trainer/sft_trainer.py", line 401, in __init__
    super().__init__(
  File "/root/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 639, in __init__
    self.model = _init_fsdp(self.model, self.accelerator, self.args.device)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/transformers/utils/deprecation.py", line 165, in wrapped_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 305, in _init_fsdp
    _ = model(**dummy_input)
        ^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 639, in __init__
    self.model = _init_fsdp(self.model, self.accelerator, self.args.device)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 305, in _init_fsdp
    _ = model(**dummy_input)
        ^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 863, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 863, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/accelerate/utils/operations.py", line 820, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/accelerate/utils/operations.py", line 808, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/accelerate/utils/operations.py", line 820, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
TypeError: LlamaForCausalLM.forward() got an unexpected keyword argument 'args'
  File "/root/miniconda3/lib/python3.11/site-packages/accelerate/utils/operations.py", line 808, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
TypeError: LlamaForCausalLM.forward() got an unexpected keyword argument 'args'

@SunMarc
Copy link
Copy Markdown
Member

SunMarc commented Oct 16, 2024

Thanks for the heads-up @Qubitium @ringohoffman ! I will revert this PR !

SunMarc added a commit that referenced this pull request Oct 16, 2024
@Itssshikhar
Copy link
Copy Markdown
Contributor Author

Thanks for info @Qubitium @ringohoffman on the PR. I'll try to resolve the errors.

muellerzr pushed a commit that referenced this pull request Oct 16, 2024
Revert "Fix FSDP resume Initialization issue (#34032)"

This reverts commit 4de1bdb.
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* Fix FSDP Initialization for resume training

* Added init_fsdp function to work with dummy values

* Fix FSDP initialization for resuming training

* Added CUDA decorator for tests

* Added torch_gpu decorator to FSDP tests

* Fixup for failing code quality tests
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
Revert "Fix FSDP resume Initialization issue (huggingface#34032)"

This reverts commit 4de1bdb.
@Itssshikhar Itssshikhar deleted the fix-fsdp-resume branch February 3, 2025 11:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Load fsdp+lora checkpoint error

7 participants