Skip to content

Fix Seq2SeqLM ExecuTorch export: add encoder_attention_mask to decoder and use static encoder shapes#45523

Open
duyhv-qualgo wants to merge 2 commits intohuggingface:mainfrom
duyhv-qualgo:fix/seq2seq-decoder-encoder-attention-mask
Open

Fix Seq2SeqLM ExecuTorch export: add encoder_attention_mask to decoder and use static encoder shapes#45523
duyhv-qualgo wants to merge 2 commits intohuggingface:mainfrom
duyhv-qualgo:fix/seq2seq-decoder-encoder-attention-mask

Conversation

@duyhv-qualgo
Copy link
Copy Markdown

@duyhv-qualgo duyhv-qualgo commented Apr 20, 2026

Problem

Two related bugs in src/transformers/integrations/executorch.py that break seq2seq (T5) ExecuTorch export:

Bug 1 — encoder_attention_mask not forwarded to decoder

Seq2SeqLMDecoderExportableModuleWithStaticCache.forward calls self.decoder(...) without passing encoder_attention_mask. For T5, the cross-attention module computes a relative position bias scaled by key_length. When no mask is provided, key_length equals the full padded sequence length (e.g. 512) instead of the real encoder output length. This causes a ~20× logit scale error and completely wrong greedy-decoding outputs.

Verified: exporting and running T5 without this fix produces semantically wrong translations. With the fix, ExecuTorch output matches HuggingFace model.generate() exactly (5/5 test cases, exact string match).

Bug 2 — Dynamic encoder dim conflicts with static KV cache

Seq2SeqLMExportableModule._export_decoder marks encoder_hidden_states dim-1 as a dynamic symbol (encoder_hidden_seq_length). With transformers 5.0, T5's cross-attention causal mask slices against the static KV-cache size:

causal_mask = mask[:, :, :, : key_states.shape[-2]]  # static int
position_bias = position_bias + causal_mask           # symbolic dim → conflict

This raises RuntimeError: tensor a (1024) must match tensor b (s96) during torch.export.

Fix: remove the dynamic encoder dim; callers pad encoder inputs to max_cache_len (the static export shape), which is the correct assumption for static-shape ExecuTorch deployment.

Changes

  • Seq2SeqLMDecoderExportableModuleWithStaticCache.forward: add encoder_attention_mask: Tensor | None = None parameter and pass it to self.decoder(...).
  • Seq2SeqLMExportableModule._export_decoder: accept and pass encoder_attention_mask; remove dynamic encoder sequence length shape (use dynamic_shapes=None).
  • Seq2SeqLMExportableModule.export(): pass encoder_attention_mask (default: all-ones of shape [batch, max_cache_len]).
  • Seq2SeqLMExportableModule.generate(): build encoder_attention_mask from prompt_token_ids != 0 and pass it to each decoder step.

Test plan

  • Verified on a custom Helsinki-NLP–style T5 seq2seq checkpoint: ExecuTorch fp32 export produces exact token-level match with model.generate() on 5 translation test cases after this fix
  • Existing tests/models/test_modeling_t5.py @slow tests (if CI can run them)
  • tests/integrations/test_executorch.py — no seq2seq-specific tests exist yet; a follow-up can add one

Who can review?

cc @vasqu @Cyrilvallez

…r and use static encoder shapes

Two related bugs in the seq2seq ExecuTorch export path:

1. `Seq2SeqLMDecoderExportableModuleWithStaticCache.forward` did not pass
   `encoder_attention_mask` to the decoder stack. For T5 (and any model
   using relative position bias scaled by key_length), omitting this mask
   causes the bias to be computed over the full padded sequence length
   rather than the real encoder length, producing ~20× logit scale errors
   and wrong greedy-decoding outputs.

2. `Seq2SeqLMExportableModule._export_decoder` marked `encoder_hidden_states`
   dim-1 as dynamic (`encoder_hidden_seq_length`). With transformers 5.0 the
   static KV-cache size is a compile-time constant; a symbolic encoder dim
   creates a shape conflict during `torch.export` for models like T5 that
   slice the cross-attention causal mask against the cache size.

Fix:
- Add optional `encoder_attention_mask` parameter to
  `Seq2SeqLMDecoderExportableModuleWithStaticCache.forward` and thread it
  through to `self.decoder(...)`.
- Remove the dynamic encoder dim in `_export_decoder`; callers are expected
  to pad encoder inputs to `max_cache_len` (the static export shape).
- Update `Seq2SeqLMExportableModule.export()` and `generate()` to build and
  pass the encoder attention mask automatically.
@duyhv-qualgo duyhv-qualgo force-pushed the fix/seq2seq-decoder-encoder-attention-mask branch from ad2e86c to 373f55c Compare April 20, 2026 05:27
@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 23, 2026

@Cyrilvallez could you check it out

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants