From 7bf000d670e51a42d90c7633d191d6b1c66e50bd Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Tue, 11 Apr 2023 13:17:43 -0700 Subject: [PATCH 1/3] [Unity][BYOC] Add fused patterns for stacked attention In some models, the input Q, K and V for attention ops are from a stacked tensor initially, and then they are splitted and reshaped to call attention op, like stacked_qkv -> split -> reshape -> attention. Actually, we could to skip the split and reshape ops, by manipulating the layout parameters in codegen. This PR adds the such fused patterns for stacked attention in BYOC. So that we are able to codegen directly from stacked_qkv. --- .../contrib/cutlass/attention_operation.py | 92 ++++++++++++------- python/tvm/contrib/cutlass/build.py | 56 ++++++----- python/tvm/contrib/cutlass/gen_tensor_op.py | 35 ++++--- python/tvm/relax/backend/contrib/cutlass.py | 9 ++ python/tvm/relax/backend/patterns.py | 26 +++++- src/relax/backend/contrib/cutlass/codegen.cc | 2 + tests/python/relax/test_codegen_cutlass.py | 88 +++++++++++++++++- 7 files changed, 236 insertions(+), 72 deletions(-) diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py index f7dee4e3b80a..8e34385b4441 100644 --- a/python/tvm/contrib/cutlass/attention_operation.py +++ b/python/tvm/contrib/cutlass/attention_operation.py @@ -19,45 +19,87 @@ from .library import * -def instantiate_attention_template(attrs, func_args): +def instantiate_attention_template(attrs): """Return CUTLASS host code for fused multi head attention based on a template and the provided attribute map.""" bias_template = { "B11S'": """ - CHECK(${arg3}->ndim == 2); // B, 1, 1, S' + CHECK(${bias}->ndim == 2); // B, 1, 1, S' - p.attn_bias_ptr = reinterpret_cast(${arg3}->data); + p.attn_bias_ptr = reinterpret_cast(${bias}->data); p.bias_strideM = 0; // 0 p.bias_strideH = 0; // 0 p.bias_strideB = p.num_keys; // S' """, "B1SS'": """ - CHECK(${arg3}->ndim == 3); // B, 1, S, S' + CHECK(${bias}->ndim == 3); // B, 1, S, S' - p.attn_bias_ptr = reinterpret_cast(${arg3}->data); + p.attn_bias_ptr = reinterpret_cast(${bias}->data); p.bias_strideM = p.num_keys; // S' p.bias_strideH = 0; // 0 p.bias_strideB = p.bias_strideM * p.num_queries; // S' * S """, "BNSS'": """ - CHECK(${arg3}->ndim == 4); // B, N, S, S' + CHECK(${bias}->ndim == 4); // B, N, S, S' - p.attn_bias_ptr = reinterpret_cast(${arg3}->data); + p.attn_bias_ptr = reinterpret_cast(${bias}->data); p.bias_strideM = p.num_keys; // S' p.bias_strideH = p.bias_strideM * p.num_queries; // S' * S p.bias_strideB = p.bias_strideH * p.num_heads; // S' * S * N """, } + qkv_template = { + "default": """ + p.query_ptr = reinterpret_cast(${query}->data); + p.key_ptr = reinterpret_cast(${key}->data); + p.value_ptr = reinterpret_cast(${value}->data); + CHECK(${query}->ndim == 4); // B, S, N, H + CHECK(${key}->ndim == 4); // B, S', N, H + CHECK(${value}->ndim == 4); // B, S', N, H' + + // stride for N + p.q_strideH = p.head_dim; // H + p.k_strideH = p.head_dim; // H + p.v_strideH = p.head_dim_value; // H' + + // stride for S + p.q_strideM = p.q_strideH * p.num_heads; // H * N + p.k_strideM = p.k_strideH * p.num_heads; // H * N + p.v_strideM = p.v_strideH * p.num_heads; // H' * N + + // stride for B + p.q_strideB = p.q_strideM * p.num_queries; // H * N * S + p.k_strideB = p.k_strideM * p.num_keys; // H * N * S' + p.v_strideB = p.v_strideM * p.num_keys; // H'* N * S' +""", + "qkv_stacked": """ + p.query_ptr = reinterpret_cast(${qkv}->data); + p.key_ptr = reinterpret_cast(${qkv}->data) + p.head_dim * p.num_heads; + p.value_ptr = reinterpret_cast(${qkv}->data) + p.head_dim * p.num_heads * 2; + CHECK(${qkv}->ndim == 3); // B, S, NH + NH + NH' + + // stride for N + p.q_strideH = p.head_dim; // H + p.k_strideH = p.head_dim; // H + p.v_strideH = p.head_dim_value; // H' + + // stride for S + p.q_strideM = p.k_strideM = p.v_strideM = + p.q_strideH * p.num_heads + + p.k_strideH * p.num_heads + + p.v_strideH * p.num_heads; // H * N + H * N + H * N' + + // stride for B + p.q_strideB = p.k_strideB = p.v_strideB = + p.q_strideM * p.num_queries; // (H * N + H * N + H * N') * S +""", + } + template = """ using T = ${data_type}; - CHECK(${arg0}->ndim == 4); // B, S, N, H - CHECK(${arg1}->ndim == 4); // B, S', N, H - CHECK(${arg2}->ndim == 4); // B, S', N, H' - CHECK(out0->ndim == 4); // B, S, N, H' - using Attention = AttentionKernel; typename Attention::Params p; - - p.query_ptr = reinterpret_cast(${arg0}->data); - p.key_ptr = reinterpret_cast(${arg1}->data); - p.value_ptr = reinterpret_cast(${arg2}->data); p.logsumexp_ptr = nullptr; p.output_ptr = reinterpret_cast(out0->data); p.output_accum_ptr = nullptr; @@ -92,22 +130,11 @@ def instantiate_attention_template(attrs, func_args): p.num_keys = ${num_keys}; // S' p.scale = ${scale}; - // stride for N - p.q_strideH = p.head_dim; // H - p.k_strideH = p.head_dim; // H - p.v_strideH = p.head_dim_value; // H' - // stride for S - p.q_strideM = p.q_strideH * p.num_heads; // H * N - p.k_strideM = p.k_strideH * p.num_heads; // H * N - p.v_strideM = p.v_strideH * p.num_heads; // H' * N p.o_strideM = p.head_dim_value * p.num_heads; // H' * N + CHECK(out0->ndim == 4); // B, S, N, H' - // stride for B - p.q_strideB = p.q_strideM * p.num_queries; // H * N * S - p.k_strideB = p.k_strideM * p.num_keys; // H * N * S' - p.v_strideB = p.v_strideM * p.num_keys; // H'* N * S' - + ${qkv_template} ${bias_template} constexpr auto kernel_fn = attention_kernel_batched_impl; @@ -126,9 +153,10 @@ def instantiate_attention_template(attrs, func_args): template = substitute_template( template, - {"bias_template": bias_template[attrs["bias_layout"]] if "bias_layout" in attrs else ""}, + { + "qkv_template": qkv_template[attrs["qkv_layout"]], + "bias_template": bias_template[attrs["bias_layout"]] if "bias_layout" in attrs else "", + }, ) - for i, arg in enumerate(func_args): - attrs["arg{}".format(i)] = arg return substitute_template(template, attrs) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 43494991a04c..e943aec6b1e1 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -546,12 +546,17 @@ def _extract_relax_function_signature(f): for i, arg in enumerate(f.params): sinfo = arg.struct_info - signature["arg%d_shape" % i] = get_const_tuple(sinfo.shape) - signature["arg%d_dtype" % i] = sinfo.dtype + if isinstance(sinfo, relax.TensorStructInfo): + signature["arg%d_shape" % i] = get_const_tuple(sinfo.shape) + signature["arg%d_dtype" % i] = sinfo.dtype + elif isinstance(sinfo, relax.ShapeStructInfo): + signature["arg%d_shape" % i] = get_const_tuple(sinfo.values) + else: + raise NotImplementedError() ret_sinfo = f.ret_struct_info if ret_sinfo.shape is not None: - signature["ret_shape"] = list(ret_sinfo.shape) + signature["ret_shape"] = get_const_tuple(ret_sinfo.shape) else: signature["ret_shape"] = None signature["ret_dtype"] = ret_sinfo.dtype @@ -779,34 +784,42 @@ def handle_attention(self, f, op_type): op_attrs = _get_call_node(f.body, "relax.nn.attention_bias").attrs else: raise ValueError(f"Cannot find call node for attention") - q_shape = signature["arg0_shape"] - k_shape = signature["arg1_shape"] - v_shape = signature["arg2_shape"] + arg = {} + + if "stacked_attention" in op_type: + arg["arg0_shape"] = signature["arg0_shape"] + arg["arg0_dtype"] = signature["arg0_dtype"] + arg["arg1_shape"] = q_shape = signature["arg1_shape"] + arg["arg2_shape"] = k_shape = signature["arg2_shape"] + arg["arg3_shape"] = v_shape = signature["arg3_shape"] + if "arg4_dtype" in signature: + arg["bias_dtype"] = signature["arg4_dtype"] + if "arg4_shape" in signature: + arg["bias_shape"] = signature["arg4_shape"] + qkv_layout = "qkv_stacked" + else: + arg["arg0_shape"] = q_shape = signature["arg0_shape"] + arg["arg1_shape"] = k_shape = signature["arg1_shape"] + arg["arg2_shape"] = v_shape = signature["arg2_shape"] + arg["arg0_dtype"] = signature["arg0_dtype"] + arg["arg1_dtype"] = signature["arg1_dtype"] + arg["arg2_dtype"] = signature["arg2_dtype"] + if "arg3_dtype" in signature: + arg["bias_dtype"] = signature["arg3_dtype"] + if "arg3_shape" in signature: + arg["bias_shape"] = signature["arg3_shape"] + qkv_layout = "default" out_shape = signature["ret_shape"] - q_dtype = signature["arg0_dtype"] - k_dtype = signature["arg1_dtype"] - v_dtype = signature["arg2_dtype"] out_dtype = signature["ret_dtype"] num_batches, num_queries, num_heads, head_dim = q_shape _, num_keys, _, _ = k_shape _, _, _, head_dim_value = v_shape scale = op_attrs.scale - bias = {} - if "arg3_dtype" in signature: - bias["arg3_dtype"] = signature["arg3_dtype"] - if "arg3_shape" in signature: - bias["arg3_shape"] = signature["arg3_shape"] return f.with_attrs( { "op_type": op_type, - "arg0_dtype": q_dtype, - "arg1_dtype": k_dtype, - "arg2_dtype": v_dtype, "ret_dtype": out_dtype, - "arg0_shape": q_shape, - "arg1_shape": k_shape, - "arg2_shape": v_shape, "ret_shape": out_shape, "num_batches": num_batches, "num_queries": num_queries, @@ -816,7 +829,8 @@ def handle_attention(self, f, op_type): "head_dim_value": head_dim_value, "scale": scale, "arch": self.options["sm"], - **bias, + "qkv_layout": qkv_layout, + **arg, } ) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 61c88c657f05..bb4d2243297c 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -500,12 +500,6 @@ def instantiate_template(func_name, annotations, func_args): if k in annotations: attrs[k] = annotations[k] - arg0_shape = annotations["arg0_shape"] - arg1_shape = annotations["arg1_shape"] - attrs["ElementInputA"] = DataTypeTag[dtype_map[annotations["arg0_dtype"]]] - attrs["ElementInputB"] = DataTypeTag[dtype_map[annotations["arg1_dtype"]]] - attrs["ElementOutput"] = DataTypeTag[dtype_map[annotations["ret_dtype"]]] - headers = [] if "relu" in func_name: @@ -649,12 +643,12 @@ def get_batch_on_arg(arg_name, arg_shape): if "conv2d_transpose" in func_name: headers.append("cutlass/conv/kernel/default_conv2d_dgrad.h") activation_shape = output_shape - output_shape = arg0_shape + output_shape = annotations["arg0_shape"] elif "backward" in func_name: headers.append("cutlass/conv/kernel/default_conv2d_wgrad.h") - activation_shape = arg1_shape + activation_shape = annotations["arg1_shape"] weight_shape = output_shape - output_shape = arg0_shape + output_shape = annotations["arg0_shape"] elif "residual" in func_name: headers.append("cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h") else: @@ -731,13 +725,26 @@ def get_batch_on_arg(arg_name, arg_shape): ), "Cutlass may generate nan occasionally when scale == 0.0" attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"]) attrs["kSupportsDropout"] = False - if len(func_args) > 3: + attrs["qkv_layout"] = annotations["qkv_layout"] + if attrs["qkv_layout"] == "default": + attrs["query"] = func_args[0] + attrs["key"] = func_args[1] + attrs["value"] = func_args[2] + if len(func_args) > 3: + attrs["bias"] = func_args[3] + elif attrs["qkv_layout"] == "qkv_stacked": + attrs["qkv"] = func_args[0] + if len(func_args) > 4: + attrs["bias"] = func_args[4] + else: + raise NotImplementedError() + if "bias" in attrs: attrs["kSupportsBias"] = True - if len(annotations["arg3_shape"]) == 4: + if len(annotations["bias_shape"]) == 4: attrs["bias_layout"] = "BNSS'" - elif len(annotations["arg3_shape"]) == 3: + elif len(annotations["bias_shape"]) == 3: attrs["bias_layout"] = "B1SS'" - elif len(annotations["arg3_shape"]) == 2: + elif len(annotations["bias_shape"]) == 2: attrs["bias_layout"] = "B11S'" else: raise NotImplementedError() @@ -745,7 +752,7 @@ def get_batch_on_arg(arg_name, arg_shape): # To support negative scale in current Cutlass implementation, # kSupportsBias should be set true, or there are nan's as result. attrs["kSupportsBias"] = attrs["scale"] < 0 - code = instantiate_attention_template(attrs, func_args) + code = instantiate_attention_template(attrs) return CodegenResult(code, headers) raise ValueError("Do not have a template for {}".format(func_name)) diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index 856cd4d7871f..4515118f5889 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -29,6 +29,7 @@ make_fused_bias_activation_pattern, make_matmul_pattern, make_residual_block_pattern, + make_stacked_attention_pattern, ) @@ -244,6 +245,14 @@ def attention_patterns(): "cutlass.attention_bias", *make_attention_pattern(with_bias=True), ), + ( + "cutlass.stacked_attention", + *make_stacked_attention_pattern(), + ), + ( + "cutlass.stacked_attention", + *make_stacked_attention_pattern(with_bias=True), + ), ] diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py index e27b91b3eaa6..a3f4ea8fa2bf 100644 --- a/python/tvm/relax/backend/patterns.py +++ b/python/tvm/relax/backend/patterns.py @@ -19,7 +19,7 @@ from typing import Dict, Mapping, Tuple, Union -from tvm.relax.dpl.pattern import DFPattern, is_op, wildcard +from tvm.relax.dpl.pattern import DFPattern, is_const, is_op, is_tuple, is_tuple_get_item, wildcard def _with_bias_activation_pattern( @@ -190,3 +190,27 @@ def make_attention_pattern(with_bias: bool = False): out = is_op("relax.nn.attention")(query, key, value) return out, annotations + + +def make_stacked_attention_pattern(with_bias: bool = False): + stacked_qkv = wildcard() + qkv_tuple = is_op("relax.split")(stacked_qkv) + query_reshape_list = wildcard() + key_reshape_list = wildcard() + value_reshape_list = wildcard() + query = is_op("relax.reshape")(is_tuple_get_item(qkv_tuple, 0), query_reshape_list) + key = is_op("relax.reshape")(is_tuple_get_item(qkv_tuple, 1), key_reshape_list) + value = is_op("relax.reshape")(is_tuple_get_item(qkv_tuple, 2), value_reshape_list) + annotations = { + "stacked_qkv": stacked_qkv, + "query_reshape_list": query_reshape_list, + "key_reshape_list": key_reshape_list, + "value_reshape_list": value_reshape_list, + } + if with_bias: + bias = wildcard() + annotations["bias"] = bias + out = is_op("relax.nn.attention_bias")(query, key, value, bias) + else: + out = is_op("relax.nn.attention")(query, key, value) + return out, annotations diff --git a/src/relax/backend/contrib/cutlass/codegen.cc b/src/relax/backend/contrib/cutlass/codegen.cc index 8ef68baf6832..730d098510c0 100644 --- a/src/relax/backend/contrib/cutlass/codegen.cc +++ b/src/relax/backend/contrib/cutlass/codegen.cc @@ -59,6 +59,8 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, auto sinfo = GetStructInfo(arg); if (const auto* tensor_sinfo = sinfo.as()) { arg_types.emplace_back(backend::DType2String(tensor_sinfo->dtype)); + } else if (const auto* shape_sinfo = sinfo.as()) { + arg_types.emplace_back(backend::DType2String(shape_sinfo->values.value()[0]->dtype)); } else { LOG(FATAL) << "Unimplemented"; } diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index b9ba4f4dc9af..9288db3eb545 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -592,10 +592,10 @@ def attention_bias_size(request): return request.param -def test_attention_bias_offload(attention_bias_size, attention_dtype): +def test_attention_bias_offload(attention_bias_size): b, (s, s_kv), n, (h, h_v), bias_shape, bias_reshape = attention_bias_size q, k, v, bias, ref = get_numpy_attention_ref( - b, s, s_kv, n, h, h_v, bias_shape, bias_reshape, "none", attention_dtype + b, s, s_kv, n, h, h_v, bias_shape, bias_reshape, "none", "float32" ) mod = get_relax_attention_module(q, k, v, bias) @@ -620,10 +620,10 @@ def attention_scale(request): return request.param -def test_attention_scale_offload(attention_scale_size, attention_scale, attention_dtype): +def test_attention_scale_offload(attention_scale_size, attention_scale): b, (s, s_kv), n, (h, h_v), bias_shape, bias_reshape = attention_scale_size q, k, v, bias, ref = get_numpy_attention_ref( - b, s, s_kv, n, h, h_v, bias_shape, bias_reshape, attention_scale, attention_dtype + b, s, s_kv, n, h, h_v, bias_shape, bias_reshape, attention_scale, "float32" ) mod = get_relax_attention_module(q, k, v, bias, attention_scale) @@ -634,5 +634,85 @@ def test_attention_scale_offload(attention_scale_size, attention_scale, attentio tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) +@memoize("topi.tests.test_codegen_cutlass.test_stacked_attention_offload") +def get_numpy_stacked_attention_ref(b, s, n, h, h_v, bias_shape, bias_reshape, qk_scale, dtype): + qkv = np.random.randn(b, s, n * h + n * h + n * h_v).astype(dtype) + split_qkv = np.split(qkv, [n * h, n * h * 2], axis=2) + q = np.reshape(split_qkv[0], (b, s, n, h)) + k = np.reshape(split_qkv[1], (b, s, n, h)) + v = np.reshape(split_qkv[2], (b, s, n, h_v)) + qt = q.transpose(0, 2, 1, 3) # b, n, s, h + kt = k.transpose(0, 2, 3, 1) # b, n, h, s + if not qk_scale == "none": + score = qt @ kt * qk_scale # b, n, s, s + else: + score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s + if not bias_shape == "none": + bias = np.random.randn(*bias_shape).astype(dtype) + score = score + bias.reshape(*bias_reshape) # b, n, s, s + else: + bias = None + attn = tvm.topi.testing.softmax_python(score, -1) + vt = v.transpose(0, 2, 1, 3) # b, n, s, h_v + ref = attn @ vt # b, n, s, h_v + return qkv, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v + + +def get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, bias=None, qk_scale=None): + dtype = str(qkv.dtype) + + from tvm.script.ir_builder import IRBuilder + from tvm.script.ir_builder import relax as relax_builder, tir as T + + if qk_scale is not None: + qk_scale = T.FloatImm("float32", qk_scale) + + with IRBuilder() as builder: + with relax_builder.function(): + R.func_name("main") + qkv = R.arg("qkv", R.Tensor(qkv.shape, dtype)) + if bias is not None: + bias = R.arg("bias", R.Tensor(bias.shape, dtype)) + with R.dataflow() as frame: + qkv_tuple = R.split(qkv, [n * h, n * h * 2], axis=2) + q = R.reshape(qkv_tuple[0], [b, s, n, h]) + k = R.reshape(qkv_tuple[1], [b, s, n, h]) + v = R.reshape(qkv_tuple[2], [b, s, n, h_v]) + result = R.emit(R.nn.attention(q, k, v, bias, qk_scale)) + R.output(result) + + R.func_ret_value(frame.output_vars[0]) + + func = builder.get() + return tvm.IRModule({"main": func}) + + +@pytest.fixture( + params=[ + # B, S, N, H, bias_shape, bias_reshape, scale + (4, 8, 32, (64, 32), "none", "none", "none"), + (4, 8, 32, (64, 32), (4, 32, 8, 8), (4, 32, 8, 8), 0.5), + ] +) +def stacked_attention_size(request): + return request.param + + +def test_stacked_attention_offload(stacked_attention_size): + b, s, n, (h, h_v), bias_shape, bias_reshape, scale = stacked_attention_size + qkv, bias, ref = get_numpy_stacked_attention_ref( + b, s, n, h, h_v, bias_shape, bias_reshape, scale, "float32" + ) + if scale == "none": + mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, bias) + else: + mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, bias, scale) + if bias is None: + out = get_result_with_relax_cutlass_offload(mod, qkv) + else: + out = get_result_with_relax_cutlass_offload(mod, qkv, bias) + tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) + + if __name__ == "__main__": tvm.testing.main() From b48e7bf65f137567a98e33b32d0759401de47e59 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Wed, 12 Apr 2023 12:05:22 -0700 Subject: [PATCH 2/3] fix lint --- python/tvm/contrib/cutlass/attention_operation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py index 8e34385b4441..7c5b7048a248 100644 --- a/python/tvm/contrib/cutlass/attention_operation.py +++ b/python/tvm/contrib/cutlass/attention_operation.py @@ -86,13 +86,13 @@ def instantiate_attention_template(attrs): p.v_strideH = p.head_dim_value; // H' // stride for S - p.q_strideM = p.k_strideM = p.v_strideM = - p.q_strideH * p.num_heads + - p.k_strideH * p.num_heads + + p.q_strideM = p.k_strideM = p.v_strideM = + p.q_strideH * p.num_heads + + p.k_strideH * p.num_heads + p.v_strideH * p.num_heads; // H * N + H * N + H * N' // stride for B - p.q_strideB = p.k_strideB = p.v_strideB = + p.q_strideB = p.k_strideB = p.v_strideB = p.q_strideM * p.num_queries; // (H * N + H * N + H * N') * S """, } From 4807dd6fa09d5e4dc389f850d3a74ae73af483a3 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Wed, 12 Apr 2023 13:27:26 -0700 Subject: [PATCH 3/3] fix lint --- python/tvm/relax/backend/patterns.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py index a3f4ea8fa2bf..9e34b0c96472 100644 --- a/python/tvm/relax/backend/patterns.py +++ b/python/tvm/relax/backend/patterns.py @@ -19,7 +19,7 @@ from typing import Dict, Mapping, Tuple, Union -from tvm.relax.dpl.pattern import DFPattern, is_const, is_op, is_tuple, is_tuple_get_item, wildcard +from tvm.relax.dpl.pattern import DFPattern, is_op, is_tuple_get_item, wildcard def _with_bias_activation_pattern( @@ -168,6 +168,11 @@ def make_attention_pattern(with_bias: bool = False): """ Create pattern for fused multi head attention. + Parameters + ---------- + with_bias: bool + Whether or not to include bias addition + Returns ------- pattern: DFPattern @@ -193,6 +198,24 @@ def make_attention_pattern(with_bias: bool = False): def make_stacked_attention_pattern(with_bias: bool = False): + """ + Create pattern for fused multi head attention with stacked input. + + Parameters + ---------- + with_bias: bool + Whether or not to include bias addition + + Returns + ------- + pattern: DFPattern + The resulting pattern describing a fused multi head attention. + + annotations: Mapping[str, DFPattern] + A mapping from name to sub pattern. It can be used to extract + important expressions from match result, to power the partition + check function and codegen. + """ stacked_qkv = wildcard() qkv_tuple = is_op("relax.split")(stacked_qkv) query_reshape_list = wildcard()