fix: use seq_length instead of padded_seq_length for topk output padding#1929
Conversation
In get_topk_logits, the topk outputs were incorrectly padded to padded_seq_length instead of seq_length, causing an assertion failure in get_and_validate_seqlen during distillation training. This aligns the padding logic with get_logprobs which correctly uses seq_length. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Zhaopeng Qiu <qiuzhaopeng@foxmail.com>
📝 WalkthroughWalkthroughModified padding calculation logic in the Megatron policy worker for top-k logits and indices collection. Changed Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
No actionable comments were generated in the recent review. 🎉 Tip Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
…ing (NVIDIA-NeMo#1929) Signed-off-by: Zhaopeng Qiu <qiuzhaopeng@foxmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
…ing (#1929) Signed-off-by: Zhaopeng Qiu <qiuzhaopeng@foxmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
…ing (#1929) Signed-off-by: Zhaopeng Qiu <qiuzhaopeng@foxmail.com>
…ing (#1929) Signed-off-by: Zhaopeng Qiu <qiuzhaopeng@foxmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
What does this PR do ?
Fix distillation training assertion failure by using
seq_lengthinstead ofpadded_seq_lengthwhen padding topk outputs inget_topk_logits, consistent with howget_logprobshandles the same padding.Issues
List issues that this PR closes (syntax):
Usage
Nightly test
tests/test_suites/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-megatron-tp2pp2cp2-pack.shfails with:In
get_topk_logits(megatron_policy_worker.py:942), topk outputs were padded topadded_seq_length(used only for PP communication) instead ofseq_length(the actual input sequence length). This causedteacher_topk_indicesto have a larger sequence dimension thaninput_ids, triggering an assertion inget_and_validate_seqlenduring student training.The fix is a one-line change aligning with the existing pattern in get_logprobs (line 656):
Before your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
Release Notes