From b90b518f297d58c35d5f37778bfb1562db2da16a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 8 Jul 2025 03:13:14 +0000 Subject: [PATCH] Refactor label name handling for PEFT models in Trainer class --- src/transformers/trainer.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index ddc447b6a473..14db3deedb2d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -777,15 +777,17 @@ def __init__( # returned to 0 every time flos need to be logged self.current_flos = 0 self.hp_search_backend = None - if _is_peft_model(self.model) and self.args.label_names is None: - logger.warning( - f"No label_names provided for model class `{self.model.__class__.__name__}`." - " Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`." - " Note that empty label_names list will be used instead." - ) - default_label_names = find_labels(self.model.__class__) + + model_to_inspect = self.model + if _is_peft_model(self.model): + if hasattr(self.model, "get_base_model"): + model_to_inspect = self.model.get_base_model() + else: + # PeftMixedModel do not provide a `get_base_model` method + model_to_inspect = self.model.base_model.model + default_label_names = find_labels(model_to_inspect.__class__) self.label_names = default_label_names if self.args.label_names is None else self.args.label_names - self.can_return_loss = can_return_loss(self.model.__class__) + self.can_return_loss = can_return_loss(model_to_inspect.__class__) self.control = self.callback_handler.on_init_end(self.args, self.state, self.control) # Internal variables to help with automatic batch size reduction