diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f434d78d4040..f3ca5f2b1152 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1496,6 +1496,8 @@ def _inner_training_loop( self._tr_loss = torch.tensor(0.0, device=args.device) self._total_loss_scalar = 0.0 self._globalstep_last_logged = self.state.global_step + # Sums detached scalar auxiliary losses between logging steps (same schedule as `_tr_loss`). + self._aux_losses_accumulator: dict[str, torch.Tensor] = {} model.zero_grad() @@ -1906,7 +1908,10 @@ def training_step( return loss_mb.reduce_mean().detach().to(self.args.device) with self.compute_loss_context_manager(): - loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) + loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch, return_outputs=True) + outputs = None + if isinstance(loss, tuple): + loss, outputs = loss del inputs if ( @@ -1929,6 +1934,9 @@ def training_step( # If the model does not accept loss kwargs, we need to normalize the loss by the number of gradient accumulation steps loss = loss / self.current_gradient_accumulation_steps + self._accumulate_auxiliary_losses(outputs, num_items_in_batch=num_items_in_batch) + del outputs + # Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled # https://github.com/huggingface/transformers/pull/35808 if self.accelerator.distributed_type == DistributedType.DEEPSPEED: @@ -1938,6 +1946,74 @@ def training_step( return loss.detach() + def _extract_auxiliary_losses_for_logging(self, outputs: Any) -> dict[str, torch.Tensor]: + """ + Collect scalar per-term losses for logging alongside the main `loss`. + + Supports: + - `loss_dict` (`dict` of scalar tensors), as used by several vision/detection heads. + - Top-level output fields whose name contains ``loss`` (apart from the main ``loss``), when they are scalar tensors. + """ + collected: dict[str, torch.Tensor] = {} + if outputs is None: + return collected + try: + items = outputs.items() + except (AttributeError, TypeError): + return collected + + skip_top_level = { + "loss", + "loss_dict", + "logits", + "start_logits", + "end_logits", + "mems", + "hidden_states", + "attentions", + "cross_attentions", + "encoder_last_hidden_state", + "decoder_hidden_states", + "decoder_attentions", + "encoder_attentions", + "past_key_values", + "cache_params", + } + for key, value in items: + if key == "loss_dict" and isinstance(value, dict): + for name, tensor in value.items(): + if isinstance(tensor, torch.Tensor) and tensor.numel() == 1: + safe = str(name).replace("/", "_") + collected[f"loss_dict_{safe}"] = tensor + continue + if key in skip_top_level: + continue + if isinstance(value, torch.Tensor) and value.numel() == 1 and key != "loss" and "loss" in key.lower(): + collected[key] = value + return collected + + def _accumulate_auxiliary_losses( + self, + outputs: Any, + num_items_in_batch: torch.Tensor | int | None, + ) -> None: + extras = self._extract_auxiliary_losses_for_logging(outputs) + if not extras: + return + if self.args.n_gpu > 1: + for k in list(extras.keys()): + extras[k] = extras[k].mean() + if (not self.model_accepts_loss_kwargs or num_items_in_batch is None) and self.compute_loss_func is None: + gas = self.current_gradient_accumulation_steps + for k in list(extras.keys()): + extras[k] = extras[k] / gas + device = self.args.device + for k, v in extras.items(): + v = v.detach() + if k not in self._aux_losses_accumulator: + self._aux_losses_accumulator[k] = torch.zeros((), device=device, dtype=v.dtype) + self._aux_losses_accumulator[k] = self._aux_losses_accumulator[k] + v.to(device) + def compute_loss( self, model: nn.Module, @@ -2065,7 +2141,12 @@ def _maybe_log_save_evaluate( # reset tr_loss to zero tr_loss -= tr_loss - logs["loss"] = tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged) + steps_since_log = self.state.global_step - self._globalstep_last_logged + logs["loss"] = tr_loss_scalar / steps_since_log + for aux_name, aux_tensor in list(self._aux_losses_accumulator.items()): + aux_scalar = nested_gather(aux_tensor, self.args.parallel_mode).mean().item() + logs[aux_name] = aux_scalar / steps_since_log + self._aux_losses_accumulator.clear() if grad_norm is not None: logs["grad_norm"] = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm if learning_rate is not None: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 603630bdb458..3e117935dce3 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -78,6 +78,9 @@ BasicTextGenerationModel, RegressionDataset, RegressionModel, + RegressionModelConfig, + RegressionPreTrainedModelWithLossDict, + RegressionTrainingArguments, RepeatDataset, StoreLossCallback, TrainerIntegrationCommon, @@ -1110,6 +1113,33 @@ def test_use_liger_kernel_custom_config_trainer(self): class TrainerIntegrationTest(TestCasePlus): """Integration tests: compatibility, and e2e.""" + def test_trainer_logs_auxiliary_losses_from_loss_dict(self): + """When the model returns `loss_dict`, Trainer should log each term (see #31081).""" + with tempfile.TemporaryDirectory() as tmp_dir: + config = RegressionModelConfig(a=0.0, b=0.0) + model = RegressionPreTrainedModelWithLossDict(config) + train_dataset = RegressionDataset(length=16) + args = RegressionTrainingArguments( + output_dir=tmp_dir, + max_steps=2, + logging_steps=1, + per_device_train_batch_size=4, + disable_tqdm=True, + ) + trainer = Trainer(model=model, args=args, train_dataset=train_dataset) + trainer.model_accepts_loss_kwargs = False + trainer.train() + for row in trainer.state.log_history: + if "loss" not in row: + continue + self.assertIn("loss_dict_mse", row) + self.assertIn("loss_dict_l1", row) + self.assertIsInstance(row["loss_dict_mse"], float) + self.assertIsInstance(row["loss_dict_l1"], float) + break + else: + self.fail("No training log row with auxiliary losses found") + @slow @run_first @require_non_hpu diff --git a/tests/trainer/trainer_test_utils.py b/tests/trainer/trainer_test_utils.py index 909db11160eb..91368a12feea 100644 --- a/tests/trainer/trainer_test_utils.py +++ b/tests/trainer/trainer_test_utils.py @@ -314,6 +314,30 @@ def forward(self, input_x, labels=None, **kwargs): loss = nn.functional.mse_loss(y, labels) return (loss, y, y) if self.double_output else (loss, y) + class RegressionPreTrainedModelWithLossDict(PreTrainedModel): + """Like `RegressionPreTrainedModel` but returns a combined loss plus a `loss_dict` for logging tests.""" + + config_class = RegressionModelConfig + base_model_prefix = "regression" + + def __init__(self, config): + super().__init__(config) + self.a = nn.Parameter(torch.as_tensor(config.a).float()) + self.b = nn.Parameter(torch.as_tensor(config.b).float()) + self.post_init() + + def forward(self, input_x, labels=None, **kwargs): + y = input_x * self.a + self.b + if labels is None: + return (y,) + mse = nn.functional.mse_loss(y, labels) + l1 = nn.functional.l1_loss(y, labels) + combined = mse + 0.25 * l1 + return { + "loss": combined, + "loss_dict": {"mse": mse, "l1": l1}, + } + class RegressionDictModel(nn.Module): def __init__(self, a=0, b=0): super().__init__()