diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index e8c2823f3793..40f321fcb294 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -397,7 +397,7 @@ class TrainingArguments: When resuming training, whether or not to skip the epochs and batches to get the data loading at the same stage as in the previous training. If set to `True`, the training will begin faster (as that skipping step can take a long time) but will not yield the same results as the interrupted training would have. - sharded_ddp (`bool`, `str` or list of [`~trainer_utils.ShardedDDPOption`], *optional*, defaults to `False`): + sharded_ddp (`bool`, `str` or list of [`~trainer_utils.ShardedDDPOption`], *optional*, defaults to `''`): Use Sharded DDP training from [FairScale](https://github.com/facebookresearch/fairscale) (in distributed training only). This is an experimental feature. @@ -412,7 +412,7 @@ class TrainingArguments: If a string is passed, it will be split on space. If a bool is passed, it will be converted to an empty list for `False` and `["simple"]` for `True`. - fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `False`): + fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `''`): Use PyTorch Distributed Parallel Training (in distributed training only). A list of options along the following: @@ -944,7 +944,7 @@ class TrainingArguments: ) }, ) - sharded_ddp: str = field( + sharded_ddp: Optional[Union[List[ShardedDDPOption], str]] = field( default="", metadata={ "help": ( @@ -955,7 +955,7 @@ class TrainingArguments: ), }, ) - fsdp: str = field( + fsdp: Optional[Union[List[FSDPOption], str]] = field( default="", metadata={ "help": ( @@ -980,8 +980,8 @@ class TrainingArguments: default=None, metadata={ "help": ( - "Config to be used with FSDP (Pytorch Fully Sharded Data Parallel). The value is either a" - "fsdp json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`." + "Config to be used with FSDP (Pytorch Fully Sharded Data Parallel). The value is either a" + "fsdp json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`." ) }, ) @@ -994,11 +994,11 @@ class TrainingArguments: ) }, ) - deepspeed: Optional[str] = field( + deepspeed: Optional[Union[str, Dict]] = field( default=None, metadata={ "help": ( - "Enable deepspeed and pass the path to deepspeed json config file (e.g. ds_config.json) or an already" + "Enable deepspeed and pass the path to deepspeed json config file (e.g. `ds_config.json`) or an already" " loaded json file as a dict" ) },