Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
from jax import value_and_grad, jit

from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
from transformer_engine_jax import get_device_compute_capability

# Type annotations
Array = jnp.ndarray
Expand Down Expand Up @@ -146,8 +147,6 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs)
return cross_fused_attn(q, kv, mask, dropout_rng, **kwargs)


@pytest.mark.skipif(not is_fused_attn_kernel_available(),
reason="Fused attention kernel is not supported.")
@pytest.mark.parametrize('b, s, h, d', SELF_CASES)
@pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS])
@pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
Expand All @@ -159,13 +158,14 @@ class TestSelfFusedAttn():
"""Tests for transformer_engine.jax.fused_attn.self_fused_attn"""

@staticmethod
def _check_inputs(s, *, attn_bias_type, attn_mask_type, backend, pad_ratio):
# Arbitrary seqlen backend has a limited spec for now
# No bias, only causal mask, and no variable seqlen
if (s > 512 or backend == Backend.Arbitrary) and (attn_bias_type != AttnBiasType.NO_BIAS or
attn_mask_type != AttnMaskType.CAUSAL_MASK
or pad_ratio != 0):
pytest.skip("Unsupported inputs combination.")
def _check_inputs(s, *, attn_bias_type, attn_mask_type, backend, dropout_probability, dtype,
head_dim, pad_ratio):
if (s > 512 or backend == Backend.Arbitrary) and pad_ratio != 0:
pytest.skip("Arbitrary seqlen backend hasn't support padded input.")

if not is_fused_attn_kernel_available(dtype, dtype, attn_bias_type, attn_mask_type,
dropout_probability, s, s, head_dim):
pytest.skip("Unsupported inputs combination or device compute capability.")

def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend,
dropout_probability, dtype, is_training, pad_ratio):
Expand All @@ -174,6 +174,9 @@ def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
backend=backend,
dropout_probability=dropout_probability,
dtype=dtype,
head_dim=d,
pad_ratio=pad_ratio)
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
Expand Down Expand Up @@ -361,7 +364,7 @@ def grad_func(fused_attn_func, *args, **kwargs):
jnp.zeros_like(primitive_dbias[:, :, self.valid_len:, self.valid_len:]))


@pytest.mark.skipif(not is_fused_attn_kernel_available(),
@pytest.mark.skipif(get_device_compute_capability(0) not in [80, 90],
reason="Fused attention kernel is not supported.")
@pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_CASES)
@pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK])
Expand Down
9 changes: 7 additions & 2 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
} else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) {
bool flag_m512 = false;
bool flag_arb = false;
if ((sm_arch_ >= 80)
if ((sm_arch_ == 80 || sm_arch_ == 90)
&& (head_dim == 64)
&& ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
|| (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS))
Expand All @@ -55,7 +55,12 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
|| (qkv_layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED))) {
flag_m512 = true;
}
if ((sm_arch_ >= 80)
if (
#if (CUDNN_VERSION >= 8903)
(sm_arch_ >= 80)
#else
(sm_arch_ == 80 || sm_arch_ == 90)
#endif
&& (max_seqlen_q == max_seqlen_kv)
&& ((head_dim == 64) || (head_dim == 128))
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
Expand Down
34 changes: 30 additions & 4 deletions transformer_engine/jax/cpp_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,33 @@ def jax_dtype_to_te_dtype(jax_dtype):
raise ValueError(f"Not support the {jax_dtype=}")


@dataclass(frozen=True)
class FusedAttnHelper:
"""
Helper for the fused attention backend
"""

q_type: jnp.dtype
kv_type: jnp.dtype
attn_bias_type: NVTE_Bias_Type
attn_mask_type: NVTE_Mask_Type
dropout_probability: float
max_seqlen_q: int
max_seqlen_kv: int
head_dim: int

def is_fused_attn_kernel_available(self):
"""Check if there is available fused attention kernel"""
return self.get_fused_attn_backend() != NVTE_Fused_Attn_Backend.NVTE_No_Backend

def get_fused_attn_backend(self):
"""Get the fused attention kernel backend"""
return transformer_engine_jax.get_fused_attn_backend(
jax_dtype_to_te_dtype(self.q_type), jax_dtype_to_te_dtype(self.kv_type),
NVTE_QKV_Layout.NVTE_QKV_INTERLEAVED, self.attn_bias_type, self.attn_mask_type,
self.dropout_probability, self.max_seqlen_q, self.max_seqlen_kv, self.head_dim)


def merge_named_shape(base, new):
"""
merge named shape(ie, dict), no key conflict
Expand Down Expand Up @@ -2053,10 +2080,9 @@ def abstract(
output_shape = (batch, max_seqlen, num_head, head_dim)
output_dtype = qkv_dtype

backend = transformer_engine_jax.get_fused_attn_backend(
jax_dtype_to_te_dtype(qkv_dtype), jax_dtype_to_te_dtype(qkv_dtype),
NVTE_QKV_Layout.NVTE_QKV_INTERLEAVED, attn_bias_type, attn_mask_type,
dropout_probability, max_seqlen, max_seqlen, head_dim)
backend = FusedAttnHelper(qkv_dtype, qkv_dtype, attn_bias_type, attn_mask_type,
dropout_probability, max_seqlen, max_seqlen,
head_dim).get_fused_attn_backend()

if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_aux_shape = (batch, num_head, max_seqlen, max_seqlen)
Expand Down
1 change: 0 additions & 1 deletion transformer_engine/jax/csrc/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_cuda_version", &GetCudaRuntimeVersion);
m.def("get_device_compute_capability", &GetDeviceComputeCapability);
m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor);
m.def("is_fused_attn_kernel_available", &IsFusedAttnKernelAvailable);
m.def("get_fused_attn_backend", &GetFusedAttnBackend);

pybind11::enum_<DType>(m, "DType", pybind11::module_local())
Expand Down
11 changes: 1 addition & 10 deletions transformer_engine/jax/csrc/modules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <vector>

#include "common/common.h"
#include "common/util/cuda_runtime.h"
#include "transformer_engine/activation.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/fused_attn.h"
Expand Down Expand Up @@ -89,16 +90,6 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor(
bias_type, mask_type, dtype, is_training});
}

bool IsFusedAttnKernelAvailable() {
#if (CUDNN_VERSION >= 8901)
auto major = cudaDevicePropertiesManager::Instance().GetMajor();
// Fused attention requires at least Ampere
return major >= 8;
#else
return false;
#endif
}

void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream,
void *output) {
auto input_shape = std::vector<size_t>{rows, cols};
Expand Down
2 changes: 0 additions & 2 deletions transformer_engine/jax/csrc/modules.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,6 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor(
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, DType dtype, bool is_training);

bool IsFusedAttnKernelAvailable();

NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability,
Expand Down
49 changes: 27 additions & 22 deletions transformer_engine/jax/flax/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,21 @@ def kv_init(key, shape, dtype):

return jnp.stack([k_kernel, v_kernel], axis=-2, dtype=dtype)

first_sharding_type, second_sharding_type = infer_sharding_type()
# TODO(rewang): make it configurable for pre_scale_bias
attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS

def canonicalize_attn_mask_type(attn_mask_type):
"""
Convert the string to AttnMaskType
"""
if attn_mask_type == 'causal':
return AttnMaskType.CAUSAL_MASK
if attn_mask_type == 'padding':
return AttnMaskType.PADDING_MASK
raise ValueError(f"Unsupported {attn_mask_type=}, "
"supported attn_mask_type = {'causal', 'padding'}")

attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)

canonicalize_dtype = dtypes.canonicalize_dtype(self.dtype)
q_seqlen = inputs_q.shape[0] if self.transpose_batch_sequence else inputs_q.shape[1]
Expand All @@ -427,11 +441,16 @@ def _check_seqlen(seqlen):
def _check_head_dim(head_dim):
return head_dim in [64, 128]

has_fused_attn_kernel = is_fused_attn_kernel_available(self.dtype, self.dtype,
attn_bias_type, attn_mask_type,
self.dropout_rate, q_seqlen,
kv_seqlen, self.head_dim)

use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \
canonicalize_dtype in [jnp.bfloat16, jnp.float16] and \
_check_seqlen(q_seqlen) and _check_seqlen(kv_seqlen) and \
_check_head_dim(self.head_dim) and \
is_fused_attn_kernel_available() and \
has_fused_attn_kernel and \
enable_fused_attn

if enable_fused_attn and not use_fused_attn:
Expand All @@ -454,12 +473,14 @@ def _check_head_dim(head_dim):
f"but got {kv_seqlen=}, "
if not _check_head_dim(self.head_dim):
reason += f"head_dim should be 64 or 128 but got {self.head_dim}, "
if not is_fused_attn_kernel_available():
reason += "GPU arch >= Ampere and cuDNN >= 8.9.1 are required, "
if not has_fused_attn_kernel:
reason += "no fused attention kernel is available, "

warnings.warn(
f"Fused attention is not enabled, " \
f"{reason}fall back to unfused attention")
f"Fused attention is not enabled. Because " \
f"{reason}fall back to unfused attention.")

first_sharding_type, second_sharding_type = infer_sharding_type()

residual = inputs_q
if self.fuse_qkv:
Expand Down Expand Up @@ -629,22 +650,6 @@ def _check_head_dim(head_dim):
# ensure the old key never used
del dropout_rng

# TODO(rewang): make it configurable for pre_scale_bias
attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS

def canonicalize_attn_mask_type(attn_mask_type):
"""
Convert the string to AttnMaskType
"""
if attn_mask_type == 'causal':
return AttnMaskType.CAUSAL_MASK
if attn_mask_type == 'padding':
return AttnMaskType.PADDING_MASK
raise ValueError(f"Unsupported {attn_mask_type=}, "
"supported attn_mask_type = {'causal', 'padding'}")

attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)

if inputs_q is inputs_kv:
qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim))
qkv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES,
Expand Down
19 changes: 11 additions & 8 deletions transformer_engine/jax/fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import jax
import jax.numpy as jnp

import transformer_engine_jax
from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type

from .cpp_extensions import FusedAttnHelper
from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd
from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd
from .sharding import get_fused_attn_sharding_meta
Expand All @@ -22,13 +22,6 @@
jax.config.update('experimental_xmap_spmd_lowering_manual', True)


def is_fused_attn_kernel_available():
"""
To check whether the fused attention kernel is available
"""
return transformer_engine_jax.is_fused_attn_kernel_available()


class AttnBiasType(Enum):
"""Attention Bias Type."""
NO_BIAS = NVTE_Bias_Type.NVTE_NO_BIAS
Expand All @@ -43,6 +36,16 @@ class AttnMaskType(Enum):
CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK


def is_fused_attn_kernel_available(q_type, kv_type, attn_bias_type, attn_mask_type,
dropout_probability, max_seqlen_q, max_seqlen_kv, head_dim):
"""
To check whether the fused attention kernel is available
"""
return FusedAttnHelper(q_type, kv_type, attn_bias_type.value, attn_mask_type.value,
dropout_probability, max_seqlen_q, max_seqlen_kv,
head_dim).is_fused_attn_kernel_available()


def self_fused_attn(qkv: jnp.ndarray,
bias: jnp.ndarray,
mask: jnp.ndarray,
Expand Down