Skip to content

Fix KeyError when patching mistral regex#43376

Merged
vasqu merged 6 commits intohuggingface:mainfrom
LeonardoEmili:bugfix_fix_mistral_regex
Mar 19, 2026
Merged

Fix KeyError when patching mistral regex#43376
vasqu merged 6 commits intohuggingface:mainfrom
LeonardoEmili:bugfix_fix_mistral_regex

Conversation

@LeonardoEmili
Copy link
Copy Markdown
Contributor

What does this PR do?

Seems like the same fix_mistral_regex is provided multiple times, replacing get with pop to avoid running into KeyError fixes the issue.

  File "/mambaforge/envs/cf/lib/python3.10/site-packages/transformers/models/auto/processing_auto.py", line 400, in from_pretrained
    return PROCESSOR_MAPPING[type(config)].from_pretrained(pretrained_model_name_or_path, **kwargs)
  File "/mambaforge/envs/cf/lib/python3.10/site-packages/transformers/processing_utils.py", line 1413, in from_pretrained
    args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, processor_dict, **kwargs)
  File "/mambaforge/envs/cf/lib/python3.10/site-packages/transformers/processing_utils.py", line 1525, in _get_arguments_from_pretrained
    tokenizer = cls._load_tokenizer_from_pretrained(
  File "/mambaforge/envs/cf/lib/python3.10/site-packages/transformers/processing_utils.py", line 1474, in _load_tokenizer_from_pretrained
    tokenizer = auto_processor_class.from_pretrained(
  File "/mambaforge/envs/cf/lib/python3.10/site-packages/transformers/models/auto/tokenization_auto.py", line 700, in from_pretrained
    return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
  File "/mambaforge/envs/cf/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 1757, in from_pretrained
    return cls._from_pretrained(
  File "/mambaforge/envs/cf/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 2009, in _from_pretrained
    tokenizer = cls(*init_inputs, **init_kwargs)
  File "/mambaforge/envs/cf/lib/python3.10/site-packages/transformers/models/gemma/tokenization_gemma.py", line 99, in __init__
    super().__init__(
  File "/mambaforge/envs/cf/lib/python3.10/site-packages/transformers/tokenization_utils_tokenizers.py", line 373, in __init__
    self._tokenizer = self._patch_mistral_regex(
KeyError: 'fix_mistral_regex'

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@zucchini-nlp @ArthurZucker @itazap

@LeonardoEmili LeonardoEmili changed the title Fix KeyError: 'fix_mistral_regex' when patching mistral regex Fix KeyError when patching mistral regex Jan 20, 2026
Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a reproducer? We need to add a test to avoid encountering this again

@LeonardoEmili
Copy link
Copy Markdown
Contributor Author

LeonardoEmili commented Jan 21, 2026

Do you have a reproducer? We need to add a test to avoid encountering this again

Of course, note the stable releases are not affected by the issue, however the issue seem to arise when switching to the v5.0.0rc2 version (had to switch to pre-release after facing this issue), here's the reproducer:

conda create -n pr_43376 python=3.10 -y
uv pip install torch --index-url https://download.pytorch.org/whl/cpu
uv pip install pillow protobuf  # gemmatranslate requirements
uv pip install git+https://github.com/huggingface/transformers.git@v5.0.0rc2

> from transformers import AutoProcessor
> model_id = "google/translategemma-4b-it"
> processor = AutoProcessor.from_pretrained(model_id,token='<HF_TOKEN>')  # suggests fix_mistral_regex=True as in https://huggingface.co/google/t5gemma-2-4b-4b/discussions/4
> processor = AutoProcessor.from_pretrained(model_id,token='<HF_TOKEN>',fix_mistral_regex=True)
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/miniconda3/envs/pr_43376/lib/python3.10/site-packages/transformers/models/auto/processing_auto.py", line 395, in from_pretrained
    return processor_class.from_pretrained(
  File "/opt/miniconda3/envs/pr_43376/lib/python3.10/site-packages/transformers/processing_utils.py", line 1413, in from_pretrained
    args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, processor_dict, **kwargs)
  File "/opt/miniconda3/envs/pr_43376/lib/python3.10/site-packages/transformers/processing_utils.py", line 1525, in _get_arguments_from_pretrained
    tokenizer = cls._load_tokenizer_from_pretrained(
  File "/opt/miniconda3/envs/pr_43376/lib/python3.10/site-packages/transformers/processing_utils.py", line 1474, in _load_tokenizer_from_pretrained
    tokenizer = auto_processor_class.from_pretrained(
  File "/opt/miniconda3/envs/pr_43376/lib/python3.10/site-packages/transformers/models/auto/tokenization_auto.py", line 700, in from_pretrained
    return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
  File "/opt/miniconda3/envs/pr_43376/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 1757, in from_pretrained
    return cls._from_pretrained(
  File "/opt/miniconda3/envs/pr_43376/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 2009, in _from_pretrained
    tokenizer = cls(*init_inputs, **init_kwargs)
  File "/opt/miniconda3/envs/pr_43376/lib/python3.10/site-packages/transformers/models/gemma/tokenization_gemma.py", line 99, in __init__
    super().__init__(
  File "/opt/miniconda3/envs/pr_43376/lib/python3.10/site-packages/transformers/tokenization_utils_tokenizers.py", line 373, in __init__
    self._tokenizer = self._patch_mistral_regex(
TypeError: transformers.tokenization_utils_tokenizers.TokenizersBackend._patch_mistral_regex() got multiple values for keyword argument 'fix_mistral_regex'

As a further test, I tried directly testing the code in the main branch (uv pip install git+https://github.com/huggingface/transformers.git@5c773b8a84677192d4a52edc1e2c8823f9c1dcea) and the code doesn't crash anymore. Reason seems to be the pre_tokenizer has been removed from this model but the issue can still appear with other models (I don't actually have one that fails with me):

if vocab_size > 100000 and getattr(self._tokenizer, "pre_tokenizer", None) is not None:

This issue should be resolved in future transformers versions as if the pre_tokenizer exists it would fail again since the key fix_mistral_regex is provided twice (as named arg + kwargs). Hope this helps, cheers!

self._tokenizer = self._patch_mistral_regex(
self._tokenizer,
self.init_kwargs.get("name_or_path", None),
init_kwargs=self.init_kwargs,
fix_mistral_regex=kwargs.get("fix_mistral_regex"),
**kwargs,
)

@LeonardoEmili LeonardoEmili requested a review from vasqu January 22, 2026 11:42
Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the details! The fix is obviously correct but we definitely need a test, even if it is a dummy tokenizer - I can move it to our internal testing repos as well

@younesbelkada
Copy link
Copy Markdown
Contributor

Hi @vasqu @LeonardoEmili

I can confirm I also encounter this on my end with latest commit from transformers main - below is a simple reproducer:

from transformers import AutoTokenizer

tok = AutoTokenizer.from_pretrained("mistralai/Ministral-3-3B-Instruct-2512", fix_mistral_regex=True)

I also confirm the changes of this PR fixes the issue, I am also happy to help finalizing the PR if this makes sense 🙏

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Mar 12, 2026

Sure, let's get this in! Can we quickly add a regression test with the tokenizer? @LeonardoEmili @younesbelkada

Don't mind if it's this PR specifically or another, let's just coordinate

@younesbelkada
Copy link
Copy Markdown
Contributor

younesbelkada commented Mar 12, 2026

Sure - happy whatever way @LeonardoEmili if you don't plan to add a test on your PR I can open a new PR and add you as co-author as well. Let me know what works best!

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was so free to add the tests and a last detail, since I want to merge this

thanks a lot to both of you @younesbelkada @LeonardoEmili (I probably would have forgot/lost it otherwise 😅)

@vasqu vasqu enabled auto-merge March 19, 2026 07:18
@vasqu vasqu added this pull request to the merge queue Mar 19, 2026
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@vasqu vasqu removed this pull request from the merge queue due to a manual request Mar 19, 2026
@vasqu vasqu added this pull request to the merge queue Mar 19, 2026
@vasqu vasqu removed this pull request from the merge queue due to a manual request Mar 19, 2026
@vasqu vasqu added this pull request to the merge queue Mar 19, 2026
Merged via the queue into huggingface:main with commit cecacd3 Mar 19, 2026
28 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants