diff --git a/3rdparty/cutlass b/3rdparty/cutlass index d8359c804b7e..92ebbf1dc461 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit d8359c804b7e3915a0f0668c19213f63ae88aac6 +Subproject commit 92ebbf1dc4612bf838ace6f2e6d262919f0abd63 diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py new file mode 100644 index 000000000000..9093a03dd6ed --- /dev/null +++ b/python/tvm/contrib/cutlass/attention_operation.py @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-wildcard-import, wildcard-import +"""Generator for CUTLASS attention kernels.""" +from .library import * + + +def instantiate_attention_template(attrs, func_args): + """Return CUTLASS host code for fused multi head attention + based on a template and the provided attribute map.""" + + bias_template = { + "B11S'": """ + CHECK(${arg3}->ndim == 2); // B, 1, 1, S' + + p.attn_bias_ptr = reinterpret_cast(${arg3}->data); + p.bias_strideM = 0; // 0 + p.bias_strideH = 0; // 0 + p.bias_strideB = p.num_keys; // S' +""", + "B1SS'": """ + CHECK(${arg3}->ndim == 3); // B, 1, S, S' + + p.attn_bias_ptr = reinterpret_cast(${arg3}->data); + p.bias_strideM = p.num_keys; // S' + p.bias_strideH = 0; // 0 + p.bias_strideB = p.bias_strideM * p.num_queries; // S' * S +""", + "BNSS'": """ + CHECK(${arg3}->ndim == 4); // B, N, S, S' + + p.attn_bias_ptr = reinterpret_cast(${arg3}->data); + p.bias_strideM = p.num_keys; // S' + p.bias_strideH = p.bias_strideM * p.num_queries; // S' * S + p.bias_strideB = p.bias_strideH * p.num_heads; // S' * S * N +""", + } + + template = """ + using T = ${data_type}; + + CHECK(${arg0}->ndim == 4); // B, S, N, H + CHECK(${arg1}->ndim == 4); // B, S', N, H + CHECK(${arg2}->ndim == 4); // B, S', N, H' + CHECK(out0->ndim == 4); // B, S, N, H' + + using Attention = + AttentionKernel; + + typename Attention::Params p; + + p.query_ptr = reinterpret_cast(${arg0}->data); + p.key_ptr = reinterpret_cast(${arg1}->data); + p.value_ptr = reinterpret_cast(${arg2}->data); + p.logsumexp_ptr = nullptr; + p.output_ptr = reinterpret_cast(out0->data); + p.output_accum_ptr = nullptr; + if (Attention::kNeedsOutputAccumulatorBuffer) { + cudaMalloc( + &p.output_accum_ptr, + ${output_size} * sizeof(Attention::output_accum_t) + ); + } + + p.num_heads = ${num_heads}; // N + p.num_batches = ${num_batches}; // B + p.head_dim = ${head_dim}; // H + p.head_dim_value = ${head_dim_value}; // H' + p.num_queries = ${num_queries}; // S + p.num_keys = ${num_keys}; // S' + p.scale = 1.0f / sqrt(float(${head_dim})); + + // stride for N + p.q_strideH = p.head_dim; // H + p.k_strideH = p.head_dim; // H + p.v_strideH = p.head_dim_value; // H' + + // stride for S + p.q_strideM = p.q_strideH * p.num_heads; // H * N + p.k_strideM = p.k_strideH * p.num_heads; // H * N + p.v_strideM = p.v_strideH * p.num_heads; // H' * N + p.o_strideM = p.head_dim_value * p.num_heads; // H' * N + + // stride for B + p.q_strideB = p.q_strideM * p.num_queries; // H * N * S + p.k_strideB = p.k_strideM * p.num_keys; // H * N * S' + p.v_strideB = p.v_strideM * p.num_keys; // H'* N * S' + + ${bias_template} + + constexpr auto kernel_fn = attention_kernel_batched_impl; + int smem_bytes = sizeof(typename Attention::SharedStorage); + if (smem_bytes > 0xc000) { + static bool once = [&]() { + cudaFuncSetAttribute( + kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + return true; + }(); + } + + CHECK(Attention::check_supported(p)); + kernel_fn<<>>(p); +""" + if attrs["kSupportsBias"]: + template = substitute_template( + template, {"bias_template": bias_template[attrs["bias_layout"]]} + ) + else: + template = substitute_template(template, {"bias_template": ""}) + for i, arg in enumerate(func_args): + attrs["arg{}".format(i)] = arg + return substitute_template(template, attrs) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 7e81113f4431..0e8d419baecf 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -57,6 +57,7 @@ def _get_cutlass_compile_options(sm, threads, use_fast_math=False): cutlass_root = _get_cutlass_path() cutlass_include = os.path.join(cutlass_root, "include") cutlass_util_include = os.path.join(cutlass_root, "tools/util/include") + cutlass_attention_include = os.path.join(cutlass_root, "examples/41_fused_multi_head_attention") kwargs = {} kwargs["cc"] = "nvcc" @@ -71,6 +72,7 @@ def _get_cutlass_compile_options(sm, threads, use_fast_math=False): "-std=c++17", "-I" + cutlass_include, "-I" + cutlass_util_include, + "-I" + cutlass_attention_include, ] if use_fast_math: kwargs["options"].append("-DCUTLASS_USE_TANH_FOR_SIGMOID") @@ -756,6 +758,49 @@ def handle_matmul(self, f, op_type): } ) + def handle_attention(self, f, op_type): + """Tune and annotate a dense op.""" + signature = _extract_relax_function_signature(f) + + q_shape = signature["arg0_shape"] + k_shape = signature["arg1_shape"] + v_shape = signature["arg2_shape"] + out_shape = signature["ret_shape"] + q_dtype = signature["arg0_dtype"] + k_dtype = signature["arg1_dtype"] + v_dtype = signature["arg2_dtype"] + out_dtype = signature["ret_dtype"] + num_batches, num_queries, num_heads, head_dim = q_shape + _, num_keys, _, _ = k_shape + _, _, _, head_dim_value = v_shape + bias = {} + if "arg3_dtype" in signature: + bias["arg3_dtype"] = signature["arg3_dtype"] + if "arg3_shape" in signature: + bias["arg3_shape"] = signature["arg3_shape"] + + return f.with_attrs( + { + "op_type": op_type, + "arg0_dtype": q_dtype, + "arg1_dtype": k_dtype, + "arg2_dtype": v_dtype, + "ret_dtype": out_dtype, + "arg0_shape": q_shape, + "arg1_shape": k_shape, + "arg2_shape": v_shape, + "ret_shape": out_shape, + "num_batches": num_batches, + "num_queries": num_queries, + "num_keys": num_keys, + "num_heads": num_heads, + "head_dim": head_dim, + "head_dim_value": head_dim_value, + "arch": self.options["sm"], + **bias, + } + ) + def visit_function_(self, f): if "Composite" not in f.attrs: body = super().visit_expr(f.body) @@ -767,6 +812,8 @@ def visit_function_(self, f): return self.handle_conv2d(f, op_type) elif "matmul" in op_type: return self.handle_matmul(f, op_type) + elif "attention" in op_type: + return self.handle_attention(f, op_type) raise ValueError("Unsupported composite {}".format(op_type)) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 2976946dd258..78e2b489c6fa 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -30,8 +30,10 @@ from . import _ffi_api as ffi from .conv2d_operation import instantiate_conv2d_template from .gemm_operation import instantiate_gemm_template +from .attention_operation import instantiate_attention_template from .library import ( DataType, + DataTypeSize, DataTypeTag, EpilogueFunctor, MathInstruction, @@ -549,7 +551,7 @@ def get_batch_stride(stride_annot, arg0_idx, arg1_idx, arg0_axis_idx, arg1_axis_ attrs["ElementInputB"] = DataTypeTag[dtype_map[annotations[f"arg{rhs_arg_idx}_dtype"]]] attrs["ElementOutput"] = DataTypeTag[dtype_map[annotations["ret_dtype"]]] - attrs["K"] = str(int(lhs_shape[lhs_batched_offset + 1])) + attrs["K"] = lhs_shape[lhs_batched_offset + 1] attrs["M"] = get_dim(lhs_shape[lhs_batched_offset], lhs_arg, 0, lhs_batched_offset) if transposed: @@ -630,27 +632,70 @@ def get_batch_stride(stride_annot, arg0_idx, arg1_idx, arg0_axis_idx, arg1_axis_ attrs["N"] = get_dim(activation_shape[0], activation_var, 0) attrs["H"] = get_dim(activation_shape[1], activation_var, 1) attrs["W"] = get_dim(activation_shape[2], activation_var, 2) - attrs["C"] = str(int(activation_shape[3])) + attrs["C"] = activation_shape[3] attrs["P"] = get_dim(output_shape[1], "out0", 1) attrs["Q"] = get_dim(output_shape[2], "out0", 2) - attrs["K"] = str(int(output_shape[3])) - attrs["R"] = str(int(weight_shape[1])) - attrs["S"] = str(int(weight_shape[2])) - attrs["pad_h"] = str(int(annotations["padding"][0])) - attrs["pad_w"] = str(int(annotations["padding"][1])) - attrs["stride_h"] = str(int(annotations["strides"][0])) - attrs["stride_w"] = str(int(annotations["strides"][1])) - attrs["dilation_h"] = str(int(annotations["dilation"][0])) - attrs["dilation_w"] = str(int(annotations["dilation"][1])) + attrs["K"] = output_shape[3] + attrs["R"] = weight_shape[1] + attrs["S"] = weight_shape[2] + attrs["pad_h"] = annotations["padding"][0] + attrs["pad_w"] = annotations["padding"][1] + attrs["stride_h"] = annotations["strides"][0] + attrs["stride_w"] = annotations["strides"][1] + attrs["dilation_h"] = annotations["dilation"][0] + attrs["dilation_w"] = annotations["dilation"][1] if "splitk" in op_name: attrs["split_k_mode"] = "kParallel" attrs["split_k_slices"] = str(re.search(r"splitk(\d+)", op_name).group(1)) else: attrs["split_k_mode"] = "kSerial" - attrs["split_k_slices"] = "1" + attrs["split_k_slices"] = 1 code = instantiate_conv2d_template(attrs, func_args) return CodegenResult(code, headers) + elif "attention" in func_name: + headers.append("kernel_forward.h") + data_type = dtype_map[annotations["arg0_dtype"]] + attrs["data_type"] = DataTypeTag[data_type] + attrs["num_batches"] = b = annotations["num_batches"] + attrs["num_queries"] = s = annotations["num_queries"] + attrs["num_keys"] = annotations["num_keys"] + attrs["num_heads"] = n = annotations["num_heads"] + attrs["head_dim"] = h = annotations["head_dim"] + attrs["head_dim_value"] = h_v = annotations["head_dim_value"] + data_type_size = DataTypeSize[data_type] + if (data_type_size * h // 8) % 16 == 0 and (data_type_size * h_v // 8) % 16 == 0: + attrs["kIsAligned"] = True + elif (h % 4 == 0) and (h_v % 4 == 0): + attrs["kIsAligned"] = False + else: + raise NotImplementedError() + if h_v > 64: + attrs["kQueriesPerBlock"] = 32 + attrs["kKeysPerBlock"] = 128 + attrs["kSingleValueIteration"] = h_v <= 128 + else: + attrs["kQueriesPerBlock"] = 64 + attrs["kKeysPerBlock"] = 64 + attrs["kSingleValueIteration"] = True + attrs["output_size"] = b * s * n * h_v + attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"]) + attrs["kSupportsDropout"] = False + if len(func_args) > 3: + attrs["kSupportsBias"] = True + if len(annotations["arg3_shape"]) == 4: + attrs["bias_layout"] = "BNSS'" + elif len(annotations["arg3_shape"]) == 3: + attrs["bias_layout"] = "B1SS'" + elif len(annotations["arg3_shape"]) == 2: + attrs["bias_layout"] = "B11S'" + else: + raise NotImplementedError() + else: + attrs["kSupportsBias"] = False + code = instantiate_attention_template(attrs, func_args) + return CodegenResult(code, headers) + raise ValueError("Do not have a template for {}".format(func_name)) diff --git a/python/tvm/contrib/cutlass/library.py b/python/tvm/contrib/cutlass/library.py index 8632ab15641d..b72553ef6052 100644 --- a/python/tvm/contrib/cutlass/library.py +++ b/python/tvm/contrib/cutlass/library.py @@ -20,6 +20,8 @@ import enum from enum import auto as enum_auto +from tvm.tir.expr import IntImm + class GeneratorTarget(enum.Enum): Library = enum_auto() @@ -143,6 +145,10 @@ def substitute_template(template, values): while changed: changed = False for key, value in values.items(): + if isinstance(value, (int, IntImm)): + value = str(int(value)) + elif isinstance(value, bool): + value = str(value).lower() regex = "\\$\\{%s\\}" % key newtext = re.sub(regex, value, text) if newtext != text: diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index 2d8908184bd4..19165fa8329f 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -25,7 +25,11 @@ from tvm.relax.dpl import DFPattern from ..pattern_registry import get_patterns_with_prefix, register_patterns -from ..patterns import make_fused_bias_activation_pattern, make_matmul_pattern +from ..patterns import ( + make_fused_bias_activation_pattern, + make_matmul_pattern, + make_attention_pattern, +) def _get_static_shape(shape: ShapeExpr) -> Optional[Tuple[int]]: @@ -157,6 +161,14 @@ def _check_matmul( ), _check_matmul, ), + ( + "cutlass.attention", + *make_attention_pattern(), + ), + ( + "cutlass.attention_bias", + *make_attention_pattern(with_bias=True), + ), ] ) diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py index 2f744af66002..a2ea803d9d92 100644 --- a/python/tvm/relax/backend/patterns.py +++ b/python/tvm/relax/backend/patterns.py @@ -113,3 +113,30 @@ def make_matmul_pattern( out = is_op("relax.matmul")(lhs, rhs) return _with_bias_activation_pattern(out, args, with_bias, activation) + + +def make_attention_pattern(with_bias: bool = False): + """ + Create pattern for fused multi head attention. + + Returns + ------- + pattern: DFPattern + The resulting pattern describing a fused multi head attention. + + args: Mapping[str, DFPattern] + The mapping from arg name to its pattern. It can be used to extract + arg expression from match result. + """ + query = wildcard() + key = wildcard() + value = wildcard() + args = {"query": query, "key": key, "value": value} + if with_bias: + bias = wildcard() + args["bias"] = bias + out = is_op("relax.nn.attention_bias")(query, key, value, bias) + else: + out = is_op("relax.nn.attention")(query, key, value) + + return out, args diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 0ff143fd045b..2fef37249703 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -577,3 +577,42 @@ def cross_entropy_with_logits(predictions: Expr, labels: Expr) -> Expr: The computed result. """ return _ffi_api.cross_entropy_with_logits(predictions, labels) # type: ignore + + +def attention(query: Expr, key: Expr, value: Expr, bias: Optional[Expr] = None) -> Expr: + r"""Computes fused multi head attention. + + All input tensors are of 4-D tensors with BSNH layout. + + .. math:: + FMA(Q, K, V) = \text{Softmax}(Q @ K^T) @ V + + .. note:: + The input tensor is required to have float16 dtype + + Parameters + ---------- + query: relax.Expr + The input query to the operator. The layout of the input query should be + (batch_size, seq_len, num_head, head_dim). + + key: relax.Expr + The input key to the operator. The layout of the input key should be + (batch_size, seq_len_kv, num_head, head_dim). + + value: relax.Expr + The input value to the operator. The layout of the input value should be + (batch_size, seq_len_kv, num_head, head_dim_v). + + bias: Optional[Expr] + The optional attention bias to the operator. The layout of the attention bias should be + (batch_size, num_head, seq_len, seq_len_kv), + (batch_size, seq_len, seq_len_kv) or (batch_size, seq_len_kv). + + Returns + ------- + result : relax.Expr + The computed result. The layout of the output should be + (batch_size, seq_len, num_head, head_dim_v). + """ + return _ffi_api.attention(query, key, value, bias) # type: ignore diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc new file mode 100644 index 000000000000..e139aa09d692 --- /dev/null +++ b/src/relax/op/nn/attention.cc @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "attention.h" + +#include +#include + +namespace tvm { +namespace relax { + +/* relax.nn.attention */ +Expr attention(Expr query, Expr key, Expr value, Optional bias) { + if (bias.defined()) { + return Call(Op::Get("relax.nn.attention_bias"), + {std::move(query), std::move(key), std::move(value), std::move(bias.value())}, {}, + {}); + } + return Call(Op::Get("relax.nn.attention"), {std::move(query), std::move(key), std::move(value)}, + {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.attention").set_body_typed(attention); + +StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo q_sinfo = input_sinfo[0]; + TensorStructInfo k_sinfo = input_sinfo[1]; + TensorStructInfo v_sinfo = input_sinfo[2]; + auto diag_dim = [&](TensorStructInfo sinfo, String name) { + if (sinfo->ndim != 4) { + ctx->ReportFatal(Diagnostic::Error(call) + << "The " << name << " should have 4 dimension, namely " + << "[batch size, sequence length, number of heads, dimension of heads]."); + } + }; + diag_dim(q_sinfo, "query"); + diag_dim(k_sinfo, "key"); + diag_dim(v_sinfo, "value"); + const ShapeExprNode* q_shape = q_sinfo->shape.as(); + const ShapeExprNode* k_shape = k_sinfo->shape.as(); + const ShapeExprNode* v_shape = v_sinfo->shape.as(); + PrimExpr num_batches = q_shape->values[0]; + PrimExpr num_queries = q_shape->values[1]; + PrimExpr num_heads = q_shape->values[2]; + PrimExpr head_dim = q_shape->values[3]; + PrimExpr num_keys = k_shape->values[1]; + PrimExpr head_dim_value = v_shape->values[3]; + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + auto diag_equal = [&](PrimExpr v1, PrimExpr v2, String m1, String m2, String dim) { + if (analyzer->CanProve(v1 != v2)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "The " << m1 << " " << dim << " and the " << m2 << " " << dim + << " should be the same. However, the " << dim << " of " << m1 << " is " + << v1 << " while the " << dim << " of " << m2 << " is " << v2); + } + }; + diag_equal(num_batches, k_shape->values[0], "query", "key", "batch size"); + diag_equal(num_batches, v_shape->values[0], "query", "value", "batch size"); + diag_equal(num_heads, k_shape->values[2], "query", "key", "number of heads"); + diag_equal(num_heads, v_shape->values[2], "query", "value", "number of heads"); + diag_equal(num_keys, v_shape->values[1], "key", "value", "sequence length"); + diag_equal(head_dim, k_shape->values[3], "query", "key", "dimension of heads"); + + if (input_sinfo.size() == 4) { + TensorStructInfo bias_sinfo = input_sinfo[3]; + const ShapeExprNode* bias_shape = bias_sinfo->shape.as(); + if (bias_sinfo->ndim == 4) { + diag_equal(num_batches, bias_shape->values[0], "query", "bias", "batch size"); + diag_equal(num_heads, bias_shape->values[1], "query", "bias", "number of heads"); + diag_equal(num_queries, bias_shape->values[2], "query", "bias", "sequence length"); + diag_equal(num_keys, bias_shape->values[3], "key", "bias", "sequence length"); + } else if (bias_sinfo->ndim == 3) { + diag_equal(num_batches, bias_shape->values[0], "query", "bias", "batch size"); + diag_equal(num_queries, bias_shape->values[1], "query", "bias", "sequence length"); + diag_equal(num_keys, bias_shape->values[2], "key", "bias", "sequence length"); + } else if (bias_sinfo->ndim == 2) { + diag_equal(num_batches, bias_shape->values[0], "query", "bias", "batch size"); + diag_equal(num_keys, bias_shape->values[1], "key", "bias", "sequence length"); + } else { + ctx->ReportFatal(Diagnostic::Error(call) + << "The bias should have 2, 3 or 4 dimensions." + << "However, the bias input has " << bias_sinfo->ndim << " dimensions."); + } + } + + Array output_shape = {num_batches, num_queries, num_heads, head_dim_value}; + return TensorStructInfo(ShapeExpr(output_shape), q_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.nn.attention") + .set_num_inputs(3) + .add_argument("query", "Tensor", "The input queries tensor.") + .add_argument("key", "Tensor", "The input keys tensor.") + .add_argument("value", "Tensor", "The input values tensor.") + .set_attr("FInferStructInfo", InferStructInfoAttention); + +TVM_REGISTER_OP("relax.nn.attention_bias") + .set_num_inputs(4) + .add_argument("query", "Tensor", "The input queries tensor.") + .add_argument("key", "Tensor", "The input keys tensor.") + .add_argument("value", "Tensor", "The input values tensor.") + .add_argument("bias", "Tensor", "The input bias tensor.") + .set_attr("FInferStructInfo", InferStructInfoAttention); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/nn/attention.h b/src/relax/op/nn/attention.h new file mode 100644 index 000000000000..662e0b7e7b81 --- /dev/null +++ b/src/relax/op/nn/attention.h @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file attention.h + * \brief The functions to make Relax attention operator calls. + */ + +#ifndef TVM_RELAX_OP_NN_ATTENTION_H_ +#define TVM_RELAX_OP_NN_ATTENTION_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! \brief fused multi head attention */ +Expr attention(Expr query, Expr key, Expr value, Optional bias); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_NN_ATTENTION_H_ diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 83104d6fe16c..36a1c4cd16ff 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -20,6 +20,8 @@ import tvm import tvm.testing +import tvm.topi.testing +from tvm.contrib.pickle_memoize import memoize from tvm import relax, relay from tvm.contrib.cutlass.build import is_valid_for_cutlass_matmul from tvm.relax.backend import get_patterns_with_prefix @@ -300,5 +302,150 @@ def test_cutlass_partition_matmul_blocked(x_shape, y_shape, transpose_y, dtype): tvm.ir.assert_structural_equal(mod, partition_for_cutlass(mod)) +@pytest.fixture(params=["float16", "float32"]) +def attention_dtype(request): + return request.param + + +@pytest.fixture( + params=[ + # B, S, N, H + (32, (8, 8), 16, (8, 8)), + (4, (16, 8), 32, (8, 8)), # s != s_kv + (4, (16, 8), 32, (8, 16)), # h != h_v + (32, (8, 8), 16, (4, 4)), # h is not aligned + (2, (8, 8), 8, (256, 256)), # needs output accumulator buffer + ] +) +def attention_size(request): + return request.param + + +def get_relax_attention_module(q, k, v, bias=None): + dtype = str(q.dtype) + + from tvm.script.ir_builder import IRBuilder + from tvm.script.ir_builder import relax as relax_builder + + with IRBuilder() as builder: + with relax_builder.function(): + R.func_name("main") + q = R.arg("q", R.Tensor(q.shape, dtype)) + k = R.arg("k", R.Tensor(k.shape, dtype)) + v = R.arg("v", R.Tensor(v.shape, dtype)) + if bias is not None: + bias = R.arg("bias", R.Tensor(bias.shape, dtype)) + with R.dataflow() as frame: + result = R.emit(R.nn.attention(q, k, v, bias)) + R.output(result) + + R.func_ret_value(frame.output_vars[0]) + + func = builder.get() + return tvm.IRModule({"main": func}) + + +@memoize("topi.tests.test_codegen_cutlass.test_attention_offload") +def get_numpy_attention_ref(b, s, s_kv, n, h, h_v, dtype): + q = np.random.randn(b, s, n, h).astype(dtype) + k = np.random.randn(b, s_kv, n, h).astype(dtype) + v = np.random.randn(b, s_kv, n, h_v).astype(dtype) + qt = q.transpose(0, 2, 1, 3) # b, n, s, h + kt = k.transpose(0, 2, 3, 1) # b, n, h, s_kv + score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv + attn = tvm.topi.testing.softmax_python(score, -1) + vt = v.transpose(0, 2, 1, 3) # b, n, s_kv, h_v + ref = attn @ vt # b, n, s, h_v + return q, k, v, ref.transpose(0, 2, 1, 3) # b, s, n, h_v + + +def test_attention_offload(attention_size, attention_dtype): + b, (s, s_kv), n, (h, h_v) = attention_size + q, k, v, ref = get_numpy_attention_ref(b, s, s_kv, n, h, h_v, attention_dtype) + + mod = get_relax_attention_module(q, k, v) + out = get_result_with_relax_cutlass_offload(mod, q, k, v) + + tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) + + +@memoize("topi.tests.test_codegen_cutlass.test_attention_bias_4d_offload") +def get_numpy_attention_bias_4d_ref(b, s, s_kv, n, h, h_v, dtype): + q = np.random.randn(b, s, n, h).astype(dtype) + k = np.random.randn(b, s_kv, n, h).astype(dtype) + v = np.random.randn(b, s_kv, n, h_v).astype(dtype) + bias = np.random.randn(b, n, s, s_kv).astype(dtype) + qt = q.transpose(0, 2, 1, 3) # b, n, s, h + kt = k.transpose(0, 2, 3, 1) # b, n, h, s_kv + score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv + score_bias = score + bias # b, n, s, s_kv + attn = tvm.topi.testing.softmax_python(score_bias, -1) + vt = v.transpose(0, 2, 1, 3) # b, n, s_kv, h_v + ref = attn @ vt # b, n, s, h_v + return q, k, v, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v + + +def test_attention_bias_4d_offload(attention_size, attention_dtype): + b, (s, s_kv), n, (h, h_v) = attention_size + q, k, v, bias, ref = get_numpy_attention_bias_4d_ref(b, s, s_kv, n, h, h_v, attention_dtype) + + mod = get_relax_attention_module(q, k, v, bias) + out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias) + + tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) + + +@memoize("topi.tests.test_codegen_cutlass.test_attention_bias_3d_offload") +def get_numpy_attention_bias_3d_ref(b, s, s_kv, n, h, h_v, dtype): + q = np.random.randn(b, s, n, h).astype(dtype) + k = np.random.randn(b, s_kv, n, h).astype(dtype) + v = np.random.randn(b, s_kv, n, h_v).astype(dtype) + bias = np.random.randn(b, s, s_kv).astype(dtype) + qt = q.transpose(0, 2, 1, 3) # b, n, s, h + kt = k.transpose(0, 2, 3, 1) # b, n, h, s_kv + score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv + score_bias = score + bias.reshape(b, 1, s, s_kv) # b, n, s, s_kv + attn = tvm.topi.testing.softmax_python(score_bias, -1) + vt = v.transpose(0, 2, 1, 3) # b, n, s_kv, h_v + ref = attn @ vt # b, n, s, h_v + return q, k, v, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v + + +def test_attention_bias_3d_offload(attention_size, attention_dtype): + b, (s, s_kv), n, (h, h_v) = attention_size + q, k, v, bias, ref = get_numpy_attention_bias_3d_ref(b, s, s_kv, n, h, h_v, attention_dtype) + + mod = get_relax_attention_module(q, k, v, bias) + out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias) + + tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) + + +@memoize("topi.tests.test_codegen_cutlass.test_attention_bias_2d_offload") +def get_numpy_attention_bias_2d_ref(b, s, s_kv, n, h, h_v, dtype): + q = np.random.randn(b, s, n, h).astype(dtype) + k = np.random.randn(b, s_kv, n, h).astype(dtype) + v = np.random.randn(b, s_kv, n, h_v).astype(dtype) + bias = np.random.randn(b, s_kv).astype(dtype) + qt = q.transpose(0, 2, 1, 3) # b, n, s, h + kt = k.transpose(0, 2, 3, 1) # b, n, h, s_kv + score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv + score_bias = score + bias.reshape(b, 1, 1, s_kv) # b, n, s, s_kv + attn = tvm.topi.testing.softmax_python(score_bias, -1) + vt = v.transpose(0, 2, 1, 3) # b, n, s_kv, h_v + ref = attn @ vt # b, n, s, h_v + return q, k, v, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v + + +def test_attention_bias_2d_offload(attention_size, attention_dtype): + b, (s, s_kv), n, (h, h_v) = attention_size + q, k, v, bias, ref = get_numpy_attention_bias_2d_ref(b, s, s_kv, n, h, h_v, attention_dtype) + + mod = get_relax_attention_module(q, k, v, bias) + out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias) + + tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) + + if __name__ == "__main__": tvm.testing.main()