Skip to content

[Don't review] ChunkedCELoss#2937

Draft
wwwjn wants to merge 4 commits intomainfrom
chunked-loss
Draft

[Don't review] ChunkedCELoss#2937
wwwjn wants to merge 4 commits intomainfrom
chunked-loss

Conversation

@wwwjn
Copy link
Copy Markdown
Contributor

@wwwjn wwwjn commented Apr 10, 2026

Not ready for review

wwwjn added 4 commits April 10, 2026 15:40
Implements chunked cross-entropy loss that splits the sequence dimension
into N chunks, computing lm_head projection and CE loss per-chunk to avoid
materializing the full [B, L, V] logits tensor at once.

Key components:
- ChunkedCELoss: wraps lm_head + ce_loss with chunked forward/backward
- GradAccumulator: pre-allocated buffer for assembling chunk gradients
- _no_reshard_after_backward: FSDP2 context to avoid N all-gathers
- skip_lm_head kwarg on Decoder.forward() for the detach boundary
- ChunkedCELossFactory: deferred initialization (model not available at build time)
- Trainer integration with dedicated forward_backward_step branch
…CELoss

- Add loss_num_chunks to TrainingConfig (default 1, no-op)
- Trainer auto-wraps loss_fn in ChunkedCELossFactory when loss_num_chunks > 1
- Integration tests for FSDP, FSDP+TP(SP), FSDP+CP, FSDP+TP+CP, FSDP+compile
FSDP2's backward hooks are one-shot per forward pass. The previous approach
of calling self.lm_head(h_chunk) triggered FSDP2's backward hooks during
chunk backward, leaving no hooks for the decoder backward (h.backward(grad)),
causing zero gradients on model parameters.

Fix: Use F.linear(h_chunk, lm_weight) to bypass FSDP2 module hooks during
chunk computation. Use (h * accumulated_grad).sum().backward() instead of
h.backward(grad) to properly trigger FSDP2's hooks in a single backward pass.
Replace bare function + build_fn pattern with proper loss classes.
CrossEntropyLoss and MSELoss encapsulate compilation logic internally.
The old function names (cross_entropy_loss, mse_loss) remain as public API
for backward compatibility. build_cross_entropy_loss and build_mse_loss
now return class instances.
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 10, 2026
@wwwjn wwwjn changed the title ChunkedCELoss [Don't review] ChunkedCELoss Apr 10, 2026
"chunked_loss_fsdp+tp+cp",
ngpu=8,
),
OverrideDefinitions(
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to consolidate to one single compound test



def cross_entropy_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
class CrossEntropyLoss:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need a cleaner refactor for loss part

"""Initialize the gradient accumulator.

Args:
reference: Reference tensor to get shape, device, and dtype from.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pass the shape direclty

h_chunk.grad = None
del scaled_chunk_loss, chunk_loss, logits

# Get the accumulated gradient and backward through the decoder.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revisit this FSDP trigger bug

# Use F.linear instead of self.lm_head(h_chunk) to bypass FSDP2's
# module forward/backward hooks. This ensures FSDP2's one-shot
# backward hooks remain available for the decoder backward below.
logits = torch.nn.functional.linear(h_chunk, lm_weight)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor, avoid using functional but trigger FSDP hooks at correct time

accumulated_grad = grad_accumulator.result()
assert accumulated_grad.dtype == torch.float32

decoder_loss = (hidden_states * accumulated_grad.to(hidden_states.dtype)).sum()
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This too tricky, find how to trigger FSDP hooks explicitly

"""Factory for creating ChunkedCELoss after model construction.

Since ChunkedCELoss needs the model's lm_head, and the model is not available
at loss builder time, this factory is returned by build_chunked_cross_entropy_loss
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the build function

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant