Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,15 +783,17 @@ def __init__(
# returned to 0 every time flos need to be logged
self.current_flos = 0
self.hp_search_backend = None
if _is_peft_model(self.model) and self.args.label_names is None:
logger.warning(
f"No label_names provided for model class `{self.model.__class__.__name__}`."
" Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`."
" Note that empty label_names list will be used instead."
)
default_label_names = find_labels(self.model.__class__)

model_to_inspect = self.model
if _is_peft_model(self.model):
if hasattr(self.model, "get_base_model"):
model_to_inspect = self.model.get_base_model()
else:
# PeftMixedModel do not provide a `get_base_model` method
model_to_inspect = self.model.base_model.model
default_label_names = find_labels(model_to_inspect.__class__)
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
self.can_return_loss = can_return_loss(self.model.__class__)
self.can_return_loss = can_return_loss(model_to_inspect.__class__)
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)

# Internal variables to help with automatic batch size reduction
Expand Down