Skip to content

TP support for reverse KL loss #400

Merged
oleksost merged 11 commits intomainfrom
rev_kl_tp
Dec 4, 2025
Merged

TP support for reverse KL loss #400
oleksost merged 11 commits intomainfrom
rev_kl_tp

Conversation

@oleksost
Copy link
Copy Markdown
Contributor

@oleksost oleksost commented Dec 2, 2025

TP support for reverse KL loss.

  • adds support for vocabulary parallel reverse KL loss calculation using torch (no fused implementataion).
  • Sequence parallel loss calculation is not supported to keep it simple (I don't think we use sequence parallel embeddings/head)
  • this also fixes a small bug in CE loss for when it is used for distillation

🔍 Type of change

Select all that apply:

  • 🐛 Bug fix (non-breaking change that addresses a specific issue)
  • 🚀 New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • 📈 Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • 🛠️ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • 📦 Dependency bump (updates dependencies, including Dockerfile or package changes)
  • 📝 Documentation change (updates documentation, including new content or typo fixes)
  • 🔧 Infrastructure/Build change (affects build process, CI/CD, or dependencies)

📝 Changes

  • added _torch_reverse_kl_forward_backward in cross_entropy.py
  • added test_rkl_loss

✅ Checklist

Make sure the following tasks are completed before submitting the PR:

General

  • 📜 I have read and followed the contributing guidelines.
  • 🏷️ I am using a clear and descriptive PR title that summarizes the key change or feature introduced.
  • 🎉 The functionality is complete, and I have tested the changes.
  • 📝 I have updated the documentation if needed.
  • ⚠️ The change does not introduce any new issues (e.g., runtime warnings, type checker errors, linting problems, unhandled edge cases).
  • 🧩 I have commented my code, especially in hard-to-understand areas.

Dependencies and Configuration

  • 🐋 I have updated the Docker configuration or dependencies, if applicable.
  • 🔄 I have ensured compatibility with the existing setup after dependency changes.

Testing

  • 🧪 I have added or updated tests to cover my changes.
  • ✔️ New and existing tests pass locally with my changes.
  • 🚦 I have tested these changes on GPUs and verified training stability.
  • 🏋️ I have tested the changes on realistic training workloads, if applicable.

Performance Impact

  • 📊 I have run benchmarks where applicable to evaluate the performance impact.
  • ✅ The benchmarks show no performance regression.
  • 🚀 The benchmarks indicate a potential performance improvement.
  • ⚠️ The benchmarks indicate a potential performance degradation.
  • 📈 I have provided benchmark results and detailed any performance impact below, if applicable.

📊 Performance Impact Details

If there is any impact on performance, describe it and provide benchmark results, if applicable:


🗒️ Additional Notes

Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.

@oleksost oleksost marked this pull request as ready for review December 2, 2025 19:52
Comment thread fast_llm/functional/cross_entropy.py Outdated
Comment thread fast_llm/functional/config.py Outdated
Comment thread fast_llm/models/gpt/conversion/llama.py
Comment thread tests/test_distillation_loss.py Outdated
# then we average: 1/K sum_ranks (log Z - sum_i t_i * z_i)
# = log Z - 1/K sum_ranks (sum_i t_i * z_i)
# but sum_ranks (sum_i t_i * z_i) = sum_i t_i * z_i (over all vocab)
predicted_logits = predicted_logits * group.size()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks wrong, see previous comment. The previous version was tested and confirmed to work.

Copy link
Copy Markdown
Contributor Author

@oleksost oleksost Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was ist also tested with soft labels (i.ew. when targets are logits)? Without this scaling this new test does not pass.

The reason is that when here we average loss over ranks, we basically do 1/K sum_K (log (Z) - sum_i z_i t_i), where sum_i z_i t_i is local predicted_logits and K is number of ranks. Then what we we get is 1/K * K log (Z) - 1/K predicted_logits_global, so 1/K that scales global predicted_logits does mot cancel out without scaling it by K before.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I didn't realize this was for distillation only. This one is less robustly tested so errors are possible. But if I understand correctly we just need to replace the mean reduction below with a sum reduction on predicted_logits only?

Copy link
Copy Markdown
Contributor Author

@oleksost oleksost Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeh, either of two

  • scale predicted_logits by group size and keep everything as is (i.e. still AVG reduction on loss)
  • or do SUM reduction on predicted_logits instead of AVG reduction on loss below

@oleksost oleksost requested a review from jlamypoirier December 3, 2025 14:20
Copy link
Copy Markdown
Collaborator

@jlamypoirier jlamypoirier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, but some suggestions on improving the tests

@@ -0,0 +1,185 @@
import os
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move to tests/functional

Also consider renaming to test_cross_entropy (to match the implementation file) and moving test_cross_entropy here.

torch.testing.assert_close(loss, ref_loss, atol=1e-6, rtol=1e-6)


def _ce_vocab_tp_worker(rank: int, group: dist.ProcessGroup, use_mask: bool):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to match the implementation and parametrization from test_cross_entropy

@oleksost oleksost merged commit cc90338 into main Dec 4, 2025
3 of 4 checks passed
@oleksost oleksost deleted the rev_kl_tp branch December 4, 2025 23:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants