cp: fix: Fix gradient clipping of non-float32 params (1158) into r0.4.0#1258
cp: fix: Fix gradient clipping of non-float32 params (1158) into r0.4.0#1258chtruong814 wants to merge 1 commit intor0.4.0from
fix: Fix gradient clipping of non-float32 params (1158) into r0.4.0#1258Conversation
Signed-off-by: Jarno Seppänen <jseppanen@nvidia.com> Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
|
📝 WalkthroughWalkthroughRemoved dtype handling from gradient clipping utility and updated policy workers to match new signature. clip_grad_by_total_norm_ now builds gradients only when clipping is needed (clip_coeff < 1.0), eliminating prior dtype casting. Policy workers drop the dtype argument in their train() clipping calls; other training flow remains unchanged. Changes
Sequence Diagram(s)sequenceDiagram
participant Trainer
participant PolicyWorker
participant GradUtils as Grad Utils
Trainer->>PolicyWorker: train()
PolicyWorker->>GradUtils: get_grad_norm(params, dtype=float32)
GradUtils-->>PolicyWorker: total_norm
PolicyWorker->>GradUtils: clip_grad_by_total_norm_(params, max_norm, total_norm)
note right of GradUtils: New: lazily collect grads only if clip_coeff < 1.0<br/>No dtype casting performed
GradUtils-->>PolicyWorker: gradients possibly clipped
PolicyWorker-->>Trainer: step complete
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
📜 Recent review detailsConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro 📒 Files selected for processing (3)
💤 Files with no reviewable changes (2)
🧰 Additional context used📓 Path-based instructions (2)**/*.py📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Files:
nemo_rl/**/*.py📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Files:
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
🔇 Additional comments (1)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
beep boop [🤖]: Hi @jseppanen 👋,
Summary by CodeRabbit
New Features
Refactor
Performance
Chores