Skip to content

JAX vmap issue with TE Attention #2685

@jmukherjee-nvidia

Description

@jmukherjee-nvidia

Describe the bug

using JAX vmap results in extra dimension between segment ids and segment pos if initiated with the defaults of the library. This can by bypassed by explicitly passing segment_pos on the user side.

Steps/Code to reproduce bug
using fix_vmap=True works
fails when fix_vmap=False:

def _flash_gpu_attention(
query: jax.Array,
key: jax.Array,
value: jax.Array,
q_segment_ids: Optional[jax.Array],
kv_segment_ids: Optional[jax.Array],
scale_factor: float,
block_sizes: Dict[str, Any],
fix_vmap: bool,
):
# So this imports won't hurt TPU case
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error
from transformer_engine.jax.attention import SequenceDescriptor # pytype: disable=import-error

    if q_segment_ids is None and kv_segment_ids is None:
        segment_ids = None
    else:
        batch_size = query.shape[0]
        q_seq = query.shape[2]
        kv_seq = key.shape[2]
        if q_segment_ids is None and kv_segment_ids is not None:
            q_segment_ids = jnp.ones((batch_size, q_seq), dtype=jnp.int32)
        if kv_segment_ids is None and q_segment_ids is not None:
            kv_segment_ids = jnp.ones((batch_size, kv_seq), dtype=jnp.int32)
        
        segment_pos = None
        if fix_vmap:
            def generate_default_pos(segment_ids):
                return jnp.where(segment_ids > 0, jnp.arange(segment_ids.shape[-1]), 0)
            
            q_segment_pos = generate_default_pos(q_segment_ids)
            kv_segment_pos = generate_default_pos(kv_segment_ids)
            segment_pos = (q_segment_pos, kv_segment_pos)
            
        segment_ids = SequenceDescriptor.from_segment_ids_and_pos(
            segment_ids=(q_segment_ids, kv_segment_ids),
            segment_pos=segment_pos,
        )

    head_dim = query.shape[-1]

    with nvidia_ctx():
        dpa_layer = DotProductAttention(
            head_dim=head_dim,
            num_attention_heads=query.shape[1],
            num_gqa_groups=key.shape[1],
            attn_mask_type="no_mask",  # 'no_mask', 'padding', 'causal', or 'padding_causal'
            attn_bias_type="no_bias",  # 'no_bias', 'pre_scale_bias' or 'post_scale_bias'
            attention_dropout=0,
            dropout_rng_name="aqt",
            dtype=jax.numpy.bfloat16,
            scale_factor=scale_factor,
            **block_sizes,
        )
        query = query.swapaxes(1, 2).astype(jax.numpy.bfloat16)
        key = key.swapaxes(1, 2).astype(jax.numpy.bfloat16)
        value = value.swapaxes(1, 2).astype(jax.numpy.bfloat16)

        output: jax.Array = dpa_layer.apply({}, query, key, value, segment_ids)
        output = output.swapaxes(1, 2)
        return output

Expected behavior

We dont expect a crash when segment pos are not explicitly supplied by the user.

Currently the code crashes at:
.venv/lib/python3.12/site-packages/transformer_engine/jax/attention.py", line 700, in get_seqlens_and_offsets
assert q_segment_ids.shape == q_segment_pos.shape
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Environment overview (please complete the following information)

  • Environment location: [Docker]
  • Method of Transformer Engine install: [pip install]. Please specify exact commands you used to install.

Command to install using uv manager:
uv sync --no-install-project

Project has these dependencies:
dependencies = [
"jax[cuda12,k8s]==0.9.0.1",
"jax-cuda12-plugin",
"jax-cuda12-pjrt",
"jaxlib",
"flax==0.12.1",
"triton; sys_platform == 'linux'", # every group use different version according to other dependencies
"transformer-engine-cu12>=2.11.0",
"transformer-engine-jax>=2.11.0",
"transformer-engine>=2.11.0",

  • If method of install is [Docker], provide docker pull & docker run commands used

Base docker: nvidia/cuda:12.8.1-devel-ubuntu22.04

Install following dependencies:
sudo apt-get update && sudo apt-get install -y -q --fix-missing
libnccl-dev
libcudnn9-dev-cuda-12
libcudnn9-cuda-12

Environment details

If NVIDIA docker image is used you don't need to specify these.
ax: 0.9.0.1
jaxlib: 0.9.0.1
numpy: 2.4.2
python: 3.12.12 (main, Oct 14 2025, 21:25:31) [Clang 20.1.4 ]
device info: NVIDIA H100 80GB HBM3-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='h100-1-7b5b844db4-7zdng', release='6.6.97+', version='#1 SMP Fri Aug 22 11:53:37 UTC 2025', machine='x86_64')
XLA_PYTHON_CLIENT_ALLOCATOR=platform
XLA_PYTHON_CLIENT_MEM_FRACTION=0.75
XLA_PYTHON_CLIENT_PREALLOCATE=false

Device details

  • GPU model - H100 but applicable for other datacenter GPUs like GB200

Additional context

Add any other context about the problem here.

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions