Pass required token_type_ids#4148
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Current state of the PR seems to fix the issue:
FAILED tests/test_rloo_trainer.py::RLOOTrainerTester::test_training_vlm_0_trl_internal_testing_tiny_Gemma3ForConditionalGeneration - TypeError: RLOOTrainer._get_per_token_logps_and_entropies() got an unexpected keyword argument 'token_type_ids'
= 1 failed, 920 passed, 49 skipped, 3 xfailed, 219 warnings, 5 rerun in 906.89s (0:15:06) =The only remaining issue is now the |
|
Everything is green now! 🚀 |
| token_type_ids = forward_kwargs["token_type_ids"] | ||
| forward_kwargs["token_type_ids"] = torch.cat( | ||
| [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1 | ||
| ) |
There was a problem hiding this comment.
If you validate this approach, do you think this should be implemented in other trainers as well?
| # Concatenate prompt_mask with completion_mask for logit computation | ||
| prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) | ||
| attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) | ||
| # If token_type_ids are used, extend them with zeros for the completion part |
There was a problem hiding this comment.
1 is for image and 0 for text, right?
There was a problem hiding this comment.
Yes, completion tokens are text.
|
@qgallouedec could you please validate this PR so we can finally have the CI green? |
|
Thank you for the fix 🙏 (I wouldn't be surprised if we see more models of this kind in the future -- |
Pass required
token_type_ids.Follow-up to
transformersPR:Fix #4142, fix #4150.
This PR extends support for the
token_type_idsinput across the GRPO and RLOO trainers, ensuring that models using token type information can correctly handle these inputs during training, evaluation, and loss computation.The changes are applied consistently to:
Changes
Token type IDs support:
token_type_idsas an optional argument to the_get_per_token_logps_and_entropiesmethod in bothgrpo_trainer.pyandrloo_trainer.py, allowing the trainers to process token type information.token_type_idswhen present, ensuring correct slicing and passing of token type IDs during batched forward passes.Integration with completions and output:
_generate_and_score_completionsmethod, extendedtoken_type_idswith zeros for the completion tokens and ensured they are passed through the forward arguments and included in the output dictionary.Loss computation:
_compute_lossmethod in both trainers to passtoken_type_idswhen available, ensuring that loss calculations take token type information into account.CC: @gante