Skip to content

Use a private _metrics dict to allow for additional metric logging #43599

@qgallouedec

Description

@qgallouedec

When defining your own trainer, you want to log your own metrics. Over the time in TRL we've converged toward the use of this structure in all trainers:

from collections import defaultdict
from transformers import Trainer

class MyTrainer(Trainer):
    def __init__( self, ...):
        ...
        self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        mode = "train" if self.model.training else "eval"
        (loss, outputs) = super().compute_loss(
            model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
        )
        ...
        self._metrics[mode]["my_metric"].append(my_value)

    def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
        mode = "train" if self.model.training else "eval"
        metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()}  # average the metrics

        # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
        # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
        if mode == "eval":
            metrics = {f"eval_{key}": val for key, val in metrics.items()}

        logs = {**logs, **metrics}
        super().log(logs, start_time)
        self._metrics[mode].clear()

which is very satisfactory, but requires wrapping Trainer.log. We may want to move this design upstream (i.e., have an internal dictionary _metrics) responsible for storing metrics.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions