Draft
Conversation
Implements chunked cross-entropy loss that splits the sequence dimension into N chunks, computing lm_head projection and CE loss per-chunk to avoid materializing the full [B, L, V] logits tensor at once. Key components: - ChunkedCELoss: wraps lm_head + ce_loss with chunked forward/backward - GradAccumulator: pre-allocated buffer for assembling chunk gradients - _no_reshard_after_backward: FSDP2 context to avoid N all-gathers - skip_lm_head kwarg on Decoder.forward() for the detach boundary - ChunkedCELossFactory: deferred initialization (model not available at build time) - Trainer integration with dedicated forward_backward_step branch
…CELoss - Add loss_num_chunks to TrainingConfig (default 1, no-op) - Trainer auto-wraps loss_fn in ChunkedCELossFactory when loss_num_chunks > 1 - Integration tests for FSDP, FSDP+TP(SP), FSDP+CP, FSDP+TP+CP, FSDP+compile
FSDP2's backward hooks are one-shot per forward pass. The previous approach of calling self.lm_head(h_chunk) triggered FSDP2's backward hooks during chunk backward, leaving no hooks for the decoder backward (h.backward(grad)), causing zero gradients on model parameters. Fix: Use F.linear(h_chunk, lm_weight) to bypass FSDP2 module hooks during chunk computation. Use (h * accumulated_grad).sum().backward() instead of h.backward(grad) to properly trigger FSDP2's hooks in a single backward pass.
Replace bare function + build_fn pattern with proper loss classes. CrossEntropyLoss and MSELoss encapsulate compilation logic internally. The old function names (cross_entropy_loss, mse_loss) remain as public API for backward compatibility. build_cross_entropy_loss and build_mse_loss now return class instances.
wwwjn
commented
Apr 11, 2026
| "chunked_loss_fsdp+tp+cp", | ||
| ngpu=8, | ||
| ), | ||
| OverrideDefinitions( |
Contributor
Author
There was a problem hiding this comment.
Need to consolidate to one single compound test
|
|
||
|
|
||
| def cross_entropy_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: | ||
| class CrossEntropyLoss: |
Contributor
Author
There was a problem hiding this comment.
Need a cleaner refactor for loss part
| """Initialize the gradient accumulator. | ||
|
|
||
| Args: | ||
| reference: Reference tensor to get shape, device, and dtype from. |
Contributor
Author
There was a problem hiding this comment.
pass the shape direclty
| h_chunk.grad = None | ||
| del scaled_chunk_loss, chunk_loss, logits | ||
|
|
||
| # Get the accumulated gradient and backward through the decoder. |
Contributor
Author
There was a problem hiding this comment.
Revisit this FSDP trigger bug
| # Use F.linear instead of self.lm_head(h_chunk) to bypass FSDP2's | ||
| # module forward/backward hooks. This ensures FSDP2's one-shot | ||
| # backward hooks remain available for the decoder backward below. | ||
| logits = torch.nn.functional.linear(h_chunk, lm_weight) |
Contributor
Author
There was a problem hiding this comment.
Refactor, avoid using functional but trigger FSDP hooks at correct time
| accumulated_grad = grad_accumulator.result() | ||
| assert accumulated_grad.dtype == torch.float32 | ||
|
|
||
| decoder_loss = (hidden_states * accumulated_grad.to(hidden_states.dtype)).sum() |
Contributor
Author
There was a problem hiding this comment.
This too tricky, find how to trigger FSDP hooks explicitly
| """Factory for creating ChunkedCELoss after model construction. | ||
|
|
||
| Since ChunkedCELoss needs the model's lm_head, and the model is not available | ||
| at loss builder time, this factory is returned by build_chunked_cross_entropy_loss |
Contributor
Author
There was a problem hiding this comment.
Remove the build function
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Not ready for review