diff --git a/src/transformers/tokenization_utils_tokenizers.py b/src/transformers/tokenization_utils_tokenizers.py index afca202127be..b0d65b7c00c4 100644 --- a/src/transformers/tokenization_utils_tokenizers.py +++ b/src/transformers/tokenization_utils_tokenizers.py @@ -35,7 +35,7 @@ from transformers.utils.hub import cached_file -from .convert_slow_tokenizer import SpmConverter +from .convert_slow_tokenizer import SpmConverter, bytes_to_unicode from .integrations.ggml import convert_gguf_tokenizer from .modeling_gguf_pytorch_utils import load_gguf_checkpoint from .tokenization_utils_base import ( @@ -51,6 +51,7 @@ logger = logging.get_logger(__name__) +BYTE_TO_UNICODE = bytes_to_unicode() # Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file TOKENIZER_FILE = "tokenizer.json" @@ -726,9 +727,49 @@ def _convert_id_to_token(self, index: int) -> str | None: def _add_tokens(self, new_tokens: list[str | AddedToken], special_tokens=False) -> int: if special_tokens: return self._tokenizer.add_special_tokens(new_tokens) - + new_tokens = self._maybe_encode_added_tokens_for_bytelevel(new_tokens) return self._tokenizer.add_tokens(new_tokens) + def _maybe_encode_added_tokens_for_bytelevel(self, new_tokens: list[str | AddedToken]) -> list[str | AddedToken]: + pre_tokenizer = getattr(self.backend_tokenizer, "pre_tokenizer", None) + decoder = getattr(self.backend_tokenizer, "decoder", None) + normalizer = getattr(self.backend_tokenizer, "normalizer", None) + + def _contains_bytelevel(component: Any) -> bool: + if component is None: + return False + if component.__class__.__name__ == "ByteLevel": + return True + # Some tokenizers expose wrappers like `Sequence([... ByteLevel(...) ...])`. + # We use repr-based detection as these wrappers do not consistently expose + # iterable internals in the Python bindings. + return "ByteLevel(" in repr(component) + + # Some ByteLevel tokenizers (e.g. GPT-2/Qwen families) may use ByteLevel pre-tokenizer/decoder + # without a ByteLevel normalizer. In this setup, raw unicode added tokens can decode incorrectly + # (e.g. U+010D -> '\r'). Encoding added token contents through the ByteLevel alphabet + # preserves roundtrip behavior. + if _contains_bytelevel(pre_tokenizer) and _contains_bytelevel(decoder) and not _contains_bytelevel(normalizer): + encoded_tokens: list[str | AddedToken] = [] + for token in new_tokens: + if isinstance(token, AddedToken): + encoded_content = "".join(BYTE_TO_UNICODE[b] for b in token.content.encode("utf-8")) + encoded_tokens.append( + AddedToken( + encoded_content, + single_word=token.single_word, + lstrip=token.lstrip, + rstrip=token.rstrip, + normalized=token.normalized, + special=token.special, + ) + ) + else: + encoded_tokens.append("".join(BYTE_TO_UNICODE[b] for b in token.encode("utf-8"))) + return encoded_tokens + + return new_tokens + def num_special_tokens_to_add(self, pair: bool = False) -> int: """ Returns the number of added tokens when encoding a sequence with special tokens. diff --git a/tests/models/gpt2/test_tokenization_gpt2.py b/tests/models/gpt2/test_tokenization_gpt2.py index 8e409064320c..859aa8232851 100644 --- a/tests/models/gpt2/test_tokenization_gpt2.py +++ b/tests/models/gpt2/test_tokenization_gpt2.py @@ -84,6 +84,18 @@ def test_tokenization_tiktoken(self): tiktoken_fast_tokenizer.decode(rust_tokenizer.encode(sequence)), ) + def test_added_tokens_unicode_roundtrip_with_bytelevel(self): + """Regression (#45051): added vocabulary with Unicode must encode/decode cleanly for ByteLevel without a normalizer.""" + tokenizer = AutoTokenizer.from_pretrained(self.from_pretrained_id[0]) + new_tokens = ["Začnimo", "kuća", "međa"] + tokenizer.add_tokens(new_tokens) + + for word in new_tokens: + with self.subTest(word=word): + ids = tokenizer.encode(word, add_special_tokens=False) + decoded = tokenizer.decode(ids, skip_special_tokens=False) + self.assertEqual(decoded, word) + @require_tokenizers class OPTTokenizationTest(unittest.TestCase):