Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion fast_llm/engine/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,6 @@ def _validate(self) -> None:
# TODO: Add support.
Assert.eq(self.model.distributed.pipeline_parallel, 1)
# TODO: Check if these work.
Assert.eq(self.model.distributed.tensor_parallel, 1)
Assert.eq(self.model.distributed.sequence_data_parallel, 1)
if self.run.experiment_dir is None:
assert not self.training.checkpoint.enabled()
Expand Down
149 changes: 78 additions & 71 deletions fast_llm/functional/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,23 @@ def _fused_cross_entropy_forward_backward(
predicted_logits = logits_norm.gather(1, target)
if group is not None:
predicted_logits = target_mask * predicted_logits

all_reduce(predicted_logits, op=ReduceOp.SUM, group=group)
else:
predicted_logits = (target * logits_norm).sum(dim=-1, keepdim=True)
if group is not None and target_format != TargetFormat.labels:
# this is needed because on each rank we calculate log Z - sum_i t_i * z_i, where z_i is logit.
# Then we average on line 160: 1/K sum_ranks (log Z - sum_i t_i * z_i)
# = log Z - 1/K sum_ranks (sum_i t_i * z_i), where is the global predicted_logits, so without multiplying it by K 1/K there does not cancel out.
predicted_logits = predicted_logits * group.size()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks wrong, see previous comment. The previous version was tested and confirmed to work.

Copy link
Copy Markdown
Contributor Author

@oleksost oleksost Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was ist also tested with soft labels (i.ew. when targets are logits)? Without this scaling this new test does not pass.

The reason is that when here we average loss over ranks, we basically do 1/K sum_K (log (Z) - sum_i z_i t_i), where sum_i z_i t_i is local predicted_logits and K is number of ranks. Then what we we get is 1/K * K log (Z) - 1/K predicted_logits_global, so 1/K that scales global predicted_logits does mot cancel out without scaling it by K before.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I didn't realize this was for distillation only. This one is less robustly tested so errors are possible. But if I understand correctly we just need to replace the mean reduction below with a sum reduction on predicted_logits only?

Copy link
Copy Markdown
Contributor Author

@oleksost oleksost Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeh, either of two

  • scale predicted_logits by group size and keep everything as is (i.e. still AVG reduction on loss)
  • or do SUM reduction on predicted_logits instead of AVG reduction on loss below


per_sample_loss = sum_exp_logits.log() - predicted_logits
if loss_mask is not None:
per_sample_loss = per_sample_loss * loss_mask

loss = per_sample_loss.mean()
if target_format != TargetFormat.labels and group is not None:
all_reduce(loss, op=ReduceOp.MEAN, group=group)
all_reduce(loss, op=ReduceOp.AVG, group=group)

return loss, grad

Expand Down Expand Up @@ -213,71 +219,72 @@ def cross_entropy_forward_backward(
)


def distributed_log_softmax(
logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup | None = None, dim: int = -1
):
logits_norm, _, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group=group, dim=dim)

return logits_norm - sum_exp_logits.log() # log_softmax


def _torch_reverse_kl_forward_backward(
logits: torch.Tensor,
target: torch.Tensor,
loss_mask: torch.Tensor | None,
grad_output: float | None,
logits_scale_factor: float,
target_format: TargetFormat,
group: ProcessGroup | None = None,
logits_scale_factor: float = 1.0,
teacher_softmax_temperature: float = 1.0,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Reverse KL using PyTorch's native kl_div function.
Much simpler and more reliable than custom implementation!
This is used for TP version where we split accross vocab dimantion. KL is additive over partitions of the vocab.

Takes:
logits: [BxS, V] or [B, S, V]
target: [BxS, V] or [B, S, V] (logits format)
loss_mask: [BxS] or [B, S] or None
...
"""
Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format")
Assert.eq(
teacher_softmax_temperature,
1,
msg="Teacher softmax temperature must be 1 for sequence-tensor-parallel reverse KL",
)
Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for sequence-tensor-parallel reverse KL")
Assert.eq(target.shape, logits.shape)
assert target.dtype.is_floating_point, target.dtype
if loss_mask is not None:
Assert.eq(loss_mask.shape, logits.shape[:-1])

# Compute log probabilities - let _fused_softmax handle scaling internally
# teacher_probs = _fused_softmax(target, logits_scale_factor * (1 / teacher_softmax_temperature), group)
# # teacher_log_probs = torch.log(teacher_probs + 1e-8) # log(p)
# teacher_probs = torch.clamp(teacher_probs, min=1e-7) # or even 1e-6
# teacher_log_probs = torch.log(teacher_probs)

# Scale target logits more carefully
scaled_target = target * (logits_scale_factor / teacher_softmax_temperature)

# Clamp to prevent extreme values before log_softmax
scaled_target = torch.clamp(scaled_target, min=-50, max=50)
teacher_log_probs = torch.log_softmax(scaled_target, dim=-1)

# For reverse KL: KL(q||p) = Ξ£ q * log(q/p) = Ξ£ q * (log(q) - log(p))
# Use kl_div with: input=log(p), target=q, log_target=False
# This gives: Ξ£ q * (log(q) - log(p)) = exactly what we want!

# Compute log probabilities
teacher_log_probs = distributed_log_softmax(target.float(), group=group)
# batch_size = logits.shape[0]
with torch.enable_grad():
logits_ = logits.detach().requires_grad_(grad_output is not None)

# Use log_softmax for consistency instead of _fused_softmax
scaled_logits = logits_ * logits_scale_factor
scaled_logits = torch.clamp(scaled_logits, min=-50, max=50)
student_log_probs = torch.log_softmax(scaled_logits, dim=-1)

# Convert to probabilities for kl_div
# student_probs_ = torch.exp(student_log_probs)
logits_ = logits.float().detach().requires_grad_(grad_output is not None)
student_log_probs = distributed_log_softmax(logits_, group=group)

# Reverse KL: input=teacher_log_probs, target=student_probs
if loss_mask is None:
loss = torch.nn.functional.kl_div(
teacher_log_probs, # input = log(p)
student_log_probs, # target = log(q)
reduction="batchmean",
log_target=True,
)
loss_terms = torch.nn.functional.kl_div(
teacher_log_probs, # input = log(p)
student_log_probs, # target = log(q)
reduction="none",
log_target=True,
).sum(dim=-1)
if loss_mask is not None:
# loss mask is the same on all ranks for TP over vocab.
valid = loss_mask.to(loss_terms.dtype)
loss_terms = loss_terms * valid
valid_tokens = torch.tensor(valid.sum(), device=loss_terms.device, dtype=loss_terms.dtype)
else:
# Apply loss mask - this requires some reshaping
loss_per_sample = torch.nn.functional.kl_div(
teacher_log_probs, student_log_probs, reduction="none", log_target=True
).sum(dim=-1)
loss = (loss_per_sample * loss_mask).mean()
valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype))
loss = loss_terms.sum() # sums over batch and seq. len.

if group is not None and target_format != TargetFormat.labels:
all_reduce(loss, op=ReduceOp.MEAN, group=group)
if group is not None:
all_reduce(loss, op=ReduceOp.SUM, group=group)
loss /= valid_tokens

if grad_output is not None:
loss.backward(torch.full_like(loss, grad_output))
Expand All @@ -297,6 +304,7 @@ def reverse_kl_forward_backward(
logits_scale_factor: float = 1.0,
teacher_softmax_temperature: float = 1.0,
target_format: TargetFormat = TargetFormat.labels,
sequence_parallel_logits: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Compute reverse KL divergence: KL(q||p) where q is the predicted distribution (student) and p is the target (teacher).
Expand All @@ -309,37 +317,36 @@ def reverse_kl_forward_backward(
- Standard CE: KL(p||q) = mode-covering (spreads mass broadly)
- Reverse KL: KL(q||p) = mode-seeking (focuses on target modes)

Args:
logits: Model predictions [batch_size, ..., vocab_size]
target: Target distribution or labels
loss_mask: Optional mask for loss computation
grad_output: Gradient output scale factor
group: Process group for tensor parallelism
logits_scale_factor: Temperature scaling factor (1/T)
target_format: Format of target (labels or logits)
Takes:
logits: [BxS, V] or [B, S, V], where V is local vocab size
target: [BxS, V] or [B, S, V] (logits format)
loss_mask: [BxS] or [B, S] or None
...

Returns:
loss: Reverse KL divergence loss
grad: Gradients w.r.t. logits

Example usage:
# Replace standard cross-entropy with reverse KL
# loss, grad = cross_entropy_forward_backward(logits, target, ...)
loss, grad = reverse_kl_forward_backward(logits, target,
loss_mask=None,
grad_output=1.0,
logits_scale_factor=1.0,
target_format=TargetFormat.labels)
"""
if target_format == TargetFormat.labels:
Assert.eq(target.shape, logits.shape[:-1])
Assert.eq(target.dtype, torch.int64)
else:
Assert.eq(target.shape, logits.shape)
assert target.dtype.is_floating_point, target.dtype
if loss_mask is not None:
Assert.eq(loss_mask.shape, logits.shape[:-1])

if sequence_parallel_logits:
# TODO: see hybrid dev branch where it is implemented
raise NotImplementedError("Sequence-parallel reverse KL is not implemented yet, set vocab_parallel true")

Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format")
Assert.eq(target.shape, logits.shape)
assert target.dtype.is_floating_point, target.dtype
if loss_mask is not None:
Assert.eq(loss_mask.shape, logits.shape[:-1])

# TODO: implement fused?
return _torch_reverse_kl_forward_backward(
logits, target, loss_mask, grad_output, logits_scale_factor, target_format, group, teacher_softmax_temperature
distillation_loss, distillation_grad = _torch_reverse_kl_forward_backward(
logits=logits,
target=target,
loss_mask=loss_mask,
grad_output=grad_output,
logits_scale_factor=logits_scale_factor,
target_format=target_format,
teacher_softmax_temperature=teacher_softmax_temperature,
group=group,
)
return distillation_loss, distillation_grad
2 changes: 2 additions & 0 deletions fast_llm/layers/language_model/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,9 @@ def _logits_cross_entropy_forward_backward(
target_format=(
TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits
),
sequence_parallel_logits=self._sequence_parallel_logits,
)

elif self._config.distillation_loss_implementation == DistillationLossImpl.cross_entropy:
distillation_loss, distillation_grad = cross_entropy_forward_backward(
logits.flatten(0, -2),
Expand Down
16 changes: 8 additions & 8 deletions fast_llm/models/gpt/conversion/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,19 +198,19 @@ def import_config(cls, config: dict) -> dict:
elif rope_type == "llama3":
rotary_config.update(
{
"scale_factor": config["factor"],
"low_frequency_factor": config["low_freq_factor"],
"high_frequency_factor": config["high_freq_factor"],
"original_context_length": config["original_max_position_embeddings"],
"scale_factor": config["rope_scaling"]["factor"],
"low_frequency_factor": config["rope_scaling"]["low_freq_factor"],
"high_frequency_factor": config["rope_scaling"]["high_freq_factor"],
"original_context_length": config["rope_scaling"]["original_max_position_embeddings"],
}
)
elif rope_type == "yarn":
rotary_config.update(
{
"attention_factor": config["attention_factor"],
"beta_fast": config["beta_fast"],
"beta_slow": config["beta_slow"],
"original_context_length": config["original_max_position_embeddings"],
"attention_factor": config["rope_scaling"]["attention_factor"],
Comment thread
jlamypoirier marked this conversation as resolved.
"beta_fast": config["rope_scaling"]["beta_fast"],
"beta_slow": config["rope_scaling"]["beta_slow"],
"original_context_length": config["rope_scaling"]["original_max_position_embeddings"],
}
)
else:
Expand Down
Loading
Loading