-
Notifications
You must be signed in to change notification settings - Fork 641
Description
Describe the bug
using JAX vmap results in extra dimension between segment ids and segment pos if initiated with the defaults of the library. This can by bypassed by explicitly passing segment_pos on the user side.
Steps/Code to reproduce bug
using fix_vmap=True works
fails when fix_vmap=False:
def _flash_gpu_attention(
query: jax.Array,
key: jax.Array,
value: jax.Array,
q_segment_ids: Optional[jax.Array],
kv_segment_ids: Optional[jax.Array],
scale_factor: float,
block_sizes: Dict[str, Any],
fix_vmap: bool,
):
# So this imports won't hurt TPU case
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error
from transformer_engine.jax.attention import SequenceDescriptor # pytype: disable=import-error
if q_segment_ids is None and kv_segment_ids is None:
segment_ids = None
else:
batch_size = query.shape[0]
q_seq = query.shape[2]
kv_seq = key.shape[2]
if q_segment_ids is None and kv_segment_ids is not None:
q_segment_ids = jnp.ones((batch_size, q_seq), dtype=jnp.int32)
if kv_segment_ids is None and q_segment_ids is not None:
kv_segment_ids = jnp.ones((batch_size, kv_seq), dtype=jnp.int32)
segment_pos = None
if fix_vmap:
def generate_default_pos(segment_ids):
return jnp.where(segment_ids > 0, jnp.arange(segment_ids.shape[-1]), 0)
q_segment_pos = generate_default_pos(q_segment_ids)
kv_segment_pos = generate_default_pos(kv_segment_ids)
segment_pos = (q_segment_pos, kv_segment_pos)
segment_ids = SequenceDescriptor.from_segment_ids_and_pos(
segment_ids=(q_segment_ids, kv_segment_ids),
segment_pos=segment_pos,
)
head_dim = query.shape[-1]
with nvidia_ctx():
dpa_layer = DotProductAttention(
head_dim=head_dim,
num_attention_heads=query.shape[1],
num_gqa_groups=key.shape[1],
attn_mask_type="no_mask", # 'no_mask', 'padding', 'causal', or 'padding_causal'
attn_bias_type="no_bias", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias'
attention_dropout=0,
dropout_rng_name="aqt",
dtype=jax.numpy.bfloat16,
scale_factor=scale_factor,
**block_sizes,
)
query = query.swapaxes(1, 2).astype(jax.numpy.bfloat16)
key = key.swapaxes(1, 2).astype(jax.numpy.bfloat16)
value = value.swapaxes(1, 2).astype(jax.numpy.bfloat16)
output: jax.Array = dpa_layer.apply({}, query, key, value, segment_ids)
output = output.swapaxes(1, 2)
return output
Expected behavior
We dont expect a crash when segment pos are not explicitly supplied by the user.
Currently the code crashes at:
.venv/lib/python3.12/site-packages/transformer_engine/jax/attention.py", line 700, in get_seqlens_and_offsets
assert q_segment_ids.shape == q_segment_pos.shape
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Environment overview (please complete the following information)
- Environment location: [Docker]
- Method of Transformer Engine install: [pip install]. Please specify exact commands you used to install.
Command to install using uv manager:
uv sync --no-install-project
Project has these dependencies:
dependencies = [
"jax[cuda12,k8s]==0.9.0.1",
"jax-cuda12-plugin",
"jax-cuda12-pjrt",
"jaxlib",
"flax==0.12.1",
"triton; sys_platform == 'linux'", # every group use different version according to other dependencies
"transformer-engine-cu12>=2.11.0",
"transformer-engine-jax>=2.11.0",
"transformer-engine>=2.11.0",
- If method of install is [Docker], provide
docker pull&docker runcommands used
Base docker: nvidia/cuda:12.8.1-devel-ubuntu22.04
Install following dependencies:
sudo apt-get update && sudo apt-get install -y -q --fix-missing
libnccl-dev
libcudnn9-dev-cuda-12
libcudnn9-cuda-12
Environment details
If NVIDIA docker image is used you don't need to specify these.
ax: 0.9.0.1
jaxlib: 0.9.0.1
numpy: 2.4.2
python: 3.12.12 (main, Oct 14 2025, 21:25:31) [Clang 20.1.4 ]
device info: NVIDIA H100 80GB HBM3-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='h100-1-7b5b844db4-7zdng', release='6.6.97+', version='#1 SMP Fri Aug 22 11:53:37 UTC 2025', machine='x86_64')
XLA_PYTHON_CLIENT_ALLOCATOR=platform
XLA_PYTHON_CLIENT_MEM_FRACTION=0.75
XLA_PYTHON_CLIENT_PREALLOCATE=false
Device details
- GPU model - H100 but applicable for other datacenter GPUs like GB200
Additional context
Add any other context about the problem here.