diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 09cb48cd7f74..bd128b56f117 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -2705,7 +2705,7 @@ def _process_fsdp_args(self): if self.fsdp_config is not None and isinstance(self.fsdp_config, dict): for k in list(self.fsdp_config.keys()): - if k.startswith("fsdp_"): + if k.startswith("fsdp_") and k != "fsdp_version": v = self.fsdp_config.pop(k) self.fsdp_config[k[5:]] = v @@ -2749,6 +2749,7 @@ def _process_fsdp_args(self): # accelerate integration for FSDP fsdp_plugin_args = None if len(self.fsdp) > 0 and not self.fsdp_config["xla"]: + from accelerate.utils import FullyShardedDataParallelPlugin from accelerate.utils.constants import ( FSDP_AUTO_WRAP_POLICY, FSDP_SHARDING_STRATEGY, @@ -2757,7 +2758,10 @@ def _process_fsdp_args(self): fsdp_plugin_args = {} for fsdp_option in self.fsdp: if fsdp_option.upper() in FSDP_SHARDING_STRATEGY: - fsdp_plugin_args["sharding_strategy"] = fsdp_option + # Set deprecated sharding_strategy from CLI (plugin maps to reshard_after_forward) + # Skip if config has explicit reshard_after_forward (prioritize config) + if "reshard_after_forward" not in self.fsdp_config: + fsdp_plugin_args["sharding_strategy"] = fsdp_option elif fsdp_option == FSDPOption.OFFLOAD: fsdp_plugin_args["cpu_offload"] = True elif fsdp_option == FSDPOption.AUTO_WRAP: @@ -2769,7 +2773,8 @@ def _process_fsdp_args(self): fsdp_plugin_args["transformer_cls_names_to_wrap"] = ",".join( self.fsdp_config["transformer_layer_cls_to_wrap"] ) - fsdp_version = int(self.fsdp_config.get("version", 1)) + + fsdp_version = int(self.fsdp_config.get("fsdp_version", 1)) fsdp_plugin_args["fsdp_version"] = fsdp_version prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH") if fsdp_version == 2: @@ -2801,6 +2806,12 @@ def _process_fsdp_args(self): fsdp_plugin_args["sync_module_states"] = str_to_bool(sync_module_states) + # Pull allowed parameters from fsdp_config + ALLOWED_FSDP_PARAMS = {f.name for f in fields(FullyShardedDataParallelPlugin)} + for key in ALLOWED_FSDP_PARAMS: + if key in self.fsdp_config and key not in fsdp_plugin_args: + fsdp_plugin_args[key] = self.fsdp_config[key] + return fsdp_plugin_args diff --git a/tests/fsdp/test_fsdp.py b/tests/fsdp/test_fsdp.py index 7f0cb0482bdb..9c7bd744505e 100644 --- a/tests/fsdp/test_fsdp.py +++ b/tests/fsdp/test_fsdp.py @@ -211,6 +211,37 @@ def test_fsdp_config(self, sharding_strategy, dtype): for k, v in trainer.args.fsdp_config.items(): self.assertEqual(v, self.fsdp_config[k]) + def test_fsdp_version_2_config(self): + output_dir = self.get_auto_remove_tmp_dir() + kwargs = { + "output_dir": output_dir, + "train_len": 128, + "save_steps": 5, + "learning_rate": 0.1, + "fsdp": True, + "fsdp_config": { + "fsdp_version": 2, + "reshard_after_forward": True, + "auto_wrap_policy": "transformer_based_wrap", + "transformer_cls_names_to_wrap": ["BertLayer"], + "state_dict_type": "FULL_STATE_DICT", + "activation_checkpointing": True, + "cpu_offload": True, + "limit_all_gathers": True, + }, + } + with mockenv_context(**self.dist_env_1_gpu): + trainer = get_regression_trainer(**kwargs) + plugin_args = trainer.args._process_fsdp_args() + self.assertEqual(plugin_args["fsdp_version"], 2) + self.assertTrue(plugin_args["reshard_after_forward"]) + self.assertEqual(plugin_args["auto_wrap_policy"], "transformer_based_wrap") + self.assertListEqual(plugin_args["transformer_cls_names_to_wrap"], ["BertLayer"]) + self.assertEqual(plugin_args["state_dict_type"], "FULL_STATE_DICT") + self.assertTrue(plugin_args["activation_checkpointing"]) + self.assertTrue(plugin_args["cpu_offload"]) + self.assertTrue(plugin_args["limit_all_gathers"]) + @parameterized.expand(params, name_func=_parameterized_custom_name_func) @require_torch_multi_accelerator @run_first