Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import tempfile
import time
import warnings
from collections import defaultdict
from collections.abc import Callable, Iterator, Mapping
from functools import partial
from pathlib import Path
Expand Down Expand Up @@ -435,6 +436,8 @@ def __init__(
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
self._memory_tracker.start()

self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we can leverage this self._metrics to log the loss

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

though about it, but loss uses tensor accumulation to avoid TPU sync overhead. metrics is more for auxiliary custom metrics that don't need that optimization


# set the correct log level depending on the node
log_level = args.get_process_log_level()
logging.set_verbosity(log_level)
Expand Down Expand Up @@ -3567,6 +3570,14 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
start_time (`Optional[float]`):
The start of training.
"""
mode = "train" if self.model.training else "eval"
if self._metrics[mode]:
metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()}
if mode == "eval":
metrics = {f"eval_{key}": val for key, val in metrics.items()}
logs = {**logs, **metrics}
self._metrics[mode].clear()

if self.state.epoch is not None:
logs["epoch"] = self.state.epoch
if self.args.include_num_input_tokens_seen != "no":
Expand Down