From ffcd459e75d38d71c51b6a42b1e3a487b0d688c0 Mon Sep 17 00:00:00 2001 From: Abigail Date: Sat, 31 Jan 2026 22:28:10 +0100 Subject: [PATCH] Add _get_signature_columns method to allow custom trainers to override column filtering --- src/transformers/trainer.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index bcb58e8d40b9..413009d49ffd 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -925,20 +925,21 @@ def _align_special_tokens(self): f"values. Updated tokens: {updated_tokens}." ) + def _get_signature_columns(self) -> list[str]: + model_to_inspect = self.model + if _is_peft_model(self.model): + if hasattr(self.model, "get_base_model"): + model_to_inspect = self.model.get_base_model() + else: + model_to_inspect = self.model.base_model.model + signature = inspect.signature(model_to_inspect.forward) + columns = list(signature.parameters.keys()) + columns += list(set(["label", "label_ids"] + self.label_names)) + return columns + def _set_signature_columns_if_needed(self): if self._signature_columns is None: - # Inspect model forward signature to keep only the arguments it accepts. - model_to_inspect = self.model - if _is_peft_model(self.model): - if hasattr(self.model, "get_base_model"): - model_to_inspect = self.model.get_base_model() - else: - # PeftMixedModel do not provide a `get_base_model` method - model_to_inspect = self.model.base_model.model - signature = inspect.signature(model_to_inspect.forward) - self._signature_columns = list(signature.parameters.keys()) - # Labels may be named label or label_ids, the default data collator handles that. - self._signature_columns += list(set(["label", "label_ids"] + self.label_names)) + self._signature_columns = self._get_signature_columns() def _remove_unused_columns(self, dataset: "datasets.Dataset", description: str | None = None): if not self.args.remove_unused_columns: