feat(models): Make MimiModel encoding padding-aware to ensure batch-to-individual consistency#43378
feat(models): Make MimiModel encoding padding-aware to ensure batch-to-individual consistency#43378harshaljanjani wants to merge 5 commits intohuggingface:mainfrom
Conversation
|
Following up on this PR, happy to make changes or add context if helpful! |
eustlb
left a comment
There was a problem hiding this comment.
Hey @harshaljanjani, thanks a lot for your PR.
Unfortunately, this approach is the one we want to avoid, since it infers the model sequentially. The TODO that is there requires handling in a batch manner.
For such a usage, it's better to have this for loop outside the model than doing it hidden from the user inside the forward and pretend a batch inference which is actually not one.
|
Thanks a ton for your time! I'll take some time to think through how we could support a truly batched approach from the ground up for Mimi and then get back to you with the code changes if that’s alright; really appreciate the direction :)) |
|
|
||
| # TODO: @eustlb, let's make the encoder support padding_mask so that batched inputs are supported. | ||
| embeddings = self.encoder(input_values, padding_cache=padding_cache) | ||
| input_lengths = None |
There was a problem hiding this comment.
Leaving some reasoning here, I took a deeper look and wrote the trace script on main, they diverged at transformer position 12 (2.60e-02), downsample position 6 (0.33). The root cause is that the conv bias produces non-zero garbage at padded positions, and then later the strided convolutions at the boundary mix it into valid outputs (pytorch/audio#2242 documents the identical issue with wav2vec2).
Change adapts wav2vec2 patterns:
→ Copied modeling_wav2vec2.py#L680-L683 (zero padded tokens), adapted to run inside MimiEncoder.forward() after every layer using a time mask.
→ Copied modeling_wav2vec2.py#L1006-L1025 (compute output lengths after convs), adapted it to iterate over _mimiconv1d_layer_names and call _get_output_length.
→ Copied modeling_wav2vec2.py#L1027-L1045 (build attention mask from lengths), adapted it to use torch.arange(...) < encoder_output_lengths.
→ Also, just noting this down since it's Mimi-specific, the downsample uses pad_mode="replicate", so garbage positions should contain the last valid embedding (gather + torch.where) to match individual encoding behavior.
All ops are batched (no per-sample loops) now, happy to make further changes if needed :)
|
Did some RCA to cleanly identify where the divergence actually stemmed from, I've left the analysis in the comment along with the code here, and jot down the reasoning to make it batched from the ground up. Hope you'll find this line of reasoning well-informed; if not, I'm happy to correct it until production-ready and make further changes :) TRACE CODE: import torch
import numpy as np
from transformers import AutoModel, AutoFeatureExtractor
model_id = "kyutai/mimi"
fe = AutoFeatureExtractor.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id)
model.eval()
sr = fe.sampling_rate
a1 = np.sin(2 * np.pi * 440 * np.linspace(0, 1.0, int(sr * 1.0))).astype(np.float32)
a2 = np.sin(2 * np.pi * 554 * np.linspace(0, 0.5, int(sr * 0.5))).astype(np.float32)
inp2 = fe(a2, sampling_rate=sr, return_tensors="pt")
inp_b = fe([a1, a2], sampling_rate=sr, return_tensors="pt", padding=True)
with torch.no_grad():
emb_s = model.encoder(inp2["input_values"])
emb_b = model.encoder(inp_b["input_values"])
pm = inp_b["padding_mask"]
pm2d = pm.any(dim=1) if pm.dim() == 3 else pm
lengths = pm2d.sum(dim=-1)
enc_lens = lengths.clone()
for ln in model.encoder._mimiconv1d_layer_names:
enc_lens = model.encoder.get_submodule(ln)._get_output_length(enc_lens)
vl = enc_lens[1].item()
diff = torch.abs(emb_s[0, :, :vl] - emb_b[1, :, :vl]).max().item()
print(f"enc valid pos max diff: {diff:.6e}")
attn = torch.arange(emb_b.shape[-1]).unsqueeze(0) < enc_lens.unsqueeze(1)
t_s = model.encoder_transformer(emb_s.transpose(1, 2), return_dict=True)
t_b = model.encoder_transformer(emb_b.transpose(1, 2), attention_mask=attn, return_dict=True)
t_so, t_bo = t_s["last_hidden_state"].transpose(1, 2), t_b["last_hidden_state"].transpose(1, 2)
for t in range(vl):
d = torch.abs(t_so[0, :, t] - t_bo[1, :, t]).max().item()
tag = " ← DIVERGES" if d > 1e-4 else ""
print(f"→ transformer pos {t}: {d:.6e}{tag}")
ds_lens = model.downsample._get_output_length(enc_lens)
ds_s, ds_b = model.downsample(t_so), model.downsample(t_bo)
dl = ds_lens[1].item()
for t in range(dl):
d = torch.abs(ds_s[0, :, t] - ds_b[1, :, t]).max().item()
tag = " ← DIVERGES" if d > 1e-4 else ""
print(f"→ downsample pos {t}: {d:.6e}{tag}")
c_s, c_b = model.quantizer.encode(ds_s), model.quantizer.encode(ds_b)
mis = (c_s[:, 0, :dl] != c_b[:, 1, :dl]).any(dim=0).sum().item()
print(f"codebook mismatches: {mis}/{dl} positions")
k = model.downsample.kernel_size.item()
s = model.downsample.stride.item()
p = model.downsample.padding_total.item()
last = dl - 1
max_pos = last * s - p + k - 1
print(f"downsample last valid pos accesses encoder pos up to {max_pos}, valid range ends at {vl - 1}")
|
|
Following up on this PR, happy to make changes or add context if helpful! |
|
@eustlb Saw merge conflicts today and resolved them; I double-checked that there are no regressions or additional failing tests and the repro is intact + the reasoning still stands and is ready for review :) |
|
@eustlb Just a gentle ping :) |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: mimi |
|
@eustlb Just a gentle ping; got a notif that the issue was unfortunately marked as stale, but the PR is ready for review :) |
|
Hey @harshaljanjani, thanks for the work and for iterating on this! |
|
@eustlb Thank you so much for the update; I completely understand, and appreciate you looping me in 🤗❤️ |

What does this PR do?
The following issues were identified and fixed in this PR:
→
MimiModelincorrectly processed batched inputs with different lengths because the_encode_framemethod wasn't padding-aware, leading to significant output discrepancies when tested (see the repro in the issue description).→ Implemented a fix that slices inputs to their actual lengths during the encoding stage, pads the resulting embeddings, and provides a proper
attention_maskto theMimiTransformerModelinstance.→ Updated the
_encode_frameandencodelogic to ensure that batched encoding matches single encoding within a1e-5threshold.Fixes #43377.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.