Fix FSDP2 defaulting to version 1 in TrainingArguments; add dynamic plugin param passthrough#42521
Fix FSDP2 defaulting to version 1 in TrainingArguments; add dynamic plugin param passthrough#42521amanzoni1 wants to merge 7 commits intohuggingface:mainfrom
Conversation
f19d66a to
bcd3599
Compare
| fsdp_plugin_args["fsdp_version"] = self.fsdp_config.get("fsdp_version", 1) | ||
| prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH") | ||
| fsdp_plugin_args["backward_prefetch"] = prefetch_policy.upper() | ||
| fsdp_plugin_args["forward_prefetch"] = str(self.fsdp_config.get("forward_prefetch", "false")).lower() | ||
|
|
||
| sync_module_states = str(self.fsdp_config.get("sync_module_states", "true")).lower() | ||
| cpu_ram_efficient_loading = str(self.fsdp_config.get("cpu_ram_efficient_loading", "false")).lower() | ||
| if sync_module_states == "false" and cpu_ram_efficient_loading == "true": | ||
| # In this case, all the processes except the main process would have random weights leading | ||
| # to unexpected behaviour during training, thus throwing error here to prevent it. | ||
| raise ValueError('`sync_module_states` must be `"True"` if `cpu_ram_efficient_loading` is `"True"`') | ||
|
|
||
| # we need to set the env here as otherwise we get a warning in accelerate + we need to set it for transformers | ||
| fsdp_plugin_args["cpu_ram_efficient_loading"] = cpu_ram_efficient_loading | ||
| os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_efficient_loading | ||
|
|
||
| fsdp_plugin_args["sync_module_states"] = sync_module_states | ||
| fsdp_plugin_args["use_orig_params"] = str(self.fsdp_config.get("use_orig_params", "true")).lower() |
There was a problem hiding this comment.
I think it will be better to keep this for now as we can easily change default config here instead of doing it in accelerate. Also, some defaults are not the same here vs accelerate. Nevertheless, for keys that are not exposed here but we have them in FullyShardedDataParallelPlugin, it's fine to set them.
| if key in self.fsdp_config and key not in fsdp_plugin_args: | ||
| fsdp_plugin_args[key] = self.fsdp_config[key] |
There was a problem hiding this comment.
we can keep this but put it at the end.
| # Set deprecated sharding_strategy from CLI (plugin maps to reshard_after_forward) | ||
| # Skip if config has explicit reshard_after_forward (prioritize config) | ||
| if "reshard_after_forward" not in self.fsdp_config: | ||
| fsdp_plugin_args["sharding_strategy"] = fsdp_option |
| if self.fsdp_config is not None and isinstance(self.fsdp_config, dict): | ||
| for k in list(self.fsdp_config.keys()): | ||
| if k.startswith("fsdp_"): | ||
| if k.startswith("fsdp_") and k != "fsdp_version": |
| def test_fsdp_version_2_config(self): | ||
| output_dir = self.get_auto_remove_tmp_dir() | ||
| kwargs = { | ||
| "output_dir": output_dir, | ||
| "train_len": 128, | ||
| "save_steps": 5, | ||
| "learning_rate": 0.1, | ||
| "fsdp": True, | ||
| "fsdp_config": { | ||
| "fsdp_version": 2, | ||
| "reshard_after_forward": True, | ||
| }, | ||
| } | ||
| with mockenv_context(**self.dist_env_1_gpu): | ||
| trainer = get_regression_trainer(**kwargs) | ||
| plugin_args = trainer.args._process_fsdp_args() | ||
| self.assertEqual(plugin_args["fsdp_version"], 2) | ||
| self.assertTrue(plugin_args["reshard_after_forward"]) | ||
|
|
|
Thanks for the feedback @SunMarc!! |
|
I've merged this PR that seems to fix a bit more things ! Feel free to rebase and add your updates, especially the test is nice ! |
9a8af59 to
b6a9f0b
Compare
|
@SunMarc Thanks for the feedback! Rebased onto the latest main, and ready for review |
| # HF-to-plugin map | ||
| if ( | ||
| "transformer_layer_cls_to_wrap" in self.fsdp_config | ||
| and "transformer_cls_names_to_wrap" not in fsdp_plugin_args | ||
| ): | ||
| fsdp_plugin_args["transformer_cls_names_to_wrap"] = ",".join( | ||
| self.fsdp_config["transformer_layer_cls_to_wrap"] | ||
| ) |
| "fsdp_config": { | ||
| "fsdp_version": 2, | ||
| "reshard_after_forward": True, | ||
| }, |
There was a problem hiding this comment.
can you add more things in the config so that it is better tested ?
61c0e74 to
2252af4
Compare
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Head branch was pushed to by a user without write access
3761cdd to
ff47d0d
Compare
|
Hey @SunMarc, not sure if there's anything more I should do here. The failed CI doesn't seem code related. Thanks! |
|
cc @ArthurZucker can you merge ? |
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=42521&sha=661f54 |
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=42521&sha=dd3f32 |
What does this PR do?
Fixes the bug where
TrainingArguments(fsdp=True, fsdp_config={"fsdp_version": 2, ...})defaults to FSDP version 1, ignoring the version (unless the Accelerator was initialized manually), and most params were lost.Adds dynamic passthrough of all FSDP plugin params (FullyShardedDataParallelPlugin) from fsdp_config to the plugin args (future-proof, no hardcoding).
Changes
fsdp_versionpre-stripping to preserve it.FullyShardedDataParallelPlugindataclass for all fields.test_fsdp_version_2_configfor regression.Repro
Before submitting
Pull Request section?
to it if that's the case.
https://discuss.huggingface.co/t/how-to-start-fsdp2-when-using-trainer/151885
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@SunMarc @3outeille
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.