Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions python/tvm/relax/backend/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,19 @@ def attention_patterns():
),
(
"cutlass.stacked_attention",
*make_stacked_attention_pattern(),
*make_stacked_attention_pattern(start_op="split"),
),
(
"cutlass.stacked_attention",
*make_stacked_attention_pattern(with_bias=True),
*make_stacked_attention_pattern(start_op="split", with_bias=True),
),
(
"cutlass.stacked_attention",
*make_stacked_attention_pattern(start_op="strided_slice"),
),
(
"cutlass.stacked_attention",
*make_stacked_attention_pattern(start_op="strided_slice", with_bias=True),
),
]

Expand Down
23 changes: 18 additions & 5 deletions python/tvm/relax/backend/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,15 @@ def make_attention_pattern(with_bias: bool = False):
return out, annotations


def make_stacked_attention_pattern(with_bias: bool = False):
def make_stacked_attention_pattern(start_op: str, with_bias: bool = False):
"""
Create pattern for fused multi head attention with stacked input.

Parameters
----------
start_op: str
The starting op for pattern, i.e. `R.split` or `R.strided_slice`.

with_bias: bool
Whether or not to include bias addition

Expand All @@ -217,13 +220,23 @@ def make_stacked_attention_pattern(with_bias: bool = False):
check function and codegen.
"""
stacked_qkv = wildcard()
qkv_tuple = is_op("relax.split")(stacked_qkv)
if start_op == "split":
qkv_tuple = is_op("relax.split")(stacked_qkv)
query_raw = is_tuple_get_item(qkv_tuple, 0)
key_raw = is_tuple_get_item(qkv_tuple, 1)
value_raw = is_tuple_get_item(qkv_tuple, 2)
elif start_op == "strided_slice":
query_raw = is_op("relax.strided_slice")(stacked_qkv)
key_raw = is_op("relax.strided_slice")(stacked_qkv)
value_raw = is_op("relax.strided_slice")(stacked_qkv)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will also need to check the begin/end in strides_slice if value has a different sequence length

else:
raise NotImplementedError()
query_reshape_list = wildcard()
key_reshape_list = wildcard()
value_reshape_list = wildcard()
query = is_op("relax.reshape")(is_tuple_get_item(qkv_tuple, 0), query_reshape_list)
key = is_op("relax.reshape")(is_tuple_get_item(qkv_tuple, 1), key_reshape_list)
value = is_op("relax.reshape")(is_tuple_get_item(qkv_tuple, 2), value_reshape_list)
query = is_op("relax.reshape")(query_raw, query_reshape_list)
key = is_op("relax.reshape")(key_raw, key_reshape_list)
value = is_op("relax.reshape")(value_raw, value_reshape_list)
annotations = {
"stacked_qkv": stacked_qkv,
"query_reshape_list": query_reshape_list,
Expand Down
44 changes: 36 additions & 8 deletions tests/python/relax/test_codegen_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ def get_numpy_stacked_attention_ref(b, s, n, h, h_v, bias_shape, bias_reshape, q
return qkv, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v


def get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, bias=None, qk_scale=None):
def get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, op, bias=None, qk_scale=None):
dtype = str(qkv.dtype)

from tvm.script.ir_builder import IRBuilder
Expand All @@ -676,10 +676,22 @@ def get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, bias=None, qk_scale
if bias is not None:
bias = R.arg("bias", R.Tensor(bias.shape, dtype))
with R.dataflow() as frame:
qkv_tuple = R.split(qkv, [n * h, n * h * 2], axis=2)
q = R.reshape(qkv_tuple[0], [b, s, n, h])
k = R.reshape(qkv_tuple[1], [b, s, n, h])
v = R.reshape(qkv_tuple[2], [b, s, n, h_v])
if op == "split":
qkv_tuple = R.split(qkv, [n * h, n * h * 2], axis=2)
q = R.reshape(qkv_tuple[0], [b, s, n, h])
k = R.reshape(qkv_tuple[1], [b, s, n, h])
v = R.reshape(qkv_tuple[2], [b, s, n, h_v])
elif op == "strided_slice":
q = R.reshape(R.strided_slice(qkv, [2], [0], [n * h], [1]), [b, s, n, h])
k = R.reshape(
R.strided_slice(qkv, [2], [n * h], [n * h * 2], [1]), [b, s, n, h]
)
v = R.reshape(
R.strided_slice(qkv, [2], [n * h * 2], [n * h * 2 + n * h_v], [1]),
[b, s, n, h_v],
)
else:
raise NotImplementedError()
result = R.emit(R.nn.attention(q, k, v, bias, qk_scale))
R.output(result)

Expand All @@ -700,15 +712,31 @@ def stacked_attention_size(request):
return request.param


def test_stacked_attention_offload(stacked_attention_size):
def test_stacked_attention_split_offload(stacked_attention_size):
b, s, n, (h, h_v), bias_shape, bias_reshape, scale = stacked_attention_size
qkv, bias, ref = get_numpy_stacked_attention_ref(
b, s, n, h, h_v, bias_shape, bias_reshape, scale, "float32"
)
if scale == "none":
mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, "split", bias)
else:
mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, "split", bias, scale)
if bias is None:
out = get_result_with_relax_cutlass_offload(mod, qkv)
else:
out = get_result_with_relax_cutlass_offload(mod, qkv, bias)
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)


def test_stacked_attention_strided_slice_offload(stacked_attention_size):
b, s, n, (h, h_v), bias_shape, bias_reshape, scale = stacked_attention_size
qkv, bias, ref = get_numpy_stacked_attention_ref(
b, s, n, h, h_v, bias_shape, bias_reshape, scale, "float32"
)
if scale == "none":
mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, bias)
mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, "strided_slice", bias)
else:
mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, bias, scale)
mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, "strided_slice", bias, scale)
if bias is None:
out = get_result_with_relax_cutlass_offload(mod, qkv)
else:
Expand Down