From dd4b980334dbef766cb52c6cc0cf64c207b85bf1 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Wed, 6 Jan 2021 20:57:29 -0800 Subject: [PATCH 01/12] fix bad merge - dropped code --- src/transformers/trainer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index effa50b5a92e..e6a221bfd2b2 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -267,6 +267,11 @@ def __init__( ) self.model_init = model_init + if self.args.model_parallel and not model.is_parallelizable: + raise ValueError( + f"{model.__class__.__name__} implementation currently doesn't support model parallelism, therefore --model_parallel cl arg cannot be used" + ) + default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) self.data_collator = data_collator if data_collator is not None else default_collator self.train_dataset = train_dataset From 5d4fde7ee98c4825101b51d0b4f0baf2943d0bc6 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 7 Jan 2021 13:07:31 -0800 Subject: [PATCH 02/12] remove --model_parallel --- src/transformers/trainer.py | 15 +++++---------- src/transformers/training_args.py | 12 ------------ 2 files changed, 5 insertions(+), 22 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e6a221bfd2b2..b21d8e7db99e 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -267,11 +267,6 @@ def __init__( ) self.model_init = model_init - if self.args.model_parallel and not model.is_parallelizable: - raise ValueError( - f"{model.__class__.__name__} implementation currently doesn't support model parallelism, therefore --model_parallel cl arg cannot be used" - ) - default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) self.data_collator = data_collator if data_collator is not None else default_collator self.train_dataset = train_dataset @@ -279,7 +274,7 @@ def __init__( self.tokenizer = tokenizer # Model parallel - if not self.args.model_parallel: + if not (model.is_parallelizable and model.parallel): model = model.to(args.device) # later use `self.model is self.model_wrapped` to check if it's wrapped or not @@ -674,7 +669,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D set_seed(self.args.seed) model = self.call_model_init(trial) - if not self.args.model_parallel: + if not (model.is_parallelizable and model.parallel): model = model.to(self.args.device) self.model = model @@ -724,7 +719,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) # Multi-gpu training (should be after apex fp16 initialization) - if self.args.n_gpu > 1 and not self.args.model_parallel: + if self.args.n_gpu > 1 and not (model.is_parallelizable and model.parallel): model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) @@ -935,7 +930,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D ) if isinstance(self.model, PreTrainedModel): self.model = self.model.from_pretrained(self.state.best_model_checkpoint) - if not self.args.model_parallel: + if not (model.is_parallelizable and model.parallel): self.model = self.model.to(self.args.device) else: state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)) @@ -1486,7 +1481,7 @@ def prediction_loop( model = self.model # multi-gpu eval - if self.args.n_gpu > 1 and not self.args.model_parallel: + if self.args.n_gpu > 1 and not (model.is_parallelizable and model.parallel): model = torch.nn.DataParallel(model) # Note: in torch.distributed mode, there's no point in wrapping the model # inside a DistributedDataParallel as we'll be under `no_grad` anyways. diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 852c97d746c4..b5ab9ccacdb3 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -210,9 +210,6 @@ class TrainingArguments: - :obj:`True` if :obj:`metric_for_best_model` is set to a value that isn't :obj:`"loss"` or :obj:`"eval_loss"`. - :obj:`False` if :obj:`metric_for_best_model` is not set, or set to :obj:`"loss"` or :obj:`"eval_loss"`. - model_parallel (:obj:`bool`, `optional`, defaults to :obj:`False`): - If the model supports model parallelism and there is more than one device, whether to use model parallelism - to distribute the model's modules across devices or not. ignore_skip_data (:obj:`bool`, `optional`, defaults to :obj:`False`): When resuming training, whether or not to skip the epochs and batches to get the data loading at the same stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping @@ -245,15 +242,6 @@ class TrainingArguments: do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) do_eval: bool = field(default=None, metadata={"help": "Whether to run eval on the dev set."}) do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."}) - model_parallel: bool = field( - default=False, - metadata={ - "help": ( - "If there are more than one devices, whether to use model parallelism to distribute the " - "model's modules across devices." - ) - }, - ) evaluation_strategy: EvaluationStrategy = field( default="no", metadata={"help": "The evaluation strategy to use."}, From cb8be9062272db6cd5a356b063c9ed6676f3a0c2 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Thu, 7 Jan 2021 17:01:52 -0500 Subject: [PATCH 03/12] Deal with TrainingArguments --- src/transformers/trainer.py | 7 +++++-- src/transformers/training_args.py | 3 ++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index b21d8e7db99e..222447e95d54 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -276,6 +276,9 @@ def __init__( # Model parallel if not (model.is_parallelizable and model.parallel): model = model.to(args.device) + else: + # Force n_gpu to 1 to avoid DataParallel. + self.args._force_n_gpu = 1 # later use `self.model is self.model_wrapped` to check if it's wrapped or not self.model_wrapped = model @@ -719,7 +722,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) # Multi-gpu training (should be after apex fp16 initialization) - if self.args.n_gpu > 1 and not (model.is_parallelizable and model.parallel): + if self.args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) @@ -1481,7 +1484,7 @@ def prediction_loop( model = self.model # multi-gpu eval - if self.args.n_gpu > 1 and not (model.is_parallelizable and model.parallel): + if self.args.n_gpu > 1: model = torch.nn.DataParallel(model) # Note: in torch.distributed mode, there's no point in wrapping the model # inside a DistributedDataParallel as we'll be under `no_grad` anyways. diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index b5ab9ccacdb3..dcf15fb86db0 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -418,6 +418,7 @@ def __post_init__(self): if is_torch_available() and self.device.type != "cuda" and self.fp16: raise ValueError("Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices.") + self._force_n_gpu = None def __repr__(self): # We override the default repr to remove deprecated arguments from the repr. This method should be removed once @@ -480,7 +481,7 @@ def _setup_devices(self) -> Tuple["torch.device", int]: # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0` # will use the first GPU in that env, i.e. GPU#1 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - n_gpu = torch.cuda.device_count() + n_gpu = torch.cuda.device_count() if self._force_n_gpu is None else self._force_n_gpu else: # Here, we'll use torch.distributed. # Initializes the distributed backend which will take care of synchronizing nodes/GPUs From 70a96932031aac6bb5e9aa79499d89f23df08c29 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Thu, 7 Jan 2021 17:29:02 -0500 Subject: [PATCH 04/12] Use a private attr and fix batch sizes --- src/transformers/trainer.py | 2 +- src/transformers/training_args.py | 15 +++++---------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 222447e95d54..432b9a1e89aa 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -278,7 +278,7 @@ def __init__( model = model.to(args.device) else: # Force n_gpu to 1 to avoid DataParallel. - self.args._force_n_gpu = 1 + self.args._n_gpu = 1 # later use `self.model is self.model_wrapped` to check if it's wrapped or not self.model_wrapped = model diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index dcf15fb86db0..1540be5edbcb 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -398,6 +398,7 @@ class TrainingArguments: default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."} ) adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace Adam by Adafactor."}) + _n_gpu: int = field(init=False, repr=False, default=0) def __post_init__(self): if self.disable_tqdm is None: @@ -418,7 +419,7 @@ def __post_init__(self): if is_torch_available() and self.device.type != "cuda" and self.fp16: raise ValueError("Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices.") - self._force_n_gpu = None + self._n_gpu = torch.cuda.device_count() def __repr__(self): # We override the default repr to remove deprecated arguments from the repr. This method should be removed once @@ -440,10 +441,7 @@ def train_batch_size(self) -> int: "version. Using `--per_device_train_batch_size` is preferred." ) per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size - if not self.model_parallel: - train_batch_size = per_device_batch_size * max(1, self.n_gpu) - else: - train_batch_size = per_device_batch_size + train_batch_size = per_device_batch_size * max(1, self.n_gpu) return train_batch_size @property @@ -457,10 +455,7 @@ def eval_batch_size(self) -> int: "version. Using `--per_device_eval_batch_size` is preferred." ) per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size - if not self.model_parallel: - eval_batch_size = per_device_batch_size * max(1, self.n_gpu) - else: - eval_batch_size = per_device_batch_size + eval_batch_size = per_device_batch_size * max(1, self.n_gpu) return eval_batch_size @cached_property @@ -481,7 +476,7 @@ def _setup_devices(self) -> Tuple["torch.device", int]: # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0` # will use the first GPU in that env, i.e. GPU#1 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - n_gpu = torch.cuda.device_count() if self._force_n_gpu is None else self._force_n_gpu + n_gpu = self._n_gpu else: # Here, we'll use torch.distributed. # Initializes the distributed backend which will take care of synchronizing nodes/GPUs From a7a39216e99aae60238962ec3d6c96ecf23da42b Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 7 Jan 2021 14:38:50 -0800 Subject: [PATCH 05/12] fix _n_gpu --- src/transformers/training_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 1540be5edbcb..b4613c65fc8a 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -398,7 +398,7 @@ class TrainingArguments: default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."} ) adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace Adam by Adafactor."}) - _n_gpu: int = field(init=False, repr=False, default=0) + _n_gpu: int = field(default=0, repr=False) def __post_init__(self): if self.disable_tqdm is None: From f9a363ca190e131ec5bd54e9cbac598f642ab066 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 7 Jan 2021 14:43:26 -0800 Subject: [PATCH 06/12] add is_parallel helper wrapper --- src/transformers/trainer.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 432b9a1e89aa..c8a09a91db17 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -160,6 +160,10 @@ logger = logging.get_logger(__name__) +def is_parallel(model): + return hasattr(model, "is_parallelizable") and model.is_parallelizable and model.parallel + + def _model_unwrap(model: nn.Module) -> nn.Module: # since there could be multiple levels of wrapping, unwrap recursively if hasattr(model, "module"): @@ -274,7 +278,7 @@ def __init__( self.tokenizer = tokenizer # Model parallel - if not (model.is_parallelizable and model.parallel): + if not (is_parallel(model)): model = model.to(args.device) else: # Force n_gpu to 1 to avoid DataParallel. @@ -672,7 +676,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D set_seed(self.args.seed) model = self.call_model_init(trial) - if not (model.is_parallelizable and model.parallel): + if not is_parallel(model): model = model.to(self.args.device) self.model = model @@ -933,7 +937,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D ) if isinstance(self.model, PreTrainedModel): self.model = self.model.from_pretrained(self.state.best_model_checkpoint) - if not (model.is_parallelizable and model.parallel): + if not is_parallel(model): self.model = self.model.to(self.args.device) else: state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)) From 5f273a0c41d042c86ccafd70c6ccea4ff74c8d45 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 7 Jan 2021 15:02:51 -0800 Subject: [PATCH 07/12] fix attribute --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index c8a09a91db17..7541f06504bd 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -161,7 +161,7 @@ def is_parallel(model): - return hasattr(model, "is_parallelizable") and model.is_parallelizable and model.parallel + return hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel def _model_unwrap(model: nn.Module) -> nn.Module: From 35aefd43898a3b8d9559a6286618e5d5d3823ac7 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 7 Jan 2021 15:35:59 -0800 Subject: [PATCH 08/12] introduce a new attribute is_model_parallel --- src/transformers/trainer.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 7541f06504bd..06ded873f349 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -160,10 +160,6 @@ logger = logging.get_logger(__name__) -def is_parallel(model): - return hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel - - def _model_unwrap(model: nn.Module) -> nn.Module: # since there could be multiple levels of wrapping, unwrap recursively if hasattr(model, "module"): @@ -225,13 +221,15 @@ class Trainer: Important accessors: - ``self.model`` - always points to the core model. If using a transformers model, it will be a - :class:`PreTrainedModel` subclass. + * ``self.model`` - always points to the core model. If using a transformers model, it will be a + :class:`PreTrainedModel` subclass. - ``self.model_wrapped`` - always points to the most external model in case one or more other modules wrap the + * ``self.model_wrapped`` - always points to the most external model in case one or more other modules wrap the original model. This is the model that should be used for the forward pass. For example, under ``DeepSpeed``, the inner model is wrapped in ``DeepSpeed`` and then again in ``DistributedDataParallel``. If the inner model hasn't been wrapped, then ``self.model_wrapped`` is the same as ``self.model``. + + * ``self.is_model_parallel`` - is true if a model has been switched to a model parallel mode. """ def __init__( @@ -271,6 +269,11 @@ def __init__( ) self.model_init = model_init + if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel: + self.is_model_parallel = True + else: + self.is_model_parallel = False + default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) self.data_collator = data_collator if data_collator is not None else default_collator self.train_dataset = train_dataset @@ -278,7 +281,7 @@ def __init__( self.tokenizer = tokenizer # Model parallel - if not (is_parallel(model)): + if not self.is_model_parallel: model = model.to(args.device) else: # Force n_gpu to 1 to avoid DataParallel. @@ -676,7 +679,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D set_seed(self.args.seed) model = self.call_model_init(trial) - if not is_parallel(model): + if not self.is_model_parallel: model = model.to(self.args.device) self.model = model @@ -937,7 +940,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D ) if isinstance(self.model, PreTrainedModel): self.model = self.model.from_pretrained(self.state.best_model_checkpoint) - if not is_parallel(model): + if not self.is_model_parallel: self.model = self.model.to(self.args.device) else: state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)) From f273f2f3cf3a21fee4251d5671d6e49ca74ee663 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 7 Jan 2021 15:42:54 -0800 Subject: [PATCH 09/12] docs --- src/transformers/trainer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 06ded873f349..f70be3d3b1fa 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -169,8 +169,7 @@ def _model_unwrap(model: nn.Module) -> nn.Module: class Trainer: - """ - Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers. + """Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers. Args: model (:class:`~transformers.PreTrainedModel` or :obj:`torch.nn.Module`, `optional`): @@ -221,15 +220,16 @@ class Trainer: Important accessors: - * ``self.model`` - always points to the core model. If using a transformers model, it will be a + ``self.model`` - always points to the core model. If using a transformers model, it will be a :class:`PreTrainedModel` subclass. - * ``self.model_wrapped`` - always points to the most external model in case one or more other modules wrap the + ``self.model_wrapped`` - always points to the most external model in case one or more other modules wrap the original model. This is the model that should be used for the forward pass. For example, under ``DeepSpeed``, the inner model is wrapped in ``DeepSpeed`` and then again in ``DistributedDataParallel``. If the inner model hasn't been wrapped, then ``self.model_wrapped`` is the same as ``self.model``. - * ``self.is_model_parallel`` - is true if a model has been switched to a model parallel mode. + ``self.is_model_parallel`` - is true if a model has been switched to a model parallel mode. + """ def __init__( From 8d9b78b817da11b18a56a3f40ed6020e5cabf890 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 7 Jan 2021 15:44:49 -0800 Subject: [PATCH 10/12] docs --- src/transformers/trainer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f70be3d3b1fa..bb0895592d14 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -169,7 +169,8 @@ def _model_unwrap(model: nn.Module) -> nn.Module: class Trainer: - """Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers. + """ + Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers. Args: model (:class:`~transformers.PreTrainedModel` or :obj:`torch.nn.Module`, `optional`): @@ -218,17 +219,17 @@ class Trainer: :class:`~transformers.AdamW` on your model and a scheduler given by :func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`. - Important accessors: + Important attributes: ``self.model`` - always points to the core model. If using a transformers model, it will be a - :class:`PreTrainedModel` subclass. + :class:`PreTrainedModel` subclass. ``self.model_wrapped`` - always points to the most external model in case one or more other modules wrap the original model. This is the model that should be used for the forward pass. For example, under ``DeepSpeed``, the inner model is wrapped in ``DeepSpeed`` and then again in ``DistributedDataParallel``. If the inner model hasn't been wrapped, then ``self.model_wrapped`` is the same as ``self.model``. - ``self.is_model_parallel`` - is true if a model has been switched to a model parallel mode. + ``self.is_model_parallel`` - is True if a model has been switched to a model parallel mode. """ From ea8c41a8683c62345aeac061c5a0e9d3a3a0ea45 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Fri, 8 Jan 2021 09:54:27 -0500 Subject: [PATCH 11/12] Put back init False and rearrange doc --- src/transformers/trainer.py | 18 ++++++++---------- src/transformers/training_args.py | 2 +- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index bb0895592d14..76e29a1ad51a 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -221,16 +221,14 @@ class Trainer: Important attributes: - ``self.model`` - always points to the core model. If using a transformers model, it will be a - :class:`PreTrainedModel` subclass. - - ``self.model_wrapped`` - always points to the most external model in case one or more other modules wrap the - original model. This is the model that should be used for the forward pass. For example, under ``DeepSpeed``, - the inner model is wrapped in ``DeepSpeed`` and then again in ``DistributedDataParallel``. If the inner model - hasn't been wrapped, then ``self.model_wrapped`` is the same as ``self.model``. - - ``self.is_model_parallel`` - is True if a model has been switched to a model parallel mode. - + - **model** -- Always points to the core model. If using a transformers model, it will be a + :class:`~transformers.PreTrainedModel` subclass. + - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the + original model. This is the model that should be used for the forward pass. For example, under ``DeepSpeed``, + the inner model is wrapped in ``DeepSpeed`` and then again in ``torch.nn.DistributedDataParallel``. If the + inner model hasn't been wrapped, then ``self.model_wrapped`` is the same as ``self.model``. + - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from + data parallelism, this means some of the model layers are split on different GPUs). """ def __init__( diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index b4613c65fc8a..1540be5edbcb 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -398,7 +398,7 @@ class TrainingArguments: default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."} ) adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace Adam by Adafactor."}) - _n_gpu: int = field(default=0, repr=False) + _n_gpu: int = field(init=False, repr=False, default=0) def __post_init__(self): if self.disable_tqdm is None: From 7d7c5468f2d850184ffc4d86b995f6b933248a4b Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Fri, 8 Jan 2021 10:07:36 -0500 Subject: [PATCH 12/12] Ignore non-init args in HFArgumentParser --- src/transformers/hf_argparser.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/hf_argparser.py b/src/transformers/hf_argparser.py index f68612126323..5192a300964f 100644 --- a/src/transformers/hf_argparser.py +++ b/src/transformers/hf_argparser.py @@ -53,6 +53,8 @@ def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType] def _add_dataclass_arguments(self, dtype: DataClassType): for field in dataclasses.fields(dtype): + if not field.init: + continue field_name = f"--{field.name}" kwargs = field.metadata.copy() # field.metadata is not used at all by Data Classes, @@ -142,7 +144,7 @@ def parse_args_into_dataclasses( namespace, remaining_args = self.parse_known_args(args=args) outputs = [] for dtype in self.dataclass_types: - keys = {f.name for f in dataclasses.fields(dtype)} + keys = {f.name for f in dataclasses.fields(dtype) if f.init} inputs = {k: v for k, v in vars(namespace).items() if k in keys} for k in keys: delattr(namespace, k)