Conversation
| # then we average: 1/K sum_ranks (log Z - sum_i t_i * z_i) | ||
| # = log Z - 1/K sum_ranks (sum_i t_i * z_i) | ||
| # but sum_ranks (sum_i t_i * z_i) = sum_i t_i * z_i (over all vocab) | ||
| predicted_logits = predicted_logits * group.size() |
There was a problem hiding this comment.
This looks wrong, see previous comment. The previous version was tested and confirmed to work.
There was a problem hiding this comment.
Was ist also tested with soft labels (i.ew. when targets are logits)? Without this scaling this new test does not pass.
The reason is that when here we average loss over ranks, we basically do 1/K sum_K (log (Z) - sum_i z_i t_i), where sum_i z_i t_i is local predicted_logits and K is number of ranks. Then what we we get is 1/K * K log (Z) - 1/K predicted_logits_global, so 1/K that scales global predicted_logits does mot cancel out without scaling it by K before.
There was a problem hiding this comment.
Sorry I didn't realize this was for distillation only. This one is less robustly tested so errors are possible. But if I understand correctly we just need to replace the mean reduction below with a sum reduction on predicted_logits only?
There was a problem hiding this comment.
Yeh, either of two
- scale
predicted_logitsby group size and keep everything as is (i.e. still AVG reduction on loss) - or do SUM reduction on
predicted_logitsinstead of AVG reduction on loss below
jlamypoirier
left a comment
There was a problem hiding this comment.
Looks good, but some suggestions on improving the tests
| @@ -0,0 +1,185 @@ | |||
| import os | |||
There was a problem hiding this comment.
Please move to tests/functional
Also consider renaming to test_cross_entropy (to match the implementation file) and moving test_cross_entropy here.
| torch.testing.assert_close(loss, ref_loss, atol=1e-6, rtol=1e-6) | ||
|
|
||
|
|
||
| def _ce_vocab_tp_worker(rank: int, group: dist.ProcessGroup, use_mask: bool): |
There was a problem hiding this comment.
We might want to match the implementation and parametrization from test_cross_entropy
TP support for reverse KL loss.
🔍 Type of change
Select all that apply:
📝 Changes
✅ Checklist
Make sure the following tasks are completed before submitting the PR:
General
Dependencies and Configuration
Testing
Performance Impact
📊 Performance Impact Details
If there is any impact on performance, describe it and provide benchmark results, if applicable:
🗒️ Additional Notes
Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.