Skip to content

feat(models): Make MimiModel encoding padding-aware to ensure batch-to-individual consistency#43378

Open
harshaljanjani wants to merge 5 commits intohuggingface:mainfrom
harshaljanjani:fix/mimi-batch-correctness
Open

feat(models): Make MimiModel encoding padding-aware to ensure batch-to-individual consistency#43378
harshaljanjani wants to merge 5 commits intohuggingface:mainfrom
harshaljanjani:fix/mimi-batch-correctness

Conversation

@harshaljanjani
Copy link
Copy Markdown
Contributor

@harshaljanjani harshaljanjani commented Jan 20, 2026

What does this PR do?

The following issues were identified and fixed in this PR:

MimiModel incorrectly processed batched inputs with different lengths because the _encode_frame method wasn't padding-aware, leading to significant output discrepancies when tested (see the repro in the issue description).
→ Implemented a fix that slices inputs to their actual lengths during the encoding stage, pads the resulting embeddings, and provides a proper attention_mask to the MimiTransformerModel instance.
→ Updated the _encode_frame and encode logic to ensure that batched encoding matches single encoding within a 1e-5 threshold.

Fixes #43377.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

@harshaljanjani harshaljanjani marked this pull request as ready for review January 20, 2026 18:45
@github-actions github-actions Bot requested a review from eustlb January 20, 2026 18:45
@Rocketknight1
Copy link
Copy Markdown
Member

cc @eustlb and @ebezzam since I think this addresses your TODOs in the code!

@harshaljanjani harshaljanjani changed the title fix(models): Make MimiModel encoding padding-aware to ensure batch-to-individual consistency feat(models): Make MimiModel encoding padding-aware to ensure batch-to-individual consistency Jan 26, 2026
@eustlb eustlb added the Audio label Jan 29, 2026
@eustlb eustlb self-assigned this Jan 29, 2026
@harshaljanjani
Copy link
Copy Markdown
Contributor Author

Following up on this PR, happy to make changes or add context if helpful!

Copy link
Copy Markdown
Contributor

@eustlb eustlb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @harshaljanjani, thanks a lot for your PR.

Unfortunately, this approach is the one we want to avoid, since it infers the model sequentially. The TODO that is there requires handling in a batch manner.

For such a usage, it's better to have this for loop outside the model than doing it hidden from the user inside the forward and pretend a batch inference which is actually not one.

@harshaljanjani
Copy link
Copy Markdown
Contributor Author

Thanks a ton for your time! I'll take some time to think through how we could support a truly batched approach from the ground up for Mimi and then get back to you with the code changes if that’s alright; really appreciate the direction :))


# TODO: @eustlb, let's make the encoder support padding_mask so that batched inputs are supported.
embeddings = self.encoder(input_values, padding_cache=padding_cache)
input_lengths = None
Copy link
Copy Markdown
Contributor Author

@harshaljanjani harshaljanjani Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving some reasoning here, I took a deeper look and wrote the trace script on main, they diverged at transformer position 12 (2.60e-02), downsample position 6 (0.33). The root cause is that the conv bias produces non-zero garbage at padded positions, and then later the strided convolutions at the boundary mix it into valid outputs (pytorch/audio#2242 documents the identical issue with wav2vec2).

Change adapts wav2vec2 patterns:
→ Copied modeling_wav2vec2.py#L680-L683 (zero padded tokens), adapted to run inside MimiEncoder.forward() after every layer using a time mask.
→ Copied modeling_wav2vec2.py#L1006-L1025 (compute output lengths after convs), adapted it to iterate over _mimiconv1d_layer_names and call _get_output_length.
→ Copied modeling_wav2vec2.py#L1027-L1045 (build attention mask from lengths), adapted it to use torch.arange(...) < encoder_output_lengths.
→ Also, just noting this down since it's Mimi-specific, the downsample uses pad_mode="replicate", so garbage positions should contain the last valid embedding (gather + torch.where) to match individual encoding behavior.

All ops are batched (no per-sample loops) now, happy to make further changes if needed :)

@harshaljanjani
Copy link
Copy Markdown
Contributor Author

harshaljanjani commented Feb 10, 2026

Did some RCA to cleanly identify where the divergence actually stemmed from, I've left the analysis in the comment along with the code here, and jot down the reasoning to make it batched from the ground up. Hope you'll find this line of reasoning well-informed; if not, I'm happy to correct it until production-ready and make further changes :)

TRACE CODE:

import torch
import numpy as np
from transformers import AutoModel, AutoFeatureExtractor

model_id = "kyutai/mimi"
fe = AutoFeatureExtractor.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id)
model.eval()
sr = fe.sampling_rate
a1 = np.sin(2 * np.pi * 440 * np.linspace(0, 1.0, int(sr * 1.0))).astype(np.float32)
a2 = np.sin(2 * np.pi * 554 * np.linspace(0, 0.5, int(sr * 0.5))).astype(np.float32)
inp2 = fe(a2, sampling_rate=sr, return_tensors="pt")
inp_b = fe([a1, a2], sampling_rate=sr, return_tensors="pt", padding=True)

with torch.no_grad():
  emb_s = model.encoder(inp2["input_values"])
  emb_b = model.encoder(inp_b["input_values"])
  pm = inp_b["padding_mask"]
  pm2d = pm.any(dim=1) if pm.dim() == 3 else pm
  lengths = pm2d.sum(dim=-1)
  enc_lens = lengths.clone()
  for ln in model.encoder._mimiconv1d_layer_names:
      enc_lens = model.encoder.get_submodule(ln)._get_output_length(enc_lens)
  vl = enc_lens[1].item()
  diff = torch.abs(emb_s[0, :, :vl] - emb_b[1, :, :vl]).max().item()
  print(f"enc valid pos max diff: {diff:.6e}")
  attn = torch.arange(emb_b.shape[-1]).unsqueeze(0) < enc_lens.unsqueeze(1)
  t_s = model.encoder_transformer(emb_s.transpose(1, 2), return_dict=True)
  t_b = model.encoder_transformer(emb_b.transpose(1, 2), attention_mask=attn, return_dict=True)
  t_so, t_bo = t_s["last_hidden_state"].transpose(1, 2), t_b["last_hidden_state"].transpose(1, 2)
  for t in range(vl):
      d = torch.abs(t_so[0, :, t] - t_bo[1, :, t]).max().item()
      tag = " ← DIVERGES" if d > 1e-4 else ""
      print(f"→ transformer pos {t}: {d:.6e}{tag}")
  ds_lens = model.downsample._get_output_length(enc_lens)
  ds_s, ds_b = model.downsample(t_so), model.downsample(t_bo)
  dl = ds_lens[1].item()
  for t in range(dl):
      d = torch.abs(ds_s[0, :, t] - ds_b[1, :, t]).max().item()
      tag = " ← DIVERGES" if d > 1e-4 else ""
      print(f"→ downsample pos {t}: {d:.6e}{tag}")
  c_s, c_b = model.quantizer.encode(ds_s), model.quantizer.encode(ds_b)
  mis = (c_s[:, 0, :dl] != c_b[:, 1, :dl]).any(dim=0).sum().item()
  print(f"codebook mismatches: {mis}/{dl} positions")
  k = model.downsample.kernel_size.item()
  s = model.downsample.stride.item()
  p = model.downsample.padding_total.item()
  last = dl - 1
  max_pos = last * s - p + k - 1
  print(f"downsample last valid pos accesses encoder pos up to {max_pos}, valid range ends at {vl - 1}")
image

@harshaljanjani
Copy link
Copy Markdown
Contributor Author

Following up on this PR, happy to make changes or add context if helpful!

@harshaljanjani
Copy link
Copy Markdown
Contributor Author

@eustlb Saw merge conflicts today and resolved them; I double-checked that there are no regressions or additional failing tests and the repro is intact + the reasoning still stands and is ready for review :)

@harshaljanjani
Copy link
Copy Markdown
Contributor Author

@eustlb Just a gentle ping :)

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Mar 5, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: mimi

@harshaljanjani
Copy link
Copy Markdown
Contributor Author

@eustlb Just a gentle ping; got a notif that the issue was unfortunately marked as stale, but the PR is ready for review :)
I've left replies here and here for how I tried to resolve your comment :)

@eustlb
Copy link
Copy Markdown
Contributor

eustlb commented Mar 9, 2026

Hey @harshaljanjani, thanks for the work and for iterating on this!
And thanks for your patience. This needs a bit more attention than just reviewing this specific PR, as I’d prefer not to merge it as a hotfix but rather standardize the approach. I’ll review it ASAP, and whether it lands through this PR or another one, I’ll make sure to include you as a co-author.

@harshaljanjani
Copy link
Copy Markdown
Contributor Author

@eustlb Thank you so much for the update; I completely understand, and appreciate you looping me in 🤗❤️
I'll keep an eye out for whichever direction you choose to take!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] MIMI Encoder produces different outputs for batched vs single inputs due to missing padding mask support

3 participants