diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py index f7dee4e3b80a..7c5b7048a248 100644 --- a/python/tvm/contrib/cutlass/attention_operation.py +++ b/python/tvm/contrib/cutlass/attention_operation.py @@ -19,45 +19,87 @@ from .library import * -def instantiate_attention_template(attrs, func_args): +def instantiate_attention_template(attrs): """Return CUTLASS host code for fused multi head attention based on a template and the provided attribute map.""" bias_template = { "B11S'": """ - CHECK(${arg3}->ndim == 2); // B, 1, 1, S' + CHECK(${bias}->ndim == 2); // B, 1, 1, S' - p.attn_bias_ptr = reinterpret_cast(${arg3}->data); + p.attn_bias_ptr = reinterpret_cast(${bias}->data); p.bias_strideM = 0; // 0 p.bias_strideH = 0; // 0 p.bias_strideB = p.num_keys; // S' """, "B1SS'": """ - CHECK(${arg3}->ndim == 3); // B, 1, S, S' + CHECK(${bias}->ndim == 3); // B, 1, S, S' - p.attn_bias_ptr = reinterpret_cast(${arg3}->data); + p.attn_bias_ptr = reinterpret_cast(${bias}->data); p.bias_strideM = p.num_keys; // S' p.bias_strideH = 0; // 0 p.bias_strideB = p.bias_strideM * p.num_queries; // S' * S """, "BNSS'": """ - CHECK(${arg3}->ndim == 4); // B, N, S, S' + CHECK(${bias}->ndim == 4); // B, N, S, S' - p.attn_bias_ptr = reinterpret_cast(${arg3}->data); + p.attn_bias_ptr = reinterpret_cast(${bias}->data); p.bias_strideM = p.num_keys; // S' p.bias_strideH = p.bias_strideM * p.num_queries; // S' * S p.bias_strideB = p.bias_strideH * p.num_heads; // S' * S * N """, } + qkv_template = { + "default": """ + p.query_ptr = reinterpret_cast(${query}->data); + p.key_ptr = reinterpret_cast(${key}->data); + p.value_ptr = reinterpret_cast(${value}->data); + CHECK(${query}->ndim == 4); // B, S, N, H + CHECK(${key}->ndim == 4); // B, S', N, H + CHECK(${value}->ndim == 4); // B, S', N, H' + + // stride for N + p.q_strideH = p.head_dim; // H + p.k_strideH = p.head_dim; // H + p.v_strideH = p.head_dim_value; // H' + + // stride for S + p.q_strideM = p.q_strideH * p.num_heads; // H * N + p.k_strideM = p.k_strideH * p.num_heads; // H * N + p.v_strideM = p.v_strideH * p.num_heads; // H' * N + + // stride for B + p.q_strideB = p.q_strideM * p.num_queries; // H * N * S + p.k_strideB = p.k_strideM * p.num_keys; // H * N * S' + p.v_strideB = p.v_strideM * p.num_keys; // H'* N * S' +""", + "qkv_stacked": """ + p.query_ptr = reinterpret_cast(${qkv}->data); + p.key_ptr = reinterpret_cast(${qkv}->data) + p.head_dim * p.num_heads; + p.value_ptr = reinterpret_cast(${qkv}->data) + p.head_dim * p.num_heads * 2; + CHECK(${qkv}->ndim == 3); // B, S, NH + NH + NH' + + // stride for N + p.q_strideH = p.head_dim; // H + p.k_strideH = p.head_dim; // H + p.v_strideH = p.head_dim_value; // H' + + // stride for S + p.q_strideM = p.k_strideM = p.v_strideM = + p.q_strideH * p.num_heads + + p.k_strideH * p.num_heads + + p.v_strideH * p.num_heads; // H * N + H * N + H * N' + + // stride for B + p.q_strideB = p.k_strideB = p.v_strideB = + p.q_strideM * p.num_queries; // (H * N + H * N + H * N') * S +""", + } + template = """ using T = ${data_type}; - CHECK(${arg0}->ndim == 4); // B, S, N, H - CHECK(${arg1}->ndim == 4); // B, S', N, H - CHECK(${arg2}->ndim == 4); // B, S', N, H' - CHECK(out0->ndim == 4); // B, S, N, H' - using Attention = AttentionKernel; typename Attention::Params p; - - p.query_ptr = reinterpret_cast(${arg0}->data); - p.key_ptr = reinterpret_cast(${arg1}->data); - p.value_ptr = reinterpret_cast(${arg2}->data); p.logsumexp_ptr = nullptr; p.output_ptr = reinterpret_cast(out0->data); p.output_accum_ptr = nullptr; @@ -92,22 +130,11 @@ def instantiate_attention_template(attrs, func_args): p.num_keys = ${num_keys}; // S' p.scale = ${scale}; - // stride for N - p.q_strideH = p.head_dim; // H - p.k_strideH = p.head_dim; // H - p.v_strideH = p.head_dim_value; // H' - // stride for S - p.q_strideM = p.q_strideH * p.num_heads; // H * N - p.k_strideM = p.k_strideH * p.num_heads; // H * N - p.v_strideM = p.v_strideH * p.num_heads; // H' * N p.o_strideM = p.head_dim_value * p.num_heads; // H' * N + CHECK(out0->ndim == 4); // B, S, N, H' - // stride for B - p.q_strideB = p.q_strideM * p.num_queries; // H * N * S - p.k_strideB = p.k_strideM * p.num_keys; // H * N * S' - p.v_strideB = p.v_strideM * p.num_keys; // H'* N * S' - + ${qkv_template} ${bias_template} constexpr auto kernel_fn = attention_kernel_batched_impl; @@ -126,9 +153,10 @@ def instantiate_attention_template(attrs, func_args): template = substitute_template( template, - {"bias_template": bias_template[attrs["bias_layout"]] if "bias_layout" in attrs else ""}, + { + "qkv_template": qkv_template[attrs["qkv_layout"]], + "bias_template": bias_template[attrs["bias_layout"]] if "bias_layout" in attrs else "", + }, ) - for i, arg in enumerate(func_args): - attrs["arg{}".format(i)] = arg return substitute_template(template, attrs) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 43494991a04c..e943aec6b1e1 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -546,12 +546,17 @@ def _extract_relax_function_signature(f): for i, arg in enumerate(f.params): sinfo = arg.struct_info - signature["arg%d_shape" % i] = get_const_tuple(sinfo.shape) - signature["arg%d_dtype" % i] = sinfo.dtype + if isinstance(sinfo, relax.TensorStructInfo): + signature["arg%d_shape" % i] = get_const_tuple(sinfo.shape) + signature["arg%d_dtype" % i] = sinfo.dtype + elif isinstance(sinfo, relax.ShapeStructInfo): + signature["arg%d_shape" % i] = get_const_tuple(sinfo.values) + else: + raise NotImplementedError() ret_sinfo = f.ret_struct_info if ret_sinfo.shape is not None: - signature["ret_shape"] = list(ret_sinfo.shape) + signature["ret_shape"] = get_const_tuple(ret_sinfo.shape) else: signature["ret_shape"] = None signature["ret_dtype"] = ret_sinfo.dtype @@ -779,34 +784,42 @@ def handle_attention(self, f, op_type): op_attrs = _get_call_node(f.body, "relax.nn.attention_bias").attrs else: raise ValueError(f"Cannot find call node for attention") - q_shape = signature["arg0_shape"] - k_shape = signature["arg1_shape"] - v_shape = signature["arg2_shape"] + arg = {} + + if "stacked_attention" in op_type: + arg["arg0_shape"] = signature["arg0_shape"] + arg["arg0_dtype"] = signature["arg0_dtype"] + arg["arg1_shape"] = q_shape = signature["arg1_shape"] + arg["arg2_shape"] = k_shape = signature["arg2_shape"] + arg["arg3_shape"] = v_shape = signature["arg3_shape"] + if "arg4_dtype" in signature: + arg["bias_dtype"] = signature["arg4_dtype"] + if "arg4_shape" in signature: + arg["bias_shape"] = signature["arg4_shape"] + qkv_layout = "qkv_stacked" + else: + arg["arg0_shape"] = q_shape = signature["arg0_shape"] + arg["arg1_shape"] = k_shape = signature["arg1_shape"] + arg["arg2_shape"] = v_shape = signature["arg2_shape"] + arg["arg0_dtype"] = signature["arg0_dtype"] + arg["arg1_dtype"] = signature["arg1_dtype"] + arg["arg2_dtype"] = signature["arg2_dtype"] + if "arg3_dtype" in signature: + arg["bias_dtype"] = signature["arg3_dtype"] + if "arg3_shape" in signature: + arg["bias_shape"] = signature["arg3_shape"] + qkv_layout = "default" out_shape = signature["ret_shape"] - q_dtype = signature["arg0_dtype"] - k_dtype = signature["arg1_dtype"] - v_dtype = signature["arg2_dtype"] out_dtype = signature["ret_dtype"] num_batches, num_queries, num_heads, head_dim = q_shape _, num_keys, _, _ = k_shape _, _, _, head_dim_value = v_shape scale = op_attrs.scale - bias = {} - if "arg3_dtype" in signature: - bias["arg3_dtype"] = signature["arg3_dtype"] - if "arg3_shape" in signature: - bias["arg3_shape"] = signature["arg3_shape"] return f.with_attrs( { "op_type": op_type, - "arg0_dtype": q_dtype, - "arg1_dtype": k_dtype, - "arg2_dtype": v_dtype, "ret_dtype": out_dtype, - "arg0_shape": q_shape, - "arg1_shape": k_shape, - "arg2_shape": v_shape, "ret_shape": out_shape, "num_batches": num_batches, "num_queries": num_queries, @@ -816,7 +829,8 @@ def handle_attention(self, f, op_type): "head_dim_value": head_dim_value, "scale": scale, "arch": self.options["sm"], - **bias, + "qkv_layout": qkv_layout, + **arg, } ) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 61c88c657f05..bb4d2243297c 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -500,12 +500,6 @@ def instantiate_template(func_name, annotations, func_args): if k in annotations: attrs[k] = annotations[k] - arg0_shape = annotations["arg0_shape"] - arg1_shape = annotations["arg1_shape"] - attrs["ElementInputA"] = DataTypeTag[dtype_map[annotations["arg0_dtype"]]] - attrs["ElementInputB"] = DataTypeTag[dtype_map[annotations["arg1_dtype"]]] - attrs["ElementOutput"] = DataTypeTag[dtype_map[annotations["ret_dtype"]]] - headers = [] if "relu" in func_name: @@ -649,12 +643,12 @@ def get_batch_on_arg(arg_name, arg_shape): if "conv2d_transpose" in func_name: headers.append("cutlass/conv/kernel/default_conv2d_dgrad.h") activation_shape = output_shape - output_shape = arg0_shape + output_shape = annotations["arg0_shape"] elif "backward" in func_name: headers.append("cutlass/conv/kernel/default_conv2d_wgrad.h") - activation_shape = arg1_shape + activation_shape = annotations["arg1_shape"] weight_shape = output_shape - output_shape = arg0_shape + output_shape = annotations["arg0_shape"] elif "residual" in func_name: headers.append("cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h") else: @@ -731,13 +725,26 @@ def get_batch_on_arg(arg_name, arg_shape): ), "Cutlass may generate nan occasionally when scale == 0.0" attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"]) attrs["kSupportsDropout"] = False - if len(func_args) > 3: + attrs["qkv_layout"] = annotations["qkv_layout"] + if attrs["qkv_layout"] == "default": + attrs["query"] = func_args[0] + attrs["key"] = func_args[1] + attrs["value"] = func_args[2] + if len(func_args) > 3: + attrs["bias"] = func_args[3] + elif attrs["qkv_layout"] == "qkv_stacked": + attrs["qkv"] = func_args[0] + if len(func_args) > 4: + attrs["bias"] = func_args[4] + else: + raise NotImplementedError() + if "bias" in attrs: attrs["kSupportsBias"] = True - if len(annotations["arg3_shape"]) == 4: + if len(annotations["bias_shape"]) == 4: attrs["bias_layout"] = "BNSS'" - elif len(annotations["arg3_shape"]) == 3: + elif len(annotations["bias_shape"]) == 3: attrs["bias_layout"] = "B1SS'" - elif len(annotations["arg3_shape"]) == 2: + elif len(annotations["bias_shape"]) == 2: attrs["bias_layout"] = "B11S'" else: raise NotImplementedError() @@ -745,7 +752,7 @@ def get_batch_on_arg(arg_name, arg_shape): # To support negative scale in current Cutlass implementation, # kSupportsBias should be set true, or there are nan's as result. attrs["kSupportsBias"] = attrs["scale"] < 0 - code = instantiate_attention_template(attrs, func_args) + code = instantiate_attention_template(attrs) return CodegenResult(code, headers) raise ValueError("Do not have a template for {}".format(func_name)) diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index 856cd4d7871f..4515118f5889 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -29,6 +29,7 @@ make_fused_bias_activation_pattern, make_matmul_pattern, make_residual_block_pattern, + make_stacked_attention_pattern, ) @@ -244,6 +245,14 @@ def attention_patterns(): "cutlass.attention_bias", *make_attention_pattern(with_bias=True), ), + ( + "cutlass.stacked_attention", + *make_stacked_attention_pattern(), + ), + ( + "cutlass.stacked_attention", + *make_stacked_attention_pattern(with_bias=True), + ), ] diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py index e27b91b3eaa6..9e34b0c96472 100644 --- a/python/tvm/relax/backend/patterns.py +++ b/python/tvm/relax/backend/patterns.py @@ -19,7 +19,7 @@ from typing import Dict, Mapping, Tuple, Union -from tvm.relax.dpl.pattern import DFPattern, is_op, wildcard +from tvm.relax.dpl.pattern import DFPattern, is_op, is_tuple_get_item, wildcard def _with_bias_activation_pattern( @@ -168,6 +168,11 @@ def make_attention_pattern(with_bias: bool = False): """ Create pattern for fused multi head attention. + Parameters + ---------- + with_bias: bool + Whether or not to include bias addition + Returns ------- pattern: DFPattern @@ -190,3 +195,45 @@ def make_attention_pattern(with_bias: bool = False): out = is_op("relax.nn.attention")(query, key, value) return out, annotations + + +def make_stacked_attention_pattern(with_bias: bool = False): + """ + Create pattern for fused multi head attention with stacked input. + + Parameters + ---------- + with_bias: bool + Whether or not to include bias addition + + Returns + ------- + pattern: DFPattern + The resulting pattern describing a fused multi head attention. + + annotations: Mapping[str, DFPattern] + A mapping from name to sub pattern. It can be used to extract + important expressions from match result, to power the partition + check function and codegen. + """ + stacked_qkv = wildcard() + qkv_tuple = is_op("relax.split")(stacked_qkv) + query_reshape_list = wildcard() + key_reshape_list = wildcard() + value_reshape_list = wildcard() + query = is_op("relax.reshape")(is_tuple_get_item(qkv_tuple, 0), query_reshape_list) + key = is_op("relax.reshape")(is_tuple_get_item(qkv_tuple, 1), key_reshape_list) + value = is_op("relax.reshape")(is_tuple_get_item(qkv_tuple, 2), value_reshape_list) + annotations = { + "stacked_qkv": stacked_qkv, + "query_reshape_list": query_reshape_list, + "key_reshape_list": key_reshape_list, + "value_reshape_list": value_reshape_list, + } + if with_bias: + bias = wildcard() + annotations["bias"] = bias + out = is_op("relax.nn.attention_bias")(query, key, value, bias) + else: + out = is_op("relax.nn.attention")(query, key, value) + return out, annotations diff --git a/src/relax/backend/contrib/cutlass/codegen.cc b/src/relax/backend/contrib/cutlass/codegen.cc index 8ef68baf6832..730d098510c0 100644 --- a/src/relax/backend/contrib/cutlass/codegen.cc +++ b/src/relax/backend/contrib/cutlass/codegen.cc @@ -59,6 +59,8 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, auto sinfo = GetStructInfo(arg); if (const auto* tensor_sinfo = sinfo.as()) { arg_types.emplace_back(backend::DType2String(tensor_sinfo->dtype)); + } else if (const auto* shape_sinfo = sinfo.as()) { + arg_types.emplace_back(backend::DType2String(shape_sinfo->values.value()[0]->dtype)); } else { LOG(FATAL) << "Unimplemented"; } diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index b9ba4f4dc9af..9288db3eb545 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -592,10 +592,10 @@ def attention_bias_size(request): return request.param -def test_attention_bias_offload(attention_bias_size, attention_dtype): +def test_attention_bias_offload(attention_bias_size): b, (s, s_kv), n, (h, h_v), bias_shape, bias_reshape = attention_bias_size q, k, v, bias, ref = get_numpy_attention_ref( - b, s, s_kv, n, h, h_v, bias_shape, bias_reshape, "none", attention_dtype + b, s, s_kv, n, h, h_v, bias_shape, bias_reshape, "none", "float32" ) mod = get_relax_attention_module(q, k, v, bias) @@ -620,10 +620,10 @@ def attention_scale(request): return request.param -def test_attention_scale_offload(attention_scale_size, attention_scale, attention_dtype): +def test_attention_scale_offload(attention_scale_size, attention_scale): b, (s, s_kv), n, (h, h_v), bias_shape, bias_reshape = attention_scale_size q, k, v, bias, ref = get_numpy_attention_ref( - b, s, s_kv, n, h, h_v, bias_shape, bias_reshape, attention_scale, attention_dtype + b, s, s_kv, n, h, h_v, bias_shape, bias_reshape, attention_scale, "float32" ) mod = get_relax_attention_module(q, k, v, bias, attention_scale) @@ -634,5 +634,85 @@ def test_attention_scale_offload(attention_scale_size, attention_scale, attentio tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) +@memoize("topi.tests.test_codegen_cutlass.test_stacked_attention_offload") +def get_numpy_stacked_attention_ref(b, s, n, h, h_v, bias_shape, bias_reshape, qk_scale, dtype): + qkv = np.random.randn(b, s, n * h + n * h + n * h_v).astype(dtype) + split_qkv = np.split(qkv, [n * h, n * h * 2], axis=2) + q = np.reshape(split_qkv[0], (b, s, n, h)) + k = np.reshape(split_qkv[1], (b, s, n, h)) + v = np.reshape(split_qkv[2], (b, s, n, h_v)) + qt = q.transpose(0, 2, 1, 3) # b, n, s, h + kt = k.transpose(0, 2, 3, 1) # b, n, h, s + if not qk_scale == "none": + score = qt @ kt * qk_scale # b, n, s, s + else: + score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s + if not bias_shape == "none": + bias = np.random.randn(*bias_shape).astype(dtype) + score = score + bias.reshape(*bias_reshape) # b, n, s, s + else: + bias = None + attn = tvm.topi.testing.softmax_python(score, -1) + vt = v.transpose(0, 2, 1, 3) # b, n, s, h_v + ref = attn @ vt # b, n, s, h_v + return qkv, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v + + +def get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, bias=None, qk_scale=None): + dtype = str(qkv.dtype) + + from tvm.script.ir_builder import IRBuilder + from tvm.script.ir_builder import relax as relax_builder, tir as T + + if qk_scale is not None: + qk_scale = T.FloatImm("float32", qk_scale) + + with IRBuilder() as builder: + with relax_builder.function(): + R.func_name("main") + qkv = R.arg("qkv", R.Tensor(qkv.shape, dtype)) + if bias is not None: + bias = R.arg("bias", R.Tensor(bias.shape, dtype)) + with R.dataflow() as frame: + qkv_tuple = R.split(qkv, [n * h, n * h * 2], axis=2) + q = R.reshape(qkv_tuple[0], [b, s, n, h]) + k = R.reshape(qkv_tuple[1], [b, s, n, h]) + v = R.reshape(qkv_tuple[2], [b, s, n, h_v]) + result = R.emit(R.nn.attention(q, k, v, bias, qk_scale)) + R.output(result) + + R.func_ret_value(frame.output_vars[0]) + + func = builder.get() + return tvm.IRModule({"main": func}) + + +@pytest.fixture( + params=[ + # B, S, N, H, bias_shape, bias_reshape, scale + (4, 8, 32, (64, 32), "none", "none", "none"), + (4, 8, 32, (64, 32), (4, 32, 8, 8), (4, 32, 8, 8), 0.5), + ] +) +def stacked_attention_size(request): + return request.param + + +def test_stacked_attention_offload(stacked_attention_size): + b, s, n, (h, h_v), bias_shape, bias_reshape, scale = stacked_attention_size + qkv, bias, ref = get_numpy_stacked_attention_ref( + b, s, n, h, h_v, bias_shape, bias_reshape, scale, "float32" + ) + if scale == "none": + mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, bias) + else: + mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, bias, scale) + if bias is None: + out = get_result_with_relax_cutlass_offload(mod, qkv) + else: + out = get_result_with_relax_cutlass_offload(mod, qkv, bias) + tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) + + if __name__ == "__main__": tvm.testing.main()