Skip to content

Add Xiaomi MiMo-V2#45144

Open
casinca wants to merge 85 commits intohuggingface:mainfrom
casinca:xiaomi-mimo-v2
Open

Add Xiaomi MiMo-V2#45144
casinca wants to merge 85 commits intohuggingface:mainfrom
casinca:xiaomi-mimo-v2

Conversation

@casinca
Copy link
Copy Markdown
Contributor

@casinca casinca commented Mar 31, 2026

What does this PR do?

Hello, this PR aims to add the MiMo-V2-Flash model to the Transformers library
Fixes #42954

MiMo-V2 is "The last of the OSS SOTAs" that isn't natively supported by the Transformers library (besides Kimi), so I hope we can make this work.

Code Agent Policy

The Transformers repo is currently being overwhelmed by a large number of PRs and issue comments written by
code agents. We are currently bottlenecked by our ability to review and respond to them. As a result,
we ask that new users do not submit pure code agent PRs at this time.
You may use code agents in drafting or to help you diagnose issues. We'd also ask autonomous "OpenClaw"-like agents
not to open any PRs or issues for the moment.

PRs that appear to be fully agent-written will probably be closed without review, and we may block users who do this
repeatedly or maliciously.

This is a rapidly-evolving situation that's causing significant shockwaves in the open-source community. As a result,
this policy is likely to be updated regularly in the near future. For more information, please read CONTRIBUTING.md.

  • I confirm that this is not a pure code agent PR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Only pinging HF engineers who were involved in the PRs mentioned below:

Modular Sensei @vasqu , I took your comments from the other PR, to find the best candidates to inherit from. Although see #45144 (comment) for some important points.

@ArthurZucker for gpt-oss similarity and text model

For reference, there were already 2 previous PRs (1 closed and 1 non modular/abandoned):

 

Afaik, MiMo in the codebase is a bit novel as it's the only hybrid SWA that combines both dual theta RoPE and dual head
dims + attn sinks (like gpt-oss). Most similar models are either one or the other, but not both.
So modularity wasn't easy mode pass, pass, pass.

I added some verbose NOTES for context/choices, and make the review easier. I'll obviously remove these once things are settled.

Additional useful info

(copy pasta from my own readme impl)

  • First layer is fully dense (GA+FFN, not MoE)
  • No shared experts in MoE
  • SWA and GA layers have different RoPE theta bases
  • SWA and GA layers have a different number of KV groups (GQA)
  • Values head dim is decoupled from QK head dim
  • Values are rescaled by 1√2
  • Attention sink is only applied for SWA layers
  • Partial RoPE (rotating first 33% of the head dim)
  • Triple dtype with 1 quant

@casinca
Copy link
Copy Markdown
Contributor Author

casinca commented Mar 31, 2026

Important points:


fixed in: #45441

MiMoV2FlashTopKRouter:
Atm it is not a modular implementation inheriting from the existing DeepseekV3TopKRouter in transformers, for two reasons:

The native DSV3 router uses masked_fill=0.0 instead of masked_fill=-inf when masking non-selected expert groups. The remote DSV3 repo was fixed (see https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/commit/e9b33add76883f293d6bf61f6bd89b497e80e335#d2h-632685) and remote MiMo also follows the fixed version. So I had to follow the correct behavior.

If it makes sense to change this masking in the repo for DSV3, then I can try to inherit something.

The class atm combines the 2 split DeepseekV3TopkRouter and DeepseekV3MoE.route_tokens_to_experts() to serve as a drop-in replacement in MixtralSparseMoeBlock, and simply override self.gate. This follows the newer pattern used by MiniMax-M2 for fused expert etc...

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 2, 2026

Will try to take a look, likely next week after torch conference. Bad timing with holidays + conference sorry 😢 appreciate the work tho, just as a heads up that things are a bit delayed

@vasqu vasqu added the New model label Apr 2, 2026
@casinca
Copy link
Copy Markdown
Contributor Author

casinca commented Apr 2, 2026

Will try to take a look, likely next week after torch conference. Bad timing with holidays + conference sorry 😢 appreciate the work tho, just as a heads up that things are a bit delayed

Ah true, np, I understand. Thanks for letting me know.

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some initial comments: I think the biggest point is re attention I'd rather disable the attention types than to support all these paths - it gets too messy

Comment thread docs/source/en/model_doc/mimo_v2_flash.md Outdated
Comment thread docs/source/en/model_doc/mimo_v2_flash.md Outdated
Comment thread docs/source/en/model_doc/mimo_v2_flash.md Outdated
Comment thread docs/source/en/model_doc/mimo_v2_flash.md Outdated
Comment thread docs/source/en/model_doc/mimo_v2_flash.md
Comment thread src/transformers/models/mimo_v2_flash/modular_mimo_v2_flash.py Outdated
Comment thread src/transformers/models/mimo_v2_flash/modular_mimo_v2_flash.py Outdated
Comment thread src/transformers/models/mimo_v2_flash/modular_mimo_v2_flash.py Outdated
Comment thread src/transformers/models/mimo_v2_flash/modular_mimo_v2_flash.py Outdated
Comment thread src/transformers/models/mimo_v2_flash/modular_mimo_v2_flash.py Outdated
@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 15, 2026

Feel free to ping me when it's ready for another review 🤗

@casinca
Copy link
Copy Markdown
Contributor Author

casinca commented Apr 27, 2026

👋 @vasqu , I applied your suggestions and replied to questions.

  • There was a sudden fix needed in the attention class, Xiaomi fixed it just a few days ago in their remote, concerning the rescaling of Values, the hparam attention_value_scale was in the config but never used.
    I had recently lost remote:native logits parity and after investigating this turned out to be the cause. (I also some left comments as "unresolved" for more details)
  • The integration test is copied from Minimax. It's fine on my end, but strings/logits will have to be regenerated.

Thanks

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks super good, only some smaller details tbh. I do want to check for FA3 (vllm kernel compatibility) before merge; pretty sure they could work on hopper

Comment thread src/transformers/models/mimo_v2_flash/modular_mimo_v2_flash.py Outdated
Comment thread src/transformers/models/mimo_v2_flash/modular_mimo_v2_flash.py Outdated
Comment thread src/transformers/models/mimo_v2_flash/modular_mimo_v2_flash.py
Comment thread src/transformers/models/mimo_v2_flash/modular_mimo_v2_flash.py
Comment thread src/transformers/models/mimo_v2_flash/modular_mimo_v2_flash.py Outdated
@auto_docstring
class MiMoV2FlashPreTrainedModel(DeepseekV3PreTrainedModel):
_supports_sdpa = False # disabling SDPA as it has no sink API atm (same as gpt-oss)
_supports_flash_attn = False # not compatible because of asymmetric qk/v head dims
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to recheck but I believe that FA3 with sinks support (the vllm kernel) and FA4 might support it (less sure about FA4) - checkout gpt oss where we have a _compatible_xxx flag for this

Let me test it before we merge, gonna need hopper gpus

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks that was a good suggestion (I didn't pass the extra FA3/FA4 flags, only default FA2 and was on an A100)
So now I retried including FA3/FA4 on an H100

FA2/FA3:

  • non hopper arch: doesn't work bc of asymmetric head dim
  • hopper arch: asymmetric head dim isn't a blocker anymore but combo: "asymmetric head dim + sinks" is a blocker
    RuntimeError: We don't support S aux with hdim != hdim_v

FA4: works on hopper

I'll re-enable all backends even flex, to make sure on your end.

class MiMoV2FlashPreTrainedModel(DeepseekV3PreTrainedModel):
_supports_sdpa = False # disabling SDPA as it has no sink API atm (same as gpt-oss)
_supports_flash_attn = False # not compatible because of asymmetric qk/v head dims
_supports_flex_attn = False # same as FA2 + head dim not being a power of 2
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also unsure about this: did you try a forward with the actual model (but less hidden layers for example). Not sure if it really has that limitation

Copy link
Copy Markdown
Contributor Author

@casinca casinca Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I ran the forward with the tiny-mimo-v2-flash in my hub (5 layers), you have the full config in the model card: https://huggingface.co/casinca/tiny-mimo-v2-flash

I was blocked for flex by the combo: not power of 2 + decoupled qk/v head dims on torch 2.6 but torch 2.10, only decoupled qk/v seems to be a blocker.

    raise self.create_no_valid_choices(name, "No choices exist for backend.")
torch._inductor.exc.InductorError: LoweringException: NoValidChoicesError: No choices to select. Provided reason: No choices exist for backend. please consider adding ATEN into max_autotune_gemm_backends config (defined in torch/_inductor/config.py) to allow at least one choice.
  target: flex_attention
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.bfloat16, size=[2, 16, 11, 192], stride=[33792, 2112, 192, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.bfloat16, size=[2, 1, 11, 192], stride=[2112, 2112, 192, 1]))
  ))
  args[2]: TensorBox(StorageBox(
    InputBuffer(name='arg2_1', layout=FixedLayout('cuda:0', torch.bfloat16, size=[2, 1, 11, 128], stride=[1408, 128, 128, 1]))
  ))

Comment thread src/transformers/models/mimo_v2_flash/modular_mimo_v2_flash.py Outdated
Comment thread tests/models/mimo_v2_flash/test_modeling_mimo_v2_flash.py Outdated
Comment thread tests/models/mimo_v2_flash/test_modeling_mimo_v2_flash.py
@casinca
Copy link
Copy Markdown
Contributor Author

casinca commented Apr 28, 2026

@vasqu I applied the remaining suggestions, and re-enabled all backends so you can check on your side.
I left comments about backends unresolved so you can see my new results with hopper gpu. Hope that helps.

Let me know and I'll re-update the attn backend flags accordingly. Thanks.

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, mimo_v2_flash

@casinca
Copy link
Copy Markdown
Contributor Author

casinca commented Apr 29, 2026

Nice, kudos to the Xiaomi team for open sourcing MiMo-V2.5-Pro 🙏 (currently 1st OSS, tied with Kimi K2.6 on the AA Intelligence index).

I've delved into the V2.5-Pro architecture a bit, some differences with V2-Flash (not really architectural, mostly hparam tuning):

  • QKV linears are fused
  • SWA and full attn layers have an equal number of KV heads
  • SWA/full attn pattern is slightly changed (not a clean modulo, ie every 6th layers - more like every ~7th/8th)
  • rest are some hparam value changes like rope thetas or attention_value_scale and bump in size

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Close to merging :D nice job! Let's wrap up the last details and then I cross check with our CI for final values

pass


def eager_attention_forward_with_optional_sink(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def eager_attention_forward_with_optional_sink(
def eager_attention_forward(

let's keep the same name tbh, no need to add the optional sink stuff

Comment on lines +308 to +310
_supports_flash_attn = True # not compatible because of asymmetric qk/v head dims
_supports_flex_attn = True # same as FA2 + head dim not being a power of 2
_compatible_flash_implementations = ["kernels-community/vllm-flash-attn3", "flash_attention_4"]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_supports_flash_attn = True # not compatible because of asymmetric qk/v head dims
_supports_flex_attn = True # same as FA2 + head dim not being a power of 2
_compatible_flash_implementations = ["kernels-community/vllm-flash-attn3", "flash_attention_4"]
_supports_flash_attn = True # not compatible because of asymmetric qk/v head dims and/or sinks
_supports_flex_attn = True # asymmetric head dim + not being a power of 2
_compatible_flash_implementations = ["flash_attention_4"]

trusting you on this no worries, let's keep FA4 only then



@auto_docstring
class MiMoV2FlashModel(MixtralModel):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check out if we can inherit from LagunaModel? It's a new model but I feel like it matches 1:1 with an extra optional lazy mask creation which doesnt hurt

def test_model_rope_scaling_from_config(self, scaling_type):
pass

def test_model_rope_scaling_frequencies(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ig this is copied from gemma3? can we add a comment to this test if so that we have some reference


EXPECTED_LOGITS_LEFT_UNPADDED = Expectations(
{
("cuda", 8): [
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a minor version so e.g. (8, 6)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

XiaoMi MiMo

3 participants