Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 1 addition & 42 deletions tests/test_dataset_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer

from trl.extras.dataset_formatting import get_formatting_func_from_dataset
from trl.models.utils import ChatMlSpecialTokens, clone_chat_template, setup_chat_format
from trl.models.utils import clone_chat_template

from .testing_utils import TrlTestCase

Expand Down Expand Up @@ -118,47 +118,6 @@ def test_get_formatting_func_from_dataset_with_unknown_format(self):
assert formatting_func is None


@pytest.mark.filterwarnings("ignore::FutureWarning")
class TestSetupChatFormat(TrlTestCase):
def setup_method(self):
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
# remove built-in chat_template to simulate a model having no chat_template
self.tokenizer.chat_template = None

def test_setup_chat_format(self):
modified_model, modified_tokenizer = setup_chat_format(
self.model, self.tokenizer, format="chatml", resize_to_multiple_of=123
)

_chatml = ChatMlSpecialTokens()
# Check if special tokens are correctly set
assert modified_tokenizer.eos_token == "<|im_end|>"
assert modified_tokenizer.pad_token == "<|im_end|>"
assert modified_tokenizer.bos_token == "<|im_start|>"
assert modified_tokenizer.eos_token == _chatml.eos_token
assert modified_tokenizer.pad_token == _chatml.pad_token
assert modified_tokenizer.bos_token == _chatml.bos_token
assert (modified_model.vocab_size % 123) == 0

def test_example_with_setup_model(self):
modified_model, modified_tokenizer = setup_chat_format(
self.model,
self.tokenizer,
)
messages = [
{"role": "system", "content": "You are helpful"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi, how can I help you?"},
]
prompt = modified_tokenizer.apply_chat_template(messages, tokenize=False)

assert (
prompt
== "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n"
)


class TestCloneChatTemplate(TrlTestCase):
def test_clone(self):
# This tokenizer doesn't have a chat_template by default
Expand Down
2 changes: 0 additions & 2 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
"PreTrainedModelWrapper",
"clone_chat_template",
"create_reference_model",
"setup_chat_format",
],
"trainer": [
"AllTrueJudge",
Expand Down Expand Up @@ -129,7 +128,6 @@
PreTrainedModelWrapper,
clone_chat_template,
create_reference_model,
setup_chat_format,
)
from .scripts import DatasetMixtureConfig, ScriptArguments, TrlParser, get_dataset, init_zero_verbose
from .trainer import (
Expand Down
2 changes: 0 additions & 2 deletions trl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
"prepare_fsdp",
"prepare_model_for_kbit_training",
"prepare_peft_model",
"setup_chat_format",
"unwrap_model_for_generation",
],
}
Expand All @@ -45,7 +44,6 @@
prepare_fsdp,
prepare_model_for_kbit_training,
prepare_peft_model,
setup_chat_format,
unwrap_model_for_generation,
)
else:
Expand Down
79 changes: 1 addition & 78 deletions trl/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@

import inspect
import itertools
import warnings
from collections.abc import Callable
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any

import torch
import torch.nn as nn
Expand Down Expand Up @@ -85,82 +84,6 @@ def chat_template(self):
FORMAT_MAPPING = {"chatml": ChatMlSpecialTokens}


def setup_chat_format(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
format: Literal["chatml"] | None = "chatml",
resize_to_multiple_of: int | None = None,
) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
# docstyle-ignore
"""
Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the
embedding layer of the model based on the new special tokens.

> [!WARNING]
> This function is deprecated and will be removed in version 0.26.0. Please use [`clone_chat_template`] instead.

If the model already has a chat template, this will throw an error. If you want to overwrite it, please set
`tokenizer.chat_template` to `None`.

Args:
model ([`~transformers.PreTrainedModel`]): The model to be modified.
tokenizer ([`~transformers.PreTrainedTokenizer`]): The tokenizer to be modified.
format (`Literal["chatml"] | None`): The format to be set. Defaults to "chatml".
resize_to_multiple_of (`int` or `None`): Number to resize the embedding layer to. Defaults to None.

Returns:
model ([`~transformers.PreTrainedModel`]):
The modified model.
tokenizer ([`~transformers.PreTrainedTokenizer`]):
The modified tokenizer.
"""
warnings.warn(
"The `setup_chat_format` function is deprecated and will be removed in version 0.26.0. Please use "
"`clone_chat_template` instead.",
FutureWarning,
)
# check if model already had a chat template
if tokenizer.chat_template is not None:
raise ValueError(
"Chat template is already added to the tokenizer. If you want to overwrite it, please set it to None"
)

# check if format available and retrieve
if format not in FORMAT_MAPPING:
raise ValueError(f"Format {format} not available. Please use one of {FORMAT_MAPPING.keys()}")

chat_format = FORMAT_MAPPING[format]()

# set special tokens and them
tokenizer.eos_token = chat_format.eos_token
tokenizer.pad_token = chat_format.pad_token
tokenizer.bos_token = chat_format.bos_token
tokenizer.add_special_tokens({"additional_special_tokens": [chat_format.bos_token, chat_format.eos_token]})
# set chat format for tokenizer
tokenizer.chat_template = chat_format.chat_template

# resize embedding layer to a multiple of 64, https://x.com/karpathy/status/1621578354024677377
model.resize_token_embeddings(
# After studying many tokenizers, we found that len(tokenizer.vocab) is the most reliable way to get the vocab
# size. Avoid using tokenizer.vocab_size or tokenizer.vocab_size + len(tokenizer.added_tokens_encoder),
# as handling of special and added tokens varies across tokenizers.
new_num_tokens=len(tokenizer.vocab),
pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None,
)
# Update the model config to use the new eos & bos tokens
if getattr(model, "config", None) is not None:
model.config.pad_token_id = tokenizer.pad_token_id
model.config.bos_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
# Update the generation config to use the new eos & bos token
if getattr(model, "generation_config", None) is not None:
model.generation_config.bos_token_id = tokenizer.bos_token_id
model.generation_config.eos_token_id = tokenizer.eos_token_id
model.generation_config.pad_token_id = tokenizer.pad_token_id

return model, tokenizer


def clone_chat_template(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
Expand Down
16 changes: 0 additions & 16 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,16 +222,6 @@ class DPOConfig(TrainingArguments):
generate_during_eval (`bool`, *optional*, defaults to `False`):
Whether to generate and log completions from both the model and the reference model to W&B or Comet during
evaluation.

> Deprecated parameters

padding_value:

<Deprecated version="0.24.0">

This parameter is deprecated and will be removed in version 0.26.0. Use `pad_token` (`str`) instead.

</Deprecated>
"""

_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs", "ref_model_init_kwargs"]
Expand Down Expand Up @@ -508,12 +498,6 @@ class DPOConfig(TrainingArguments):
},
)

# Deprecated arguments
padding_value: int | None = field(
default=None,
metadata={"help": "Deprecated, use `pad_token` (str) instead."},
)

def __post_init__(self):
self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16
self.f_divergence_type = FDivergenceType(self.f_divergence_type)
Expand Down
38 changes: 7 additions & 31 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import inspect
import random
import textwrap
import warnings
from collections import defaultdict
from collections.abc import Callable
from contextlib import contextmanager, nullcontext
Expand Down Expand Up @@ -342,21 +341,14 @@ def __init__(

# Get the pad token: if not provided, use the one from the processing class or the eos token
# if the processing class does not have a pad token.
if args.padding_value is not None: # deprecated, will be removed in 0.26.0.
warnings.warn(
"The `padding_value` argument is deprecated and will be removed in version 0.26.0. Please use "
"`pad_token` (str) instead."
pad_token = args.pad_token or tokenizer.pad_token or tokenizer.eos_token
self.pad_token_id = tokenizer.convert_tokens_to_ids(pad_token)
if self.pad_token_id is None:
raise ValueError(
f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given "
f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
"in the vocabulary before using it as a padding token."
)
self.pad_token_id = args.padding_value
else:
pad_token = args.pad_token or tokenizer.pad_token or tokenizer.eos_token
self.pad_token_id = tokenizer.convert_tokens_to_ids(pad_token)
if self.pad_token_id is None:
raise ValueError(
f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given "
f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
"in the vocabulary before using it as a padding token."
)

# PEFT configuration and model wrapping
model = self._prepare_peft_model(model, ref_model, peft_config, args)
Expand Down Expand Up @@ -558,22 +550,6 @@ def __init__(
if "bco_pair" in self.loss_type:
self.running = RunningMoments(self.accelerator)

@property
def padding_value(self):
warnings.warn(
"The `padding_value` property is deprecated and will be removed in version 0.26.0. Please use "
"`pad_token_id` instead.",
)
return self.pad_token_id

@padding_value.setter
def padding_value(self, value):
warnings.warn(
"The `padding_value` property is deprecated and will be removed in version 0.26.0. Please use "
"`pad_token_id` instead.",
)
self.pad_token_id = value

def _prepare_peft_model(
self, model: PreTrainedModel, ref_model: PreTrainedModel, peft_config: Any, args: DPOConfig
) -> PreTrainedModel:
Expand Down
Loading