From fdaa06fbc8362639b5b6044e6c8e4f4e5cd256a2 Mon Sep 17 00:00:00 2001 From: ErenAta16 Date: Fri, 27 Mar 2026 18:22:36 +0300 Subject: [PATCH 1/8] test: add xfail regression for ByteLevel added-token unicode decode Add a GPT-2 regression test that captures added token Unicode decode corruption with ByteLevel tokenizers and mark it xfail while the underlying tokenizers-layer fix is pending. Made-with: Cursor --- tests/models/gpt2/test_tokenization_gpt2.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/models/gpt2/test_tokenization_gpt2.py b/tests/models/gpt2/test_tokenization_gpt2.py index 8e409064320c..e3eea21f35f6 100644 --- a/tests/models/gpt2/test_tokenization_gpt2.py +++ b/tests/models/gpt2/test_tokenization_gpt2.py @@ -15,6 +15,8 @@ import unittest +import pytest + from transformers import AutoTokenizer, GPT2Tokenizer from transformers.testing_utils import require_tiktoken, require_tokenizers @@ -84,6 +86,25 @@ def test_tokenization_tiktoken(self): tiktoken_fast_tokenizer.decode(rust_tokenizer.encode(sequence)), ) + @pytest.mark.xfail( + reason="Blocked by huggingface/tokenizers ByteLevel added-token decode behavior for certain Unicode chars.", + strict=False, + ) + def test_added_tokens_unicode_roundtrip_with_bytelevel(self): + tokenizer_fast = AutoTokenizer.from_pretrained("gpt2", use_fast=True) + tokenizer_slow = AutoTokenizer.from_pretrained("gpt2", use_fast=False) + + new_tokens = ["Začnimo", "kuća", "međa"] + tokenizer_fast.add_tokens(new_tokens) + tokenizer_slow.add_tokens(new_tokens) + + for tokenizer in (tokenizer_fast, tokenizer_slow): + with self.subTest(tokenizer_class=tokenizer.__class__.__name__): + for word in new_tokens: + ids = tokenizer.encode(word, add_special_tokens=False) + decoded = tokenizer.decode(ids, skip_special_tokens=False) + self.assertEqual(decoded, word) + @require_tokenizers class OPTTokenizationTest(unittest.TestCase): From fcd5b203c23c90926beca6e0d0a90d89d3cc9f8d Mon Sep 17 00:00:00 2001 From: ErenAta16 Date: Fri, 27 Mar 2026 18:32:41 +0300 Subject: [PATCH 2/8] Revert "test: add xfail regression for ByteLevel added-token unicode decode" This reverts commit fdaa06fbc8362639b5b6044e6c8e4f4e5cd256a2. --- tests/models/gpt2/test_tokenization_gpt2.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/tests/models/gpt2/test_tokenization_gpt2.py b/tests/models/gpt2/test_tokenization_gpt2.py index e3eea21f35f6..8e409064320c 100644 --- a/tests/models/gpt2/test_tokenization_gpt2.py +++ b/tests/models/gpt2/test_tokenization_gpt2.py @@ -15,8 +15,6 @@ import unittest -import pytest - from transformers import AutoTokenizer, GPT2Tokenizer from transformers.testing_utils import require_tiktoken, require_tokenizers @@ -86,25 +84,6 @@ def test_tokenization_tiktoken(self): tiktoken_fast_tokenizer.decode(rust_tokenizer.encode(sequence)), ) - @pytest.mark.xfail( - reason="Blocked by huggingface/tokenizers ByteLevel added-token decode behavior for certain Unicode chars.", - strict=False, - ) - def test_added_tokens_unicode_roundtrip_with_bytelevel(self): - tokenizer_fast = AutoTokenizer.from_pretrained("gpt2", use_fast=True) - tokenizer_slow = AutoTokenizer.from_pretrained("gpt2", use_fast=False) - - new_tokens = ["Začnimo", "kuća", "međa"] - tokenizer_fast.add_tokens(new_tokens) - tokenizer_slow.add_tokens(new_tokens) - - for tokenizer in (tokenizer_fast, tokenizer_slow): - with self.subTest(tokenizer_class=tokenizer.__class__.__name__): - for word in new_tokens: - ids = tokenizer.encode(word, add_special_tokens=False) - decoded = tokenizer.decode(ids, skip_special_tokens=False) - self.assertEqual(decoded, word) - @require_tokenizers class OPTTokenizationTest(unittest.TestCase): From 538c6bbcfa54ca8f6e818166e1a714872fd796bc Mon Sep 17 00:00:00 2001 From: ErenAta16 Date: Fri, 27 Mar 2026 18:41:56 +0300 Subject: [PATCH 3/8] fix ByteLevel added-token unicode decoding for GPT2-like tokenizers Encode newly added tokens through the ByteLevel unicode alphabet when the backend uses a ByteLevel pre-tokenizer and decoder without a normalizer, preventing control-character corruption on decode. Add a GPT-2 regression test to validate unicode roundtrip for added tokens. Made-with: Cursor --- .../tokenization_utils_tokenizers.py | 40 +++++++++++++++++++ tests/models/gpt2/test_tokenization_gpt2.py | 11 +++++ 2 files changed, 51 insertions(+) diff --git a/src/transformers/tokenization_utils_tokenizers.py b/src/transformers/tokenization_utils_tokenizers.py index afca202127be..6a6a924a098c 100644 --- a/src/transformers/tokenization_utils_tokenizers.py +++ b/src/transformers/tokenization_utils_tokenizers.py @@ -35,6 +35,7 @@ from transformers.utils.hub import cached_file +from .convert_slow_tokenizer import bytes_to_unicode from .convert_slow_tokenizer import SpmConverter from .integrations.ggml import convert_gguf_tokenizer from .modeling_gguf_pytorch_utils import load_gguf_checkpoint @@ -51,6 +52,7 @@ logger = logging.get_logger(__name__) +BYTE_TO_UNICODE = bytes_to_unicode() # Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file TOKENIZER_FILE = "tokenizer.json" @@ -724,11 +726,49 @@ def _convert_id_to_token(self, index: int) -> str | None: return self._tokenizer.id_to_token(int(index)) def _add_tokens(self, new_tokens: list[str | AddedToken], special_tokens=False) -> int: + if not special_tokens: + new_tokens = self._maybe_encode_added_tokens_for_bytelevel(new_tokens) if special_tokens: return self._tokenizer.add_special_tokens(new_tokens) return self._tokenizer.add_tokens(new_tokens) + def _maybe_encode_added_tokens_for_bytelevel(self, new_tokens: list[str | AddedToken]) -> list[str | AddedToken]: + pre_tokenizer = getattr(self.backend_tokenizer, "pre_tokenizer", None) + decoder = getattr(self.backend_tokenizer, "decoder", None) + normalizer = getattr(self.backend_tokenizer, "normalizer", None) + + # Some ByteLevel tokenizers (e.g. GPT-2 family) have ByteLevel pre-tokenizer/decoder + # but no normalizer. In this setup, raw unicode added tokens can decode incorrectly + # (e.g. U+010D -> '\r'). Encoding added token contents through the ByteLevel alphabet + # preserves roundtrip behavior. + if ( + normalizer is None + and pre_tokenizer is not None + and pre_tokenizer.__class__.__name__ == "ByteLevel" + and decoder is not None + and decoder.__class__.__name__ == "ByteLevel" + ): + encoded_tokens: list[str | AddedToken] = [] + for token in new_tokens: + if isinstance(token, AddedToken): + encoded_content = "".join(BYTE_TO_UNICODE[b] for b in token.content.encode("utf-8")) + encoded_tokens.append( + AddedToken( + encoded_content, + single_word=token.single_word, + lstrip=token.lstrip, + rstrip=token.rstrip, + normalized=token.normalized, + special=token.special, + ) + ) + else: + encoded_tokens.append("".join(BYTE_TO_UNICODE[b] for b in token.encode("utf-8"))) + return encoded_tokens + + return new_tokens + def num_special_tokens_to_add(self, pair: bool = False) -> int: """ Returns the number of added tokens when encoding a sequence with special tokens. diff --git a/tests/models/gpt2/test_tokenization_gpt2.py b/tests/models/gpt2/test_tokenization_gpt2.py index 8e409064320c..42b5c1491e31 100644 --- a/tests/models/gpt2/test_tokenization_gpt2.py +++ b/tests/models/gpt2/test_tokenization_gpt2.py @@ -84,6 +84,17 @@ def test_tokenization_tiktoken(self): tiktoken_fast_tokenizer.decode(rust_tokenizer.encode(sequence)), ) + def test_added_tokens_unicode_roundtrip_with_bytelevel(self): + tokenizer = AutoTokenizer.from_pretrained("gpt2") + new_tokens = ["Začnimo", "kuća", "međa"] + tokenizer.add_tokens(new_tokens) + + for word in new_tokens: + with self.subTest(word=word): + ids = tokenizer.encode(word, add_special_tokens=False) + decoded = tokenizer.decode(ids, skip_special_tokens=False) + self.assertEqual(decoded, word) + @require_tokenizers class OPTTokenizationTest(unittest.TestCase): From d4e214b5cf921218e28fc3ae89fb1d81f51c0498 Mon Sep 17 00:00:00 2001 From: ErenAta16 Date: Fri, 27 Mar 2026 18:56:34 +0300 Subject: [PATCH 4/8] fix ByteLevel added-token handling when normalizer is non-ByteLevel Apply ByteLevel encoding to newly added tokens whenever tokenizer decoding uses ByteLevel but normalization does not, covering setups like Qwen (NFC normalizer + ByteLevel pre-tokenizer/decoder) and preventing unicode-to-control-character corruption on decode. Made-with: Cursor --- .../tokenization_utils_tokenizers.py | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/transformers/tokenization_utils_tokenizers.py b/src/transformers/tokenization_utils_tokenizers.py index 6a6a924a098c..e21ee2170a00 100644 --- a/src/transformers/tokenization_utils_tokenizers.py +++ b/src/transformers/tokenization_utils_tokenizers.py @@ -737,17 +737,25 @@ def _maybe_encode_added_tokens_for_bytelevel(self, new_tokens: list[str | AddedT pre_tokenizer = getattr(self.backend_tokenizer, "pre_tokenizer", None) decoder = getattr(self.backend_tokenizer, "decoder", None) normalizer = getattr(self.backend_tokenizer, "normalizer", None) + + def _contains_bytelevel(component: Any) -> bool: + if component is None: + return False + if component.__class__.__name__ == "ByteLevel": + return True + # Some tokenizers expose wrappers like `Sequence([... ByteLevel(...) ...])`. + # We use repr-based detection as these wrappers do not consistently expose + # iterable internals in the Python bindings. + return "ByteLevel(" in repr(component) - # Some ByteLevel tokenizers (e.g. GPT-2 family) have ByteLevel pre-tokenizer/decoder - # but no normalizer. In this setup, raw unicode added tokens can decode incorrectly + # Some ByteLevel tokenizers (e.g. GPT-2/Qwen families) may use ByteLevel pre-tokenizer/decoder + # without a ByteLevel normalizer. In this setup, raw unicode added tokens can decode incorrectly # (e.g. U+010D -> '\r'). Encoding added token contents through the ByteLevel alphabet # preserves roundtrip behavior. if ( - normalizer is None - and pre_tokenizer is not None - and pre_tokenizer.__class__.__name__ == "ByteLevel" - and decoder is not None - and decoder.__class__.__name__ == "ByteLevel" + _contains_bytelevel(pre_tokenizer) + and _contains_bytelevel(decoder) + and not _contains_bytelevel(normalizer) ): encoded_tokens: list[str | AddedToken] = [] for token in new_tokens: From cdb41be88b1a3575df8c4b7369458d61e59bc21f Mon Sep 17 00:00:00 2001 From: ErenAta16 Date: Fri, 27 Mar 2026 19:20:35 +0300 Subject: [PATCH 5/8] chore: fix CI lint and typing follow-ups for ByteLevel tokenizer patch Remove a stale type ignore in generation utils and clean formatting/import ordering so check_code_quality passes on the PR branch. Made-with: Cursor --- src/transformers/generation/utils.py | 2 +- src/transformers/tokenization_utils_tokenizers.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 8a55c184b0f0..85beb2a03f20 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2540,7 +2540,7 @@ def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, # The following logic allows an early break if all peers finished generating their sequence this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0, device=device) # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # type: ignore + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then if this_peer_finished_flag.item() == 0.0: return False diff --git a/src/transformers/tokenization_utils_tokenizers.py b/src/transformers/tokenization_utils_tokenizers.py index e21ee2170a00..7fff15f21b90 100644 --- a/src/transformers/tokenization_utils_tokenizers.py +++ b/src/transformers/tokenization_utils_tokenizers.py @@ -35,8 +35,7 @@ from transformers.utils.hub import cached_file -from .convert_slow_tokenizer import bytes_to_unicode -from .convert_slow_tokenizer import SpmConverter +from .convert_slow_tokenizer import SpmConverter, bytes_to_unicode from .integrations.ggml import convert_gguf_tokenizer from .modeling_gguf_pytorch_utils import load_gguf_checkpoint from .tokenization_utils_base import ( @@ -737,7 +736,7 @@ def _maybe_encode_added_tokens_for_bytelevel(self, new_tokens: list[str | AddedT pre_tokenizer = getattr(self.backend_tokenizer, "pre_tokenizer", None) decoder = getattr(self.backend_tokenizer, "decoder", None) normalizer = getattr(self.backend_tokenizer, "normalizer", None) - + def _contains_bytelevel(component: Any) -> bool: if component is None: return False From c4d8d546c212da2569512d0e985f491854e35df7 Mon Sep 17 00:00:00 2001 From: ErenAta16 Date: Sat, 28 Mar 2026 00:20:30 +0300 Subject: [PATCH 6/8] chore: drop unrelated generation utils diff from ByteLevel tokenizer PR Restore dist.all_reduce line to match upstream main so check_code_quality stays aligned with the type checker configuration. Made-with: Cursor --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 85beb2a03f20..8a55c184b0f0 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2540,7 +2540,7 @@ def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, # The following logic allows an early break if all peers finished generating their sequence this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0, device=device) # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # type: ignore # did all peers finish? the reduced sum will be 0.0 then if this_peer_finished_flag.item() == 0.0: return False From b22d5de9b8b6c55105807424deb5bfe6eedd380f Mon Sep 17 00:00:00 2001 From: ErenAta16 Date: Sat, 28 Mar 2026 00:33:58 +0300 Subject: [PATCH 7/8] style: ruff-format ByteLevel added-token branch condition CI check_code_quality runs `ruff format --check`; collapse the multi-line if to match formatter output. Made-with: Cursor --- src/transformers/tokenization_utils_tokenizers.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/transformers/tokenization_utils_tokenizers.py b/src/transformers/tokenization_utils_tokenizers.py index 7fff15f21b90..a3900baff8a1 100644 --- a/src/transformers/tokenization_utils_tokenizers.py +++ b/src/transformers/tokenization_utils_tokenizers.py @@ -751,11 +751,7 @@ def _contains_bytelevel(component: Any) -> bool: # without a ByteLevel normalizer. In this setup, raw unicode added tokens can decode incorrectly # (e.g. U+010D -> '\r'). Encoding added token contents through the ByteLevel alphabet # preserves roundtrip behavior. - if ( - _contains_bytelevel(pre_tokenizer) - and _contains_bytelevel(decoder) - and not _contains_bytelevel(normalizer) - ): + if _contains_bytelevel(pre_tokenizer) and _contains_bytelevel(decoder) and not _contains_bytelevel(normalizer): encoded_tokens: list[str | AddedToken] = [] for token in new_tokens: if isinstance(token, AddedToken): From 03484906c22aa71286b861751d49f83df78012c4 Mon Sep 17 00:00:00 2001 From: ErenAta16 Date: Sat, 28 Mar 2026 00:36:26 +0300 Subject: [PATCH 8/8] refactor: clarify ByteLevel add_tokens path and stabilize GPT-2 regression test - Return early for special_tokens before optional ByteLevel vocabulary encoding. - Load GPT-2 via from_pretrained_id and document #45051 in the test docstring. Made-with: Cursor --- src/transformers/tokenization_utils_tokenizers.py | 4 +--- tests/models/gpt2/test_tokenization_gpt2.py | 3 ++- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/tokenization_utils_tokenizers.py b/src/transformers/tokenization_utils_tokenizers.py index a3900baff8a1..b0d65b7c00c4 100644 --- a/src/transformers/tokenization_utils_tokenizers.py +++ b/src/transformers/tokenization_utils_tokenizers.py @@ -725,11 +725,9 @@ def _convert_id_to_token(self, index: int) -> str | None: return self._tokenizer.id_to_token(int(index)) def _add_tokens(self, new_tokens: list[str | AddedToken], special_tokens=False) -> int: - if not special_tokens: - new_tokens = self._maybe_encode_added_tokens_for_bytelevel(new_tokens) if special_tokens: return self._tokenizer.add_special_tokens(new_tokens) - + new_tokens = self._maybe_encode_added_tokens_for_bytelevel(new_tokens) return self._tokenizer.add_tokens(new_tokens) def _maybe_encode_added_tokens_for_bytelevel(self, new_tokens: list[str | AddedToken]) -> list[str | AddedToken]: diff --git a/tests/models/gpt2/test_tokenization_gpt2.py b/tests/models/gpt2/test_tokenization_gpt2.py index 42b5c1491e31..859aa8232851 100644 --- a/tests/models/gpt2/test_tokenization_gpt2.py +++ b/tests/models/gpt2/test_tokenization_gpt2.py @@ -85,7 +85,8 @@ def test_tokenization_tiktoken(self): ) def test_added_tokens_unicode_roundtrip_with_bytelevel(self): - tokenizer = AutoTokenizer.from_pretrained("gpt2") + """Regression (#45051): added vocabulary with Unicode must encode/decode cleanly for ByteLevel without a normalizer.""" + tokenizer = AutoTokenizer.from_pretrained(self.from_pretrained_id[0]) new_tokens = ["Začnimo", "kuća", "međa"] tokenizer.add_tokens(new_tokens)