Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 83 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1496,6 +1496,8 @@ def _inner_training_loop(
self._tr_loss = torch.tensor(0.0, device=args.device)
self._total_loss_scalar = 0.0
self._globalstep_last_logged = self.state.global_step
# Sums detached scalar auxiliary losses between logging steps (same schedule as `_tr_loss`).
self._aux_losses_accumulator: dict[str, torch.Tensor] = {}

model.zero_grad()

Expand Down Expand Up @@ -1906,7 +1908,10 @@ def training_step(
return loss_mb.reduce_mean().detach().to(self.args.device)

with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch, return_outputs=True)
outputs = None
if isinstance(loss, tuple):
loss, outputs = loss

del inputs
if (
Expand All @@ -1929,6 +1934,9 @@ def training_step(
# If the model does not accept loss kwargs, we need to normalize the loss by the number of gradient accumulation steps
loss = loss / self.current_gradient_accumulation_steps

self._accumulate_auxiliary_losses(outputs, num_items_in_batch=num_items_in_batch)
del outputs

# Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
# https://github.com/huggingface/transformers/pull/35808
if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
Expand All @@ -1938,6 +1946,74 @@ def training_step(

return loss.detach()

def _extract_auxiliary_losses_for_logging(self, outputs: Any) -> dict[str, torch.Tensor]:
"""
Collect scalar per-term losses for logging alongside the main `loss`.

Supports:
- `loss_dict` (`dict` of scalar tensors), as used by several vision/detection heads.
- Top-level output fields whose name contains ``loss`` (apart from the main ``loss``), when they are scalar tensors.
"""
collected: dict[str, torch.Tensor] = {}
if outputs is None:
return collected
try:
items = outputs.items()
except (AttributeError, TypeError):
return collected

skip_top_level = {
"loss",
"loss_dict",
"logits",
"start_logits",
"end_logits",
"mems",
"hidden_states",
"attentions",
"cross_attentions",
"encoder_last_hidden_state",
"decoder_hidden_states",
"decoder_attentions",
"encoder_attentions",
"past_key_values",
"cache_params",
}
for key, value in items:
if key == "loss_dict" and isinstance(value, dict):
for name, tensor in value.items():
if isinstance(tensor, torch.Tensor) and tensor.numel() == 1:
safe = str(name).replace("/", "_")
collected[f"loss_dict_{safe}"] = tensor
continue
if key in skip_top_level:
continue
if isinstance(value, torch.Tensor) and value.numel() == 1 and key != "loss" and "loss" in key.lower():
collected[key] = value
return collected

def _accumulate_auxiliary_losses(
self,
outputs: Any,
num_items_in_batch: torch.Tensor | int | None,
) -> None:
extras = self._extract_auxiliary_losses_for_logging(outputs)
if not extras:
return
if self.args.n_gpu > 1:
for k in list(extras.keys()):
extras[k] = extras[k].mean()
if (not self.model_accepts_loss_kwargs or num_items_in_batch is None) and self.compute_loss_func is None:
gas = self.current_gradient_accumulation_steps
for k in list(extras.keys()):
extras[k] = extras[k] / gas
device = self.args.device
for k, v in extras.items():
v = v.detach()
if k not in self._aux_losses_accumulator:
self._aux_losses_accumulator[k] = torch.zeros((), device=device, dtype=v.dtype)
self._aux_losses_accumulator[k] = self._aux_losses_accumulator[k] + v.to(device)

def compute_loss(
self,
model: nn.Module,
Expand Down Expand Up @@ -2065,7 +2141,12 @@ def _maybe_log_save_evaluate(
# reset tr_loss to zero
tr_loss -= tr_loss

logs["loss"] = tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged)
steps_since_log = self.state.global_step - self._globalstep_last_logged
logs["loss"] = tr_loss_scalar / steps_since_log
for aux_name, aux_tensor in list(self._aux_losses_accumulator.items()):
aux_scalar = nested_gather(aux_tensor, self.args.parallel_mode).mean().item()
logs[aux_name] = aux_scalar / steps_since_log
self._aux_losses_accumulator.clear()
if grad_norm is not None:
logs["grad_norm"] = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm
if learning_rate is not None:
Expand Down
30 changes: 30 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@
BasicTextGenerationModel,
RegressionDataset,
RegressionModel,
RegressionModelConfig,
RegressionPreTrainedModelWithLossDict,
RegressionTrainingArguments,
RepeatDataset,
StoreLossCallback,
TrainerIntegrationCommon,
Expand Down Expand Up @@ -1110,6 +1113,33 @@ def test_use_liger_kernel_custom_config_trainer(self):
class TrainerIntegrationTest(TestCasePlus):
"""Integration tests: compatibility, and e2e."""

def test_trainer_logs_auxiliary_losses_from_loss_dict(self):
"""When the model returns `loss_dict`, Trainer should log each term (see #31081)."""
with tempfile.TemporaryDirectory() as tmp_dir:
config = RegressionModelConfig(a=0.0, b=0.0)
model = RegressionPreTrainedModelWithLossDict(config)
train_dataset = RegressionDataset(length=16)
args = RegressionTrainingArguments(
output_dir=tmp_dir,
max_steps=2,
logging_steps=1,
per_device_train_batch_size=4,
disable_tqdm=True,
)
trainer = Trainer(model=model, args=args, train_dataset=train_dataset)
trainer.model_accepts_loss_kwargs = False
trainer.train()
for row in trainer.state.log_history:
if "loss" not in row:
continue
self.assertIn("loss_dict_mse", row)
self.assertIn("loss_dict_l1", row)
self.assertIsInstance(row["loss_dict_mse"], float)
self.assertIsInstance(row["loss_dict_l1"], float)
break
else:
self.fail("No training log row with auxiliary losses found")

@slow
@run_first
@require_non_hpu
Expand Down
24 changes: 24 additions & 0 deletions tests/trainer/trainer_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,30 @@ def forward(self, input_x, labels=None, **kwargs):
loss = nn.functional.mse_loss(y, labels)
return (loss, y, y) if self.double_output else (loss, y)

class RegressionPreTrainedModelWithLossDict(PreTrainedModel):
"""Like `RegressionPreTrainedModel` but returns a combined loss plus a `loss_dict` for logging tests."""

config_class = RegressionModelConfig
base_model_prefix = "regression"

def __init__(self, config):
super().__init__(config)
self.a = nn.Parameter(torch.as_tensor(config.a).float())
self.b = nn.Parameter(torch.as_tensor(config.b).float())
self.post_init()

def forward(self, input_x, labels=None, **kwargs):
y = input_x * self.a + self.b
if labels is None:
return (y,)
mse = nn.functional.mse_loss(y, labels)
l1 = nn.functional.l1_loss(y, labels)
combined = mse + 0.25 * l1
return {
"loss": combined,
"loss_dict": {"mse": mse, "l1": l1},
}

class RegressionDictModel(nn.Module):
def __init__(self, a=0, b=0):
super().__init__()
Expand Down
Loading