fix bug when using DP in trl, the batch size of input and output dism…#38938
fix bug when using DP in trl, the batch size of input and output dism…#38938kaixuanliu wants to merge 20 commits intohuggingface:mainfrom
Conversation
|
Steps to reproduce the bug: it will fail and return error: It crashes as |
|
@zach-huggingface, @SunMarc and @qgallouedec, pls help review |
SunMarc
left a comment
There was a problem hiding this comment.
Thanks ! Can you add a test that cover this specific case ?
…atch Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
|
@SunMarc , Hi thx for advice. I think the existing one is OK for this case: |
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
|
@kaixuanliu , CI has failed cases, pls take a look |
|
@yao-matrix , Updated the code and the failed case passed. I also double checked the failed case on my own machine. @SunMarc Can you help review again? thx! |
|
@SunMarc Hi, this is a 2 weeks ago PR, can you help review it? Many thanks! |
|
@qgallouedec ,Hi, can you help review? Thx. |
|
@qgallouedec, could you help review this PR? |
src/transformers/trainer.py
Outdated
| actual_bs = None | ||
| if "labels" in inputs and isinstance(inputs["labels"], torch.Tensor): | ||
| actual_bs = inputs["labels"].shape[0] |
There was a problem hiding this comment.
actual_bs could be defined inside the IF block at line 3819 since it is only used there
|
Shouldn't this be fixed in trl instead? |
|
I'm not quite sure to understand because you use 3 devices, but don't run the test in a distributed manner? |
|
@qgallouedec Hi, I think it is a corner case issue that is not properly processed in transformers. Although it can be avoided in trl level or other upper application level, it is best to handle it in their common base level(transformers). DP maybe a dated approach, but since it is not deprecated formally, it's best to fix it. WDYT? |
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
SunMarc
left a comment
There was a problem hiding this comment.
Sorry for the delay, I have a left a few questions !
| assert loss_bs == self.args.n_gpu, ( | ||
| f"Expected loss to have {self.args.n_gpu} elements, but got {loss_bs} elements. " | ||
| "This usually happens when the model does not return a loss for each device." | ||
| ) | ||
| else: | ||
| assert loss_bs == actual_bs, ( | ||
| f"Expected loss to have {actual_bs} elements, but got {loss_bs} elements. " | ||
| "This usually happens when the model does not return a loss for each device." | ||
| ) |
There was a problem hiding this comment.
Instead of assert, let's just raise RuntimeError instead. Also the error msg don't help that much, is there an actionnable step here for the users to fix the issue ?
| if self.args.n_gpu > 1: | ||
| if "labels" in inputs and isinstance(inputs["labels"], torch.Tensor): | ||
| actual_bs = inputs["labels"].shape[0] | ||
| loss_bs = loss.shape[0] if isinstance(loss, torch.Tensor) else len(loss) | ||
| if actual_bs >= self.args.n_gpu: |
There was a problem hiding this comment.
do we really need those checks as we didn't need until now ? I feel like the isse was if num_items_in_batch, not with the labels or loss bs actually
| loss = loss.mean() # mean() to average on multi-gpu parallel training | ||
|
|
There was a problem hiding this comment.
we are already average the loss somewhere else no ?
| # In the DataParallel case, convert the scalar tensor into a 2-dim tensor with bs = n_gpu | ||
| num_items_in_batch = num_items_in_batch.unsqueeze(0).expand(self.args.n_gpu, -1) |
There was a problem hiding this comment.
happy to have that but I ran the test you told me and it passed without this PR, maybe i'm doing something wrong ? pytest -sv -rA tests/trainer/test_trainer.py::TrainerIntegrationTest::test_num_batches_in_training_with_gradient_accumulation
There was a problem hiding this comment.
Hi @SunMarc , you may need to revert trl to commit 3ef9faf257 as my comment above to reproduce. Anyway, I think it make sense as @qgallouedec mentioned that when using DP, it's best to use accelerate launch, so we can close this PR.
|
If you want to use DP, you should launch the training with |
No description provided.