diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index bcb58e8d40b9..71180504536d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -30,6 +30,7 @@ import tempfile import time import warnings +from collections import defaultdict from collections.abc import Callable, Iterator, Mapping from functools import partial from pathlib import Path @@ -435,6 +436,8 @@ def __init__( self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics) self._memory_tracker.start() + self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} + # set the correct log level depending on the node log_level = args.get_process_log_level() logging.set_verbosity(log_level) @@ -3567,6 +3570,14 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: start_time (`Optional[float]`): The start of training. """ + mode = "train" if self.model.training else "eval" + if self._metrics[mode]: + metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} + if mode == "eval": + metrics = {f"eval_{key}": val for key, val in metrics.items()} + logs = {**logs, **metrics} + self._metrics[mode].clear() + if self.state.epoch is not None: logs["epoch"] = self.state.epoch if self.args.include_num_input_tokens_seen != "no":