Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2237,7 +2237,9 @@ def train(
self.is_in_train = True

# If the model uses a tokenizer, it may have a new tokens for fine-tuning purposes.
if isinstance(self.processing_class, (PreTrainedTokenizerBase, ProcessorMixin)):
if isinstance(self.processing_class, (PreTrainedTokenizerBase, ProcessorMixin)) and hasattr(
self.model, "config"
):
self._align_special_tokens()

# Attach NEFTune hooks if necessary
Expand Down
46 changes: 44 additions & 2 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,11 +516,15 @@ def __init__(self, vocab_size, hidden_size):
self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, vocab_size)

def forward(self, input_ids, **kwargs):
def forward(self, input_ids, labels=None, **kwargs):
embedded = self.embedding(input_ids)
lstm_out, _ = self.lstm(embedded)
logits = self.fc(lstm_out)
return logits
if labels is None:
return logits

loss = nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))
return loss, logits

def create_dummy_dataset_for_text_generation(vocab_size, seq_length, num_samples):
import numpy as np
Expand Down Expand Up @@ -5021,6 +5025,44 @@ def test_special_token_aligment(self):
self.assertEqual(trainer.model.config.pad_token_id, tokenizer.pad_token_id)
self.assertEqual(trainer.model.config.bos_token_id, tokenizer.bos_token_id)

def test_trainer_works_without_model_config(self):
"""
Tests that models without a `config` parameter can still be trained.
This is useful for preserving compatibility with third parties that train different models using the
transformers Trainer.

If this test fails, it doesn't imply that there's issues with transformers, but perhaps with third
parties.
"""

tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
model = BasicTextGenerationModel(vocab_size=tokenizer.vocab_size, hidden_size=32)
# Note that this class does not have a config attribute

train_dataset = LineByLineTextDataset(
tokenizer=tokenizer,
file_path=PATH_SAMPLE_TEXT,
block_size=tokenizer.max_len_single_sentence,
)
for example in train_dataset.examples:
example["labels"] = example["input_ids"]

with tempfile.TemporaryDirectory() as tmpdir:
training_args = TrainingArguments(
output_dir=tmpdir,
report_to="none",
max_steps=5,
per_device_train_batch_size=1,
remove_unused_columns=False,
)
trainer = Trainer(
model=model,
args=training_args,
processing_class=tokenizer,
train_dataset=train_dataset,
)
trainer.train()


@require_torch
@is_staging_test
Expand Down