Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,20 +925,21 @@ def _align_special_tokens(self):
f"values. Updated tokens: {updated_tokens}."
)

def _get_signature_columns(self) -> list[str]:
model_to_inspect = self.model
if _is_peft_model(self.model):
if hasattr(self.model, "get_base_model"):
model_to_inspect = self.model.get_base_model()
else:
model_to_inspect = self.model.base_model.model
signature = inspect.signature(model_to_inspect.forward)
columns = list(signature.parameters.keys())
columns += list(set(["label", "label_ids"] + self.label_names))
return columns

def _set_signature_columns_if_needed(self):
if self._signature_columns is None:
# Inspect model forward signature to keep only the arguments it accepts.
model_to_inspect = self.model
if _is_peft_model(self.model):
if hasattr(self.model, "get_base_model"):
model_to_inspect = self.model.get_base_model()
else:
# PeftMixedModel do not provide a `get_base_model` method
model_to_inspect = self.model.base_model.model
signature = inspect.signature(model_to_inspect.forward)
self._signature_columns = list(signature.parameters.keys())
# Labels may be named label or label_ids, the default data collator handles that.
self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
self._signature_columns = self._get_signature_columns()

def _remove_unused_columns(self, dataset: "datasets.Dataset", description: str | None = None):
if not self.args.remove_unused_columns:
Expand Down