Skip to content

Fix Trainer with a parallel model#9578

Merged
LysandreJik merged 2 commits intomasterfrom
fix_trainer_model_parallel
Jan 14, 2021
Merged

Fix Trainer with a parallel model#9578
LysandreJik merged 2 commits intomasterfrom
fix_trainer_model_parallel

Conversation

@sgugger
Copy link
Copy Markdown
Collaborator

@sgugger sgugger commented Jan 13, 2021

What does this PR do?

The test introduced in #9566 wasn't actually working as the default batch size is 8, not 16...
So the problem was still there, the reason because _setup_devices in TrainingArguments is a cached_property, so its result is computed once and for all at init. Had to change the behavior slightly, but it should be okay since it's a private method.

Fixes #9577 (model is getting wrapped into DataParallel because the value of self.args.n_gpu is not updated.


if is_torch_available() and self.device.type != "cuda" and self.fp16:
raise ValueError("Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices.")
self._n_gpu = torch.cuda.device_count()
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing from here, this is going to be completely setup in _setup_devices

Comment thread tests/test_trainer.py
model.is_parallelizable = True
model.model_parallel = True
trainer = Trainer(model=model, train_dataset=RegressionDataset(), eval_dataset=RegressionDataset())
args = TrainingArguments("./regression", per_device_train_batch_size=16, per_device_eval_batch_size=16)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure the test uses batch sizes of 16.

Comment thread tests/test_trainer.py
trainer = Trainer(model, args, train_dataset=RegressionDataset(), eval_dataset=RegressionDataset())
# Check the Trainer was fooled
self.assertTrue(trainer.is_model_parallel)
self.assertEqual(trainer.args.n_gpu, 1)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was still set to 2 before, so this checks it is indeed 1.

Comment thread src/transformers/training_args.py
Copy link
Copy Markdown
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @sgugger

@LysandreJik LysandreJik merged commit 5e1bea4 into master Jan 14, 2021
@LysandreJik LysandreJik deleted the fix_trainer_model_parallel branch January 14, 2021 08:23
LysandreJik pushed a commit that referenced this pull request Jan 14, 2021
* Fix Trainer with a parallel model

* More clean up
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Trainer is using DataParallel on parallelized models

3 participants