From aed469b696829ff54a1c6819940dd80c9c49adae Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Mon, 17 Apr 2023 09:48:29 +0000 Subject: [PATCH 1/2] [Unity][BYOC] Fuse attention pattern with `strided_slice` This PR expands the support for fused stacked attention patterns strating with `strided_slice`. Initially, we only support fused stacked attention pattern starting with `split` in #14608. But with the help of #14583, we may have similar patterns starting with `strided_slice` as well. --- python/tvm/relax/backend/contrib/cutlass.py | 12 +++++- python/tvm/relax/backend/patterns.py | 23 ++++++++--- tests/python/relax/test_codegen_cutlass.py | 45 +++++++++++++++++---- 3 files changed, 65 insertions(+), 15 deletions(-) diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index 4515118f5889..06edd9febfd3 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -247,11 +247,19 @@ def attention_patterns(): ), ( "cutlass.stacked_attention", - *make_stacked_attention_pattern(), + *make_stacked_attention_pattern(start_op="split"), ), ( "cutlass.stacked_attention", - *make_stacked_attention_pattern(with_bias=True), + *make_stacked_attention_pattern(start_op="split", with_bias=True), + ), + ( + "cutlass.stacked_attention", + *make_stacked_attention_pattern(start_op="strided_slice"), + ), + ( + "cutlass.stacked_attention", + *make_stacked_attention_pattern(start_op="strided_slice", with_bias=True), ), ] diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py index 9e34b0c96472..6197fe44ca70 100644 --- a/python/tvm/relax/backend/patterns.py +++ b/python/tvm/relax/backend/patterns.py @@ -197,12 +197,15 @@ def make_attention_pattern(with_bias: bool = False): return out, annotations -def make_stacked_attention_pattern(with_bias: bool = False): +def make_stacked_attention_pattern(start_op: str, with_bias: bool = False): """ Create pattern for fused multi head attention with stacked input. Parameters ---------- + start_op: str + The starting op for pattern, i.e. `R.split` or `R.strided_slice`. + with_bias: bool Whether or not to include bias addition @@ -217,13 +220,23 @@ def make_stacked_attention_pattern(with_bias: bool = False): check function and codegen. """ stacked_qkv = wildcard() - qkv_tuple = is_op("relax.split")(stacked_qkv) + if start_op == "split": + qkv_tuple = is_op("relax.split")(stacked_qkv) + query_raw = is_tuple_get_item(qkv_tuple, 0) + key_raw = is_tuple_get_item(qkv_tuple, 1) + value_raw = is_tuple_get_item(qkv_tuple, 2) + elif start_op == "strided_slice": + query_raw = is_op("relax.strided_slice")(stacked_qkv) + key_raw = is_op("relax.strided_slice")(stacked_qkv) + value_raw = is_op("relax.strided_slice")(stacked_qkv) + else: + raise NotImplementedError() query_reshape_list = wildcard() key_reshape_list = wildcard() value_reshape_list = wildcard() - query = is_op("relax.reshape")(is_tuple_get_item(qkv_tuple, 0), query_reshape_list) - key = is_op("relax.reshape")(is_tuple_get_item(qkv_tuple, 1), key_reshape_list) - value = is_op("relax.reshape")(is_tuple_get_item(qkv_tuple, 2), value_reshape_list) + query = is_op("relax.reshape")(query_raw, query_reshape_list) + key = is_op("relax.reshape")(key_raw, key_reshape_list) + value = is_op("relax.reshape")(value_raw, value_reshape_list) annotations = { "stacked_qkv": stacked_qkv, "query_reshape_list": query_reshape_list, diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 4309627bf0b9..84459f3092b7 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -660,7 +660,7 @@ def get_numpy_stacked_attention_ref(b, s, n, h, h_v, bias_shape, bias_reshape, q return qkv, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v -def get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, bias=None, qk_scale=None): +def get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, op, bias=None, qk_scale=None): dtype = str(qkv.dtype) from tvm.script.ir_builder import IRBuilder @@ -676,10 +676,23 @@ def get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, bias=None, qk_scale if bias is not None: bias = R.arg("bias", R.Tensor(bias.shape, dtype)) with R.dataflow() as frame: - qkv_tuple = R.split(qkv, [n * h, n * h * 2], axis=2) - q = R.reshape(qkv_tuple[0], [b, s, n, h]) - k = R.reshape(qkv_tuple[1], [b, s, n, h]) - v = R.reshape(qkv_tuple[2], [b, s, n, h_v]) + if op == "split": + qkv_tuple = R.split(qkv, [n * h, n * h * 2], axis=2) + q = R.reshape(qkv_tuple[0], [b, s, n, h]) + k = R.reshape(qkv_tuple[1], [b, s, n, h]) + v = R.reshape(qkv_tuple[2], [b, s, n, h_v]) + elif op == "strided_slice": + qkv_tuple = R.split(qkv, [n * h, n * h * 2], axis=2) + q = R.reshape(R.strided_slice(qkv, [2], [0], [n * h], [1]), [b, s, n, h]) + k = R.reshape( + R.strided_slice(qkv, [2], [n * h], [n * h * 2], [1]), [b, s, n, h] + ) + v = R.reshape( + R.strided_slice(qkv, [2], [n * h * 2], [n * h * 2 + n * h_v], [1]), + [b, s, n, h_v], + ) + else: + raise NotImplementedError() result = R.emit(R.nn.attention(q, k, v, bias, qk_scale)) R.output(result) @@ -700,15 +713,31 @@ def stacked_attention_size(request): return request.param -def test_stacked_attention_offload(stacked_attention_size): +def test_stacked_attention_split_offload(stacked_attention_size): + b, s, n, (h, h_v), bias_shape, bias_reshape, scale = stacked_attention_size + qkv, bias, ref = get_numpy_stacked_attention_ref( + b, s, n, h, h_v, bias_shape, bias_reshape, scale, "float32" + ) + if scale == "none": + mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, "split", bias) + else: + mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, "split", bias, scale) + if bias is None: + out = get_result_with_relax_cutlass_offload(mod, qkv) + else: + out = get_result_with_relax_cutlass_offload(mod, qkv, bias) + tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) + + +def test_stacked_attention_strided_slice_offload(stacked_attention_size): b, s, n, (h, h_v), bias_shape, bias_reshape, scale = stacked_attention_size qkv, bias, ref = get_numpy_stacked_attention_ref( b, s, n, h, h_v, bias_shape, bias_reshape, scale, "float32" ) if scale == "none": - mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, bias) + mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, "strided_slice", bias) else: - mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, bias, scale) + mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, "strided_slice", bias, scale) if bias is None: out = get_result_with_relax_cutlass_offload(mod, qkv) else: From da5a27d38512c94933a96deda754b76b4f4233a6 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Mon, 17 Apr 2023 23:37:38 +0000 Subject: [PATCH 2/2] remove useless code --- tests/python/relax/test_codegen_cutlass.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 84459f3092b7..db8abf34c203 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -682,7 +682,6 @@ def get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, op, bias=None, qk_s k = R.reshape(qkv_tuple[1], [b, s, n, h]) v = R.reshape(qkv_tuple[2], [b, s, n, h_v]) elif op == "strided_slice": - qkv_tuple = R.split(qkv, [n * h, n * h * 2], axis=2) q = R.reshape(R.strided_slice(qkv, [2], [0], [n * h], [1]), [b, s, n, h]) k = R.reshape( R.strided_slice(qkv, [2], [n * h], [n * h * 2], [1]), [b, s, n, h]