From 04f177038b7109c849058aebdfdcec13b3b5f4ae Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Tue, 28 Feb 2023 15:58:19 -0800 Subject: [PATCH 01/11] [Unity][OP] Add an operator for fused multi head attention This PR introduces the new relax operator `R.nn.attention` for fused multi head attention, and the support of fused multi head attention to relax cutlass BYOC. The input of the operator are query, key and value tensor, with `BSNH` layout, namely `[batch size, sequence length, number of heads, dimension of heads]`. And the output shares the same layout with all input tensor. --- 3rdparty/cutlass | 2 +- include/tvm/relax/attrs/nn.h | 9 ++ .../contrib/cutlass/attention_operation.py | 95 ++++++++++++++++++ python/tvm/contrib/cutlass/build.py | 41 ++++++++ python/tvm/contrib/cutlass/gen_tensor_op.py | 22 +++++ python/tvm/relax/backend/contrib/cutlass.py | 10 +- python/tvm/relax/backend/patterns.py | 21 ++++ python/tvm/relax/op/nn/nn.py | 40 ++++++++ src/relax/op/nn/attention.cc | 97 +++++++++++++++++++ src/relax/op/nn/attention.h | 41 ++++++++ tests/python/relax/test_codegen_cutlass.py | 75 ++++++++++++++ 11 files changed, 451 insertions(+), 2 deletions(-) create mode 100644 python/tvm/contrib/cutlass/attention_operation.py create mode 100644 src/relax/op/nn/attention.cc create mode 100644 src/relax/op/nn/attention.h diff --git a/3rdparty/cutlass b/3rdparty/cutlass index d8359c804b7e..92ebbf1dc461 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit d8359c804b7e3915a0f0668c19213f63ae88aac6 +Subproject commit 92ebbf1dc4612bf838ace6f2e6d262919f0abd63 diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 694a51070683..6a5853ad485c 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -184,6 +184,15 @@ struct DropoutAttrs : public tvm::AttrsNode { } }; // struct DropoutAttrs +/*! \brief Attributes used in fuse multi head attention operator */ +struct AttentionAttrs : public tvm::AttrsNode { + DataType out_dtype; + + TVM_DECLARE_ATTRS(AttentionAttrs, "relax.attrs.AttentionAttrs") { + TVM_ATTR_FIELD(out_dtype).describe("The data type of the output tensor"); + } +}; // struct AttentionAttrs + } // namespace relax } // namespace tvm diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py new file mode 100644 index 000000000000..a0f1bbcb8b2f --- /dev/null +++ b/python/tvm/contrib/cutlass/attention_operation.py @@ -0,0 +1,95 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-wildcard-import, wildcard-import +"""Generator for CUTLASS attention kernels.""" +from .library import * + + +def instantiate_attention_template(attrs, func_args): + """Return CUTLASS host code for fused multi head attention + based on a template and the provided attribute map.""" + + template = """ + using T = cutlass::half_t; + + CHECK(${arg0}->ndim == 4); // B, S, N, H + CHECK(${arg1}->ndim == 4); // B, S', N, H + CHECK(${arg2}->ndim == 4); // B, S', N, H' + CHECK(out0->ndim == 4); // B, S, N, H' + + using Attention = + AttentionKernel; + + typename Attention::Params p; + + p.query_ptr = reinterpret_cast(${arg0}->data); + p.key_ptr = reinterpret_cast(${arg1}->data); + p.value_ptr = reinterpret_cast(${arg2}->data); + p.logsumexp_ptr = nullptr; + p.output_ptr = reinterpret_cast(out0->data); + static_assert(!Attention::kNeedsOutputAccumulatorBuffer); + p.output_accum_ptr = nullptr; + + p.num_heads = ${num_heads}; // N + p.num_batches = ${num_batches}; // B + p.head_dim = ${head_dim}; // H + p.head_dim_value = ${head_dim_value}; // H' + p.num_queries = ${num_queries}; // S + p.num_keys = ${num_keys}; // S' + p.scale = 1.0f / sqrt(float(${head_dim})); + // p.causal = false; + + // stride for N + p.q_strideH = p.head_dim; // H + p.k_strideH = p.head_dim; // H + p.v_strideH = p.head_dim_value; // H' + // p.o_strideH = p.head_dim_value; // H' + + // stride for S + p.q_strideM = p.q_strideH * p.num_heads; // H * N + p.k_strideM = p.k_strideH * p.num_heads; // H * N + p.v_strideM = p.v_strideH * p.num_heads; // H' * N + p.o_strideM = p.head_dim_value * p.num_heads; // H' * N + + // stride for B + p.q_strideB = p.q_strideM * p.num_queries; // H * N * S + p.k_strideB = p.k_strideM * p.num_keys; // H * N * S' + p.v_strideB = p.v_strideM * p.num_keys; // H'* N * S' + // p.o_strideB = p.o_strideM * p.num_queries; // H'* N * S + + constexpr auto kernel_fn = attention_kernel_batched_impl; + int smem_bytes = sizeof(typename Attention::SharedStorage); + if (smem_bytes > 0xc000) { + static bool once = [&]() { + cudaFuncSetAttribute( + kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + return true; + }(); + } + + CHECK(Attention::check_supported(p)); + kernel_fn<<>>(p); +""" + for i, arg in enumerate(func_args): + attrs["arg{}".format(i)] = arg + return substitute_template(template, attrs) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 7e81113f4431..c86df6e82978 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -57,6 +57,7 @@ def _get_cutlass_compile_options(sm, threads, use_fast_math=False): cutlass_root = _get_cutlass_path() cutlass_include = os.path.join(cutlass_root, "include") cutlass_util_include = os.path.join(cutlass_root, "tools/util/include") + cutlass_attention_include = os.path.join(cutlass_root, "examples/41_fused_multi_head_attention") kwargs = {} kwargs["cc"] = "nvcc" @@ -71,6 +72,7 @@ def _get_cutlass_compile_options(sm, threads, use_fast_math=False): "-std=c++17", "-I" + cutlass_include, "-I" + cutlass_util_include, + "-I" + cutlass_attention_include, ] if use_fast_math: kwargs["options"].append("-DCUTLASS_USE_TANH_FOR_SIGMOID") @@ -756,6 +758,43 @@ def handle_matmul(self, f, op_type): } ) + def handle_attention(self, f, op_type): + """Tune and annotate a dense op.""" + signature = _extract_relax_function_signature(f) + + q_shape = signature["arg0_shape"] + k_shape = signature["arg1_shape"] + v_shape = signature["arg2_shape"] + out_shape = signature["ret_shape"] + q_dtype = signature["arg0_dtype"] + k_dtype = signature["arg1_dtype"] + v_dtype = signature["arg2_dtype"] + out_dtype = signature["ret_dtype"] + num_batches, num_queries, num_heads, head_dim = q_shape + _, num_keys, _, _ = k_shape + _, _, _, head_dim_value = v_shape + + return f.with_attrs( + { + "op_type": op_type, + "arg0_dtype": q_dtype, + "arg1_dtype": k_dtype, + "arg2_dtype": v_dtype, + "ret_dtype": out_dtype, + "arg0_shape": q_shape, + "arg1_shape": k_shape, + "arg2_shape": v_shape, + "ret_shape": out_shape, + "num_batches": num_batches, + "num_queries": num_queries, + "num_keys": num_keys, + "num_heads": num_heads, + "head_dim": head_dim, + "head_dim_value": head_dim_value, + "arch": self.options["sm"], + } + ) + def visit_function_(self, f): if "Composite" not in f.attrs: body = super().visit_expr(f.body) @@ -767,6 +806,8 @@ def visit_function_(self, f): return self.handle_conv2d(f, op_type) elif "matmul" in op_type: return self.handle_matmul(f, op_type) + elif "attention" in op_type: + return self.handle_attention(f, op_type) raise ValueError("Unsupported composite {}".format(op_type)) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 2976946dd258..4dddbb56de0d 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -30,6 +30,7 @@ from . import _ffi_api as ffi from .conv2d_operation import instantiate_conv2d_template from .gemm_operation import instantiate_gemm_template +from .attention_operation import instantiate_attention_template from .library import ( DataType, DataTypeTag, @@ -653,4 +654,25 @@ def get_batch_stride(stride_annot, arg0_idx, arg1_idx, arg0_axis_idx, arg1_axis_ code = instantiate_conv2d_template(attrs, func_args) return CodegenResult(code, headers) + elif "attention" in func_name: + headers.append("kernel_forward.h") + attrs["num_batches"] = str(int(annotations["num_batches"])) + attrs["num_queries"] = str(int(annotations["num_queries"])) + attrs["num_keys"] = str(int(annotations["num_keys"])) + attrs["num_heads"] = str(int(annotations["num_heads"])) + attrs["head_dim"] = str(int(annotations["head_dim"])) + h_v = int(annotations["head_dim_value"]) + attrs["head_dim_value"] = str(h_v) + if h_v > 64: + attrs["kQueriesPerBlock"] = "32" + attrs["kKeysPerBlock"] = "128" + attrs["kSingleValueIteration"] = "true" if h_v <= 128 else "false" + else: + attrs["kQueriesPerBlock"] = "64" + attrs["kKeysPerBlock"] = "64" + attrs["kSingleValueIteration"] = "true" + attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"]) + code = instantiate_attention_template(attrs, func_args) + return CodegenResult(code, headers) + raise ValueError("Do not have a template for {}".format(func_name)) diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index 2d8908184bd4..fed47aa6ecda 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -25,7 +25,11 @@ from tvm.relax.dpl import DFPattern from ..pattern_registry import get_patterns_with_prefix, register_patterns -from ..patterns import make_fused_bias_activation_pattern, make_matmul_pattern +from ..patterns import ( + make_fused_bias_activation_pattern, + make_matmul_pattern, + make_attention_pattern, +) def _get_static_shape(shape: ShapeExpr) -> Optional[Tuple[int]]: @@ -157,6 +161,10 @@ def _check_matmul( ), _check_matmul, ), + ( + "cutlass.attention", + make_attention_pattern(), + ), ] ) diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py index 2f744af66002..85245767481b 100644 --- a/python/tvm/relax/backend/patterns.py +++ b/python/tvm/relax/backend/patterns.py @@ -113,3 +113,24 @@ def make_matmul_pattern( out = is_op("relax.matmul")(lhs, rhs) return _with_bias_activation_pattern(out, args, with_bias, activation) + + +def make_attention_pattern(): + """ + Create pattern for fused multi head attention. + + Returns + ------- + pattern: DFPattern + The resulting pattern describing a fused multi head attention. + + args: Mapping[str, DFPattern] + The mapping from arg name to its pattern. It can be used to extract + arg expression from match result. + """ + query = wildcard() + key = wildcard() + value = wildcard() + out = is_op("relax.nn.attention")(query, key, value) + + return out diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 0ff143fd045b..347f63ec90af 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -577,3 +577,43 @@ def cross_entropy_with_logits(predictions: Expr, labels: Expr) -> Expr: The computed result. """ return _ffi_api.cross_entropy_with_logits(predictions, labels) # type: ignore + + +def attention( + query: Expr, key: Expr, value: Expr, out_dtype: Optional[Union[str, DataType]] = None +) -> Expr: + r"""Computes fused multi head attention. + + All input tensors are of 4-D tensors with BSNH layout. + + .. math:: + FMA(Q, K, V) = \text{Softmax}(Q @ K^T) @ V + + .. note:: + The input tensor is required to have float16 dtype + + Parameters + ---------- + query: relax.Expr + The input query to the operator. The layout of the input query should be + (batch_size, seq_len, num_head, head_dim). + + key: relax.Expr + The input key to the operator. The layout of the input key should be + (batch_size, seq_len_kv, num_head, head_dim). + + value: relax.Expr + The input value to the operator. The layout of the input value should be + (batch_size, seq_len_kv, num_head, head_dim_v). + + out_dtype: Optional[Union[str, DataType]] + The data type of the attention result. + When it is not specified, the output dtype will be the the same as input dtype. + + Returns + ------- + result : relax.Expr + The computed result. The layout of the output should be + (batch_size, seq_len, num_head, head_dim_v). + """ + return _ffi_api.attention(query, key, value, out_dtype) # type: ignore diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc new file mode 100644 index 000000000000..c6e08177c683 --- /dev/null +++ b/src/relax/op/nn/attention.cc @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "attention.h" + +#include +#include + +namespace tvm { +namespace relax { + +/* relax.nn.attention */ +TVM_REGISTER_NODE_TYPE(AttentionAttrs); + +Expr attention(Expr query, Expr key, Expr value, DataType out_dtype) { + ObjectPtr attrs = make_object(); + attrs->out_dtype = out_dtype; + static const Op& op = Op::Get("relax.nn.attention"); + return Call(op, {std::move(query), std::move(key), std::move(value)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.attention").set_body_typed(attention); + +StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo q_sinfo = input_sinfo[0]; + TensorStructInfo k_sinfo = input_sinfo[1]; + TensorStructInfo v_sinfo = input_sinfo[2]; + auto diag_dim = [&](TensorStructInfo sinfo, String name) { + if (sinfo->ndim != 4) { + ctx->ReportFatal(Diagnostic::Error(call) + << "The " << name << " should have 4 dimension, namely " + << "[batch size, sequence length, number of heads, dimension of heads]."); + } + }; + diag_dim(q_sinfo, "query"); + diag_dim(k_sinfo, "key"); + diag_dim(v_sinfo, "value"); + const auto* attrs = call->attrs.as(); + DataType out_dtype = attrs->out_dtype.is_void() + ? InferBinaryArithOpOutDtype(call, ctx, q_sinfo, k_sinfo) + : attrs->out_dtype; + const ShapeExprNode* q_shape = q_sinfo->shape.as(); + const ShapeExprNode* k_shape = k_sinfo->shape.as(); + const ShapeExprNode* v_shape = v_sinfo->shape.as(); + PrimExpr num_batches = q_shape->values[0]; + PrimExpr num_queries = q_shape->values[1]; + PrimExpr num_heads = q_shape->values[2]; + PrimExpr head_dim = q_shape->values[3]; + PrimExpr num_keys = k_shape->values[1]; + PrimExpr head_dim_value = v_shape->values[3]; + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + auto diag_equal = [&](PrimExpr v1, PrimExpr v2, String m1, String m2, String dim) { + if (analyzer->CanProve(v1 != v2)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "The " << m1 << " " << dim << " and the " << m2 << " " << dim + << " should be the same. However, the " << dim << " of " << m1 << " is " + << v1 << " while the " << dim << " of " << m2 << " is " << v2); + } + }; + diag_equal(num_batches, k_shape->values[0], "query", "key", "batch size"); + diag_equal(num_batches, v_shape->values[0], "query", "value", "batch size"); + diag_equal(num_heads, k_shape->values[2], "query", "key", "number of heads"); + diag_equal(num_heads, v_shape->values[2], "query", "value", "number of heads"); + diag_equal(k_shape->values[1], v_shape->values[1], "key", "value", "sequence length"); + diag_equal(head_dim, k_shape->values[3], "query", "key", "dimension of heads"); + + Array output_shape = {num_batches, num_queries, num_heads, head_dim_value}; + return TensorStructInfo(ShapeExpr(output_shape), out_dtype); +} + +TVM_REGISTER_OP("relax.nn.attention") + .set_num_inputs(3) + .add_argument("query", "Tensor", "The input queries tensor.") + .add_argument("key", "Tensor", "The input keys tensor.") + .add_argument("value", "Tensor", "The input values tensor.") + .set_attrs_type() + .set_attr("FInferStructInfo", InferStructInfoAttention); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/nn/attention.h b/src/relax/op/nn/attention.h new file mode 100644 index 000000000000..b32c32bac104 --- /dev/null +++ b/src/relax/op/nn/attention.h @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file attention.h + * \brief The functions to make Relax attention operator calls. + */ + +#ifndef TVM_RELAX_OP_NN_ATTENTION_H_ +#define TVM_RELAX_OP_NN_ATTENTION_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! \brief fused multi head attention */ +Expr attention(Expr query, Expr key, Expr value, DataType out_dtype); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_NN_ATTENTION_H_ diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 83104d6fe16c..e9d3bfa9a715 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -20,6 +20,7 @@ import tvm import tvm.testing +import tvm.topi.testing from tvm import relax, relay from tvm.contrib.cutlass.build import is_valid_for_cutlass_matmul from tvm.relax.backend import get_patterns_with_prefix @@ -300,5 +301,79 @@ def test_cutlass_partition_matmul_blocked(x_shape, y_shape, transpose_y, dtype): tvm.ir.assert_structural_equal(mod, partition_for_cutlass(mod)) +@pytest.fixture( + params=[ + # B, S, N, H + (32, (4, 4), 16, (8, 8)), + (4, (8, 4), 32, (8, 8)), + (4, (8, 4), 32, (8, 16)), + ] +) +def attention_size(request): + return request.param + + +@pytest.fixture +def attention_q(attention_size, target_dtype): + b, (s, _), n, (h, _) = attention_size + return np.random.randn(b, s, n, h).astype(target_dtype) + + +@pytest.fixture +def attention_k(attention_size, target_dtype): + b, (_, s), n, (h, _) = attention_size + return np.random.randn(b, s, n, h).astype(target_dtype) + + +@pytest.fixture +def attention_v(attention_size, target_dtype): + b, (_, s), n, (_, h) = attention_size + return np.random.randn(b, s, n, h).astype(target_dtype) + + +def get_relax_attention_module(q, k, v): + dtype = str(q.dtype) + + from tvm.script.ir_builder import IRBuilder + from tvm.script.ir_builder import relax as relax_builder + + with IRBuilder() as builder: + with relax_builder.function(): + R.func_name("main") + q = R.arg("q", R.Tensor(q.shape, dtype)) + k = R.arg("k", R.Tensor(k.shape, dtype)) + v = R.arg("v", R.Tensor(v.shape, dtype)) + + with R.dataflow() as frame: + result = R.emit(R.nn.attention(q, k, v)) + R.output(result) + + R.func_ret_value(frame.output_vars[0]) + + func = builder.get() + return tvm.IRModule({"main": func}) + + +def get_numpy_attention_ref(q, k, v): + qt = q.transpose(0, 2, 1, 3) # b, n, s, h + kt = k.transpose(0, 2, 3, 1) # b, n, h, s + score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s + attn = tvm.topi.testing.softmax_python(score, -1) + vt = v.transpose(0, 2, 1, 3) # b, n, s, h + ref = attn @ vt # b, n, s, h + return ref.transpose(0, 2, 1, 3) # b, s, n, h + + +def test_attention_offload(attention_q, attention_k, attention_v): + q, k, v = attention_q, attention_k, attention_v + + mod = get_relax_attention_module(q, k, v) + out = get_result_with_relax_cutlass_offload(mod, q, k, v) + + ref = get_numpy_attention_ref(q, k, v) + + tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) + + if __name__ == "__main__": tvm.testing.main() From 54d43d93fc218b91047ad530c9924498355aab1c Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Wed, 1 Mar 2023 01:03:59 -0800 Subject: [PATCH 02/11] remove useless codes, remove attrs and add memoize --- include/tvm/relax/attrs/nn.h | 9 ---- .../contrib/cutlass/attention_operation.py | 2 - src/relax/op/nn/attention.cc | 13 +---- tests/python/relax/test_codegen_cutlass.py | 47 ++++++++----------- 4 files changed, 21 insertions(+), 50 deletions(-) diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 6a5853ad485c..694a51070683 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -184,15 +184,6 @@ struct DropoutAttrs : public tvm::AttrsNode { } }; // struct DropoutAttrs -/*! \brief Attributes used in fuse multi head attention operator */ -struct AttentionAttrs : public tvm::AttrsNode { - DataType out_dtype; - - TVM_DECLARE_ATTRS(AttentionAttrs, "relax.attrs.AttentionAttrs") { - TVM_ATTR_FIELD(out_dtype).describe("The data type of the output tensor"); - } -}; // struct AttentionAttrs - } // namespace relax } // namespace tvm diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py index a0f1bbcb8b2f..ebfbd5e94599 100644 --- a/python/tvm/contrib/cutlass/attention_operation.py +++ b/python/tvm/contrib/cutlass/attention_operation.py @@ -63,7 +63,6 @@ def instantiate_attention_template(attrs, func_args): p.q_strideH = p.head_dim; // H p.k_strideH = p.head_dim; // H p.v_strideH = p.head_dim_value; // H' - // p.o_strideH = p.head_dim_value; // H' // stride for S p.q_strideM = p.q_strideH * p.num_heads; // H * N @@ -75,7 +74,6 @@ def instantiate_attention_template(attrs, func_args): p.q_strideB = p.q_strideM * p.num_queries; // H * N * S p.k_strideB = p.k_strideM * p.num_keys; // H * N * S' p.v_strideB = p.v_strideM * p.num_keys; // H'* N * S' - // p.o_strideB = p.o_strideM * p.num_queries; // H'* N * S constexpr auto kernel_fn = attention_kernel_batched_impl; int smem_bytes = sizeof(typename Attention::SharedStorage); diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc index c6e08177c683..85e77bf2bf81 100644 --- a/src/relax/op/nn/attention.cc +++ b/src/relax/op/nn/attention.cc @@ -26,13 +26,9 @@ namespace tvm { namespace relax { /* relax.nn.attention */ -TVM_REGISTER_NODE_TYPE(AttentionAttrs); - Expr attention(Expr query, Expr key, Expr value, DataType out_dtype) { - ObjectPtr attrs = make_object(); - attrs->out_dtype = out_dtype; static const Op& op = Op::Get("relax.nn.attention"); - return Call(op, {std::move(query), std::move(key), std::move(value)}, Attrs(attrs), {}); + return Call(op, {std::move(query), std::move(key), std::move(value)}, {}, {}); } TVM_REGISTER_GLOBAL("relax.op.nn.attention").set_body_typed(attention); @@ -52,10 +48,6 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { diag_dim(q_sinfo, "query"); diag_dim(k_sinfo, "key"); diag_dim(v_sinfo, "value"); - const auto* attrs = call->attrs.as(); - DataType out_dtype = attrs->out_dtype.is_void() - ? InferBinaryArithOpOutDtype(call, ctx, q_sinfo, k_sinfo) - : attrs->out_dtype; const ShapeExprNode* q_shape = q_sinfo->shape.as(); const ShapeExprNode* k_shape = k_sinfo->shape.as(); const ShapeExprNode* v_shape = v_sinfo->shape.as(); @@ -82,7 +74,7 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { diag_equal(head_dim, k_shape->values[3], "query", "key", "dimension of heads"); Array output_shape = {num_batches, num_queries, num_heads, head_dim_value}; - return TensorStructInfo(ShapeExpr(output_shape), out_dtype); + return TensorStructInfo(ShapeExpr(output_shape), q_sinfo->dtype); } TVM_REGISTER_OP("relax.nn.attention") @@ -90,7 +82,6 @@ TVM_REGISTER_OP("relax.nn.attention") .add_argument("query", "Tensor", "The input queries tensor.") .add_argument("key", "Tensor", "The input keys tensor.") .add_argument("value", "Tensor", "The input values tensor.") - .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoAttention); } // namespace relax diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index e9d3bfa9a715..65b94c88fa58 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -21,6 +21,7 @@ import tvm import tvm.testing import tvm.topi.testing +from tvm.contrib.pickle_memoize import memoize from tvm import relax, relay from tvm.contrib.cutlass.build import is_valid_for_cutlass_matmul from tvm.relax.backend import get_patterns_with_prefix @@ -301,6 +302,11 @@ def test_cutlass_partition_matmul_blocked(x_shape, y_shape, transpose_y, dtype): tvm.ir.assert_structural_equal(mod, partition_for_cutlass(mod)) +@pytest.fixture(params=["float16"]) +def attention_dtype(request): + return request.param + + @pytest.fixture( params=[ # B, S, N, H @@ -313,24 +319,6 @@ def attention_size(request): return request.param -@pytest.fixture -def attention_q(attention_size, target_dtype): - b, (s, _), n, (h, _) = attention_size - return np.random.randn(b, s, n, h).astype(target_dtype) - - -@pytest.fixture -def attention_k(attention_size, target_dtype): - b, (_, s), n, (h, _) = attention_size - return np.random.randn(b, s, n, h).astype(target_dtype) - - -@pytest.fixture -def attention_v(attention_size, target_dtype): - b, (_, s), n, (_, h) = attention_size - return np.random.randn(b, s, n, h).astype(target_dtype) - - def get_relax_attention_module(q, k, v): dtype = str(q.dtype) @@ -354,24 +342,27 @@ def get_relax_attention_module(q, k, v): return tvm.IRModule({"main": func}) -def get_numpy_attention_ref(q, k, v): +@memoize("topi.tests.test_codegen_cutlass.test_attention_offload") +def get_numpy_attention_ref(b, s, s_kv, n, h, h_v, dtype): + q = np.random.randn(b, s, n, h).astype(dtype) + k = np.random.randn(b, s_kv, n, h).astype(dtype) + v = np.random.randn(b, s_kv, n, h_v).astype(dtype) qt = q.transpose(0, 2, 1, 3) # b, n, s, h - kt = k.transpose(0, 2, 3, 1) # b, n, h, s - score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s + kt = k.transpose(0, 2, 3, 1) # b, n, h, s_kv + score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv attn = tvm.topi.testing.softmax_python(score, -1) - vt = v.transpose(0, 2, 1, 3) # b, n, s, h - ref = attn @ vt # b, n, s, h - return ref.transpose(0, 2, 1, 3) # b, s, n, h + vt = v.transpose(0, 2, 1, 3) # b, n, s_kv, h_v + ref = attn @ vt # b, n, s, h_v + return q, k, v, ref.transpose(0, 2, 1, 3) # b, s, n, h_v -def test_attention_offload(attention_q, attention_k, attention_v): - q, k, v = attention_q, attention_k, attention_v +def test_attention_offload(attention_size, attention_dtype): + b, (s, s_kv), n, (h, h_v) = attention_size + q, k, v, ref = get_numpy_attention_ref(b, s, s_kv, n, h, h_v, attention_dtype) mod = get_relax_attention_module(q, k, v) out = get_result_with_relax_cutlass_offload(mod, q, k, v) - ref = get_numpy_attention_ref(q, k, v) - tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) From 7c1eb7f69017278bb9fbe0c927e8f2e2b071a939 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Thu, 2 Mar 2023 15:29:28 -0800 Subject: [PATCH 03/11] add more dispatches --- .../contrib/cutlass/attention_operation.py | 12 +++-- python/tvm/contrib/cutlass/gen_tensor_op.py | 52 +++++++++++-------- python/tvm/contrib/cutlass/library.py | 6 +++ tests/python/relax/test_codegen_cutlass.py | 8 +-- 4 files changed, 50 insertions(+), 28 deletions(-) diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py index ebfbd5e94599..06609807e567 100644 --- a/python/tvm/contrib/cutlass/attention_operation.py +++ b/python/tvm/contrib/cutlass/attention_operation.py @@ -24,7 +24,7 @@ def instantiate_attention_template(attrs, func_args): based on a template and the provided attribute map.""" template = """ - using T = cutlass::half_t; + using T = ${data_type}; CHECK(${arg0}->ndim == 4); // B, S, N, H CHECK(${arg1}->ndim == 4); // B, S', N, H @@ -34,7 +34,7 @@ def instantiate_attention_template(attrs, func_args): using Attention = AttentionKernel(${arg2}->data); p.logsumexp_ptr = nullptr; p.output_ptr = reinterpret_cast(out0->data); - static_assert(!Attention::kNeedsOutputAccumulatorBuffer); p.output_accum_ptr = nullptr; + if (Attention::kNeedsOutputAccumulatorBuffer) { + cudaMalloc( + &p.output_accum_ptr, + ${output_size} * sizeof(Attention::output_accum_t) + ); + } p.num_heads = ${num_heads}; // N p.num_batches = ${num_batches}; // B @@ -57,7 +62,6 @@ def instantiate_attention_template(attrs, func_args): p.num_queries = ${num_queries}; // S p.num_keys = ${num_keys}; // S' p.scale = 1.0f / sqrt(float(${head_dim})); - // p.causal = false; // stride for N p.q_strideH = p.head_dim; // H diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 4dddbb56de0d..953d4cda1385 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -33,6 +33,7 @@ from .attention_operation import instantiate_attention_template from .library import ( DataType, + DataTypeSize, DataTypeTag, EpilogueFunctor, MathInstruction, @@ -550,8 +551,8 @@ def get_batch_stride(stride_annot, arg0_idx, arg1_idx, arg0_axis_idx, arg1_axis_ attrs["ElementInputB"] = DataTypeTag[dtype_map[annotations[f"arg{rhs_arg_idx}_dtype"]]] attrs["ElementOutput"] = DataTypeTag[dtype_map[annotations["ret_dtype"]]] - attrs["K"] = str(int(lhs_shape[lhs_batched_offset + 1])) - attrs["M"] = get_dim(lhs_shape[lhs_batched_offset], lhs_arg, 0, lhs_batched_offset) + attrs["K"] = lhs_shape[batched_offset + 1] + attrs["M"] = get_dim(lhs_shape[batched_offset], lhs_arg, 0, batched_offset) if transposed: attrs["N"] = get_dim(rhs_shape[rhs_batched_offset], rhs_arg, 0, rhs_batched_offset) @@ -631,18 +632,18 @@ def get_batch_stride(stride_annot, arg0_idx, arg1_idx, arg0_axis_idx, arg1_axis_ attrs["N"] = get_dim(activation_shape[0], activation_var, 0) attrs["H"] = get_dim(activation_shape[1], activation_var, 1) attrs["W"] = get_dim(activation_shape[2], activation_var, 2) - attrs["C"] = str(int(activation_shape[3])) + attrs["C"] = activation_shape[3] attrs["P"] = get_dim(output_shape[1], "out0", 1) attrs["Q"] = get_dim(output_shape[2], "out0", 2) - attrs["K"] = str(int(output_shape[3])) - attrs["R"] = str(int(weight_shape[1])) - attrs["S"] = str(int(weight_shape[2])) - attrs["pad_h"] = str(int(annotations["padding"][0])) - attrs["pad_w"] = str(int(annotations["padding"][1])) - attrs["stride_h"] = str(int(annotations["strides"][0])) - attrs["stride_w"] = str(int(annotations["strides"][1])) - attrs["dilation_h"] = str(int(annotations["dilation"][0])) - attrs["dilation_w"] = str(int(annotations["dilation"][1])) + attrs["K"] = output_shape[3] + attrs["R"] = weight_shape[1] + attrs["S"] = weight_shape[2] + attrs["pad_h"] = annotations["padding"][0] + attrs["pad_w"] = annotations["padding"][1] + attrs["stride_h"] = annotations["strides"][0] + attrs["stride_w"] = annotations["strides"][1] + attrs["dilation_h"] = annotations["dilation"][0] + attrs["dilation_w"] = annotations["dilation"][1] if "splitk" in op_name: attrs["split_k_mode"] = "kParallel" @@ -656,21 +657,30 @@ def get_batch_stride(stride_annot, arg0_idx, arg1_idx, arg0_axis_idx, arg1_axis_ elif "attention" in func_name: headers.append("kernel_forward.h") - attrs["num_batches"] = str(int(annotations["num_batches"])) - attrs["num_queries"] = str(int(annotations["num_queries"])) - attrs["num_keys"] = str(int(annotations["num_keys"])) - attrs["num_heads"] = str(int(annotations["num_heads"])) - attrs["head_dim"] = str(int(annotations["head_dim"])) - h_v = int(annotations["head_dim_value"]) - attrs["head_dim_value"] = str(h_v) + data_type = dtype_map[annotations["arg0_dtype"]] + attrs["data_type"] = DataTypeTag[data_type] + attrs["num_batches"] = b = annotations["num_batches"] + attrs["num_queries"] = s = annotations["num_queries"] + attrs["num_keys"] = annotations["num_keys"] + attrs["num_heads"] = n = annotations["num_heads"] + attrs["head_dim"] = h = annotations["head_dim"] + attrs["head_dim_value"] = h_v = annotations["head_dim_value"] + data_type_size = DataTypeSize[data_type] + if (data_type_size * h // 8) % 16 == 0 and (data_type_size * h_v // 8) % 16 == 0: + attrs["kIsAligned"] = True + elif (h % 4 == 0) and (h_v % 4 == 0): + attrs["kIsAligned"] = False + else: + raise NotImplementedError() if h_v > 64: attrs["kQueriesPerBlock"] = "32" attrs["kKeysPerBlock"] = "128" - attrs["kSingleValueIteration"] = "true" if h_v <= 128 else "false" + attrs["kSingleValueIteration"] = h_v <= 128 else: attrs["kQueriesPerBlock"] = "64" attrs["kKeysPerBlock"] = "64" - attrs["kSingleValueIteration"] = "true" + attrs["kSingleValueIteration"] = True + attrs["output_size"] = b * s * n * h_v attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"]) code = instantiate_attention_template(attrs, func_args) return CodegenResult(code, headers) diff --git a/python/tvm/contrib/cutlass/library.py b/python/tvm/contrib/cutlass/library.py index 8632ab15641d..b72553ef6052 100644 --- a/python/tvm/contrib/cutlass/library.py +++ b/python/tvm/contrib/cutlass/library.py @@ -20,6 +20,8 @@ import enum from enum import auto as enum_auto +from tvm.tir.expr import IntImm + class GeneratorTarget(enum.Enum): Library = enum_auto() @@ -143,6 +145,10 @@ def substitute_template(template, values): while changed: changed = False for key, value in values.items(): + if isinstance(value, (int, IntImm)): + value = str(int(value)) + elif isinstance(value, bool): + value = str(value).lower() regex = "\\$\\{%s\\}" % key newtext = re.sub(regex, value, text) if newtext != text: diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 65b94c88fa58..809188e799f0 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -302,7 +302,7 @@ def test_cutlass_partition_matmul_blocked(x_shape, y_shape, transpose_y, dtype): tvm.ir.assert_structural_equal(mod, partition_for_cutlass(mod)) -@pytest.fixture(params=["float16"]) +@pytest.fixture(params=["float16", "float32"]) def attention_dtype(request): return request.param @@ -311,8 +311,10 @@ def attention_dtype(request): params=[ # B, S, N, H (32, (4, 4), 16, (8, 8)), - (4, (8, 4), 32, (8, 8)), - (4, (8, 4), 32, (8, 16)), + (4, (8, 4), 32, (8, 8)), # s != s_kv + (4, (8, 4), 32, (8, 16)), # h != h_v + (32, (4, 4), 16, (4, 4)), # h is not aligned + (2, (4, 4), 8, (256, 256)), # needs output accumulator buffer ] ) def attention_size(request): From 7d557665d2ece8d248bdccf02e7ffb21ad5a022f Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Thu, 2 Mar 2023 15:41:31 -0800 Subject: [PATCH 04/11] nit and fix rebase --- python/tvm/contrib/cutlass/gen_tensor_op.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 953d4cda1385..964d2f142e4a 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -551,8 +551,8 @@ def get_batch_stride(stride_annot, arg0_idx, arg1_idx, arg0_axis_idx, arg1_axis_ attrs["ElementInputB"] = DataTypeTag[dtype_map[annotations[f"arg{rhs_arg_idx}_dtype"]]] attrs["ElementOutput"] = DataTypeTag[dtype_map[annotations["ret_dtype"]]] - attrs["K"] = lhs_shape[batched_offset + 1] - attrs["M"] = get_dim(lhs_shape[batched_offset], lhs_arg, 0, batched_offset) + attrs["K"] = lhs_shape[lhs_batched_offset + 1] + attrs["M"] = get_dim(lhs_shape[lhs_batched_offset], lhs_arg, 0, lhs_batched_offset) if transposed: attrs["N"] = get_dim(rhs_shape[rhs_batched_offset], rhs_arg, 0, rhs_batched_offset) @@ -650,7 +650,7 @@ def get_batch_stride(stride_annot, arg0_idx, arg1_idx, arg0_axis_idx, arg1_axis_ attrs["split_k_slices"] = str(re.search(r"splitk(\d+)", op_name).group(1)) else: attrs["split_k_mode"] = "kSerial" - attrs["split_k_slices"] = "1" + attrs["split_k_slices"] = 1 code = instantiate_conv2d_template(attrs, func_args) return CodegenResult(code, headers) @@ -673,12 +673,12 @@ def get_batch_stride(stride_annot, arg0_idx, arg1_idx, arg0_axis_idx, arg1_axis_ else: raise NotImplementedError() if h_v > 64: - attrs["kQueriesPerBlock"] = "32" - attrs["kKeysPerBlock"] = "128" + attrs["kQueriesPerBlock"] = 32 + attrs["kKeysPerBlock"] = 128 attrs["kSingleValueIteration"] = h_v <= 128 else: - attrs["kQueriesPerBlock"] = "64" - attrs["kKeysPerBlock"] = "64" + attrs["kQueriesPerBlock"] = 64 + attrs["kKeysPerBlock"] = 64 attrs["kSingleValueIteration"] = True attrs["output_size"] = b * s * n * h_v attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"]) From a37dc4a77c22b115f96401419fb9f7953d06d676 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Thu, 2 Mar 2023 15:56:41 -0800 Subject: [PATCH 05/11] fix linter --- python/tvm/contrib/cutlass/attention_operation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py index 06609807e567..3a8b9530201a 100644 --- a/python/tvm/contrib/cutlass/attention_operation.py +++ b/python/tvm/contrib/cutlass/attention_operation.py @@ -50,7 +50,7 @@ def instantiate_attention_template(attrs, func_args): p.output_accum_ptr = nullptr; if (Attention::kNeedsOutputAccumulatorBuffer) { cudaMalloc( - &p.output_accum_ptr, + &p.output_accum_ptr, ${output_size} * sizeof(Attention::output_accum_t) ); } From b8db5e2fbca27e14db103ba93e0e9e0cb142c2bb Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Thu, 2 Mar 2023 21:17:08 -0800 Subject: [PATCH 06/11] add support for bias --- .../contrib/cutlass/attention_operation.py | 16 +++++++ python/tvm/relax/backend/contrib/cutlass.py | 6 ++- python/tvm/relax/backend/patterns.py | 12 ++++-- python/tvm/relax/op/nn/nn.py | 12 +++--- src/relax/op/nn/attention.cc | 34 +++++++++++++-- src/relax/op/nn/attention.h | 2 +- tests/python/relax/test_codegen_cutlass.py | 43 +++++++++++++++---- 7 files changed, 102 insertions(+), 23 deletions(-) diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py index 3a8b9530201a..81641c9c54cd 100644 --- a/python/tvm/contrib/cutlass/attention_operation.py +++ b/python/tvm/contrib/cutlass/attention_operation.py @@ -23,6 +23,16 @@ def instantiate_attention_template(attrs, func_args): """Return CUTLASS host code for fused multi head attention based on a template and the provided attribute map.""" + bias_template = """ + CHECK(${arg3}->ndim == 4); // B, S, N, S' + + p.attn_bias_ptr = reinterpret_cast(${arg3}->data); + p.bias_strideH = p.num_keys; // S' + p.bias_strideM = p.bias_strideH * p.num_heads; // S' * N + p.bias_strideB = p.bias_strideM * p.num_queries; // S' * N * S + +""" + template = """ using T = ${data_type}; @@ -79,6 +89,8 @@ def instantiate_attention_template(attrs, func_args): p.k_strideB = p.k_strideM * p.num_keys; // H * N * S' p.v_strideB = p.v_strideM * p.num_keys; // H'* N * S' + ${bias_template} + constexpr auto kernel_fn = attention_kernel_batched_impl; int smem_bytes = sizeof(typename Attention::SharedStorage); if (smem_bytes > 0xc000) { @@ -92,6 +104,10 @@ def instantiate_attention_template(attrs, func_args): CHECK(Attention::check_supported(p)); kernel_fn<<>>(p); """ + if len(func_args) > 3: + template = substitute_template(template, {"bias_template": bias_template}) + else: + template = substitute_template(template, {"bias_template": ""}) for i, arg in enumerate(func_args): attrs["arg{}".format(i)] = arg return substitute_template(template, attrs) diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index fed47aa6ecda..19165fa8329f 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -163,7 +163,11 @@ def _check_matmul( ), ( "cutlass.attention", - make_attention_pattern(), + *make_attention_pattern(), + ), + ( + "cutlass.attention_bias", + *make_attention_pattern(with_bias=True), ), ] ) diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py index 85245767481b..a2ea803d9d92 100644 --- a/python/tvm/relax/backend/patterns.py +++ b/python/tvm/relax/backend/patterns.py @@ -115,7 +115,7 @@ def make_matmul_pattern( return _with_bias_activation_pattern(out, args, with_bias, activation) -def make_attention_pattern(): +def make_attention_pattern(with_bias: bool = False): """ Create pattern for fused multi head attention. @@ -131,6 +131,12 @@ def make_attention_pattern(): query = wildcard() key = wildcard() value = wildcard() - out = is_op("relax.nn.attention")(query, key, value) + args = {"query": query, "key": key, "value": value} + if with_bias: + bias = wildcard() + args["bias"] = bias + out = is_op("relax.nn.attention_bias")(query, key, value, bias) + else: + out = is_op("relax.nn.attention")(query, key, value) - return out + return out, args diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 347f63ec90af..c028fcc606c4 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -579,9 +579,7 @@ def cross_entropy_with_logits(predictions: Expr, labels: Expr) -> Expr: return _ffi_api.cross_entropy_with_logits(predictions, labels) # type: ignore -def attention( - query: Expr, key: Expr, value: Expr, out_dtype: Optional[Union[str, DataType]] = None -) -> Expr: +def attention(query: Expr, key: Expr, value: Expr, bias: Optional[Expr] = None) -> Expr: r"""Computes fused multi head attention. All input tensors are of 4-D tensors with BSNH layout. @@ -606,9 +604,9 @@ def attention( The input value to the operator. The layout of the input value should be (batch_size, seq_len_kv, num_head, head_dim_v). - out_dtype: Optional[Union[str, DataType]] - The data type of the attention result. - When it is not specified, the output dtype will be the the same as input dtype. + bias: Optional[Expr] + The optional attention bias to the operator. The layout of the attention bias should be + (batch_size, seq_len, num_head, seq_len_kv). Returns ------- @@ -616,4 +614,4 @@ def attention( The computed result. The layout of the output should be (batch_size, seq_len, num_head, head_dim_v). """ - return _ffi_api.attention(query, key, value, out_dtype) # type: ignore + return _ffi_api.attention(query, key, value, bias) # type: ignore diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc index 85e77bf2bf81..9c3c53acfe93 100644 --- a/src/relax/op/nn/attention.cc +++ b/src/relax/op/nn/attention.cc @@ -26,9 +26,15 @@ namespace tvm { namespace relax { /* relax.nn.attention */ -Expr attention(Expr query, Expr key, Expr value, DataType out_dtype) { +Expr attention(Expr query, Expr key, Expr value, Optional bias) { static const Op& op = Op::Get("relax.nn.attention"); - return Call(op, {std::move(query), std::move(key), std::move(value)}, {}, {}); + if (bias.defined()) { + return Call(Op::Get("relax.nn.attention_bias"), + {std::move(query), std::move(key), std::move(value), std::move(bias.value())}, {}, + {}); + } + return Call(Op::Get("relax.nn.attention"), {std::move(query), std::move(key), std::move(value)}, + {}, {}); } TVM_REGISTER_GLOBAL("relax.op.nn.attention").set_body_typed(attention); @@ -70,9 +76,23 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { diag_equal(num_batches, v_shape->values[0], "query", "value", "batch size"); diag_equal(num_heads, k_shape->values[2], "query", "key", "number of heads"); diag_equal(num_heads, v_shape->values[2], "query", "value", "number of heads"); - diag_equal(k_shape->values[1], v_shape->values[1], "key", "value", "sequence length"); + diag_equal(num_keys, v_shape->values[1], "key", "value", "sequence length"); diag_equal(head_dim, k_shape->values[3], "query", "key", "dimension of heads"); + if (input_sinfo.size() == 4) { + TensorStructInfo bias_sinfo = input_sinfo[3]; + if (bias_sinfo->ndim != 4) { + ctx->ReportFatal(Diagnostic::Error(call) + << "The bias should have 4 dimension, namely " + << "[batch size, sequence length, number of heads, sequence length]."); + } + const ShapeExprNode* bias_shape = bias_sinfo->shape.as(); + diag_equal(num_batches, bias_shape->values[0], "query", "bias", "batch size"); + diag_equal(num_queries, bias_shape->values[1], "query", "bias", "sequence length"); + diag_equal(num_heads, bias_shape->values[2], "query", "bias", "number of heads"); + diag_equal(num_keys, bias_shape->values[3], "key", "bias", "sequence length"); + } + Array output_shape = {num_batches, num_queries, num_heads, head_dim_value}; return TensorStructInfo(ShapeExpr(output_shape), q_sinfo->dtype); } @@ -84,5 +104,13 @@ TVM_REGISTER_OP("relax.nn.attention") .add_argument("value", "Tensor", "The input values tensor.") .set_attr("FInferStructInfo", InferStructInfoAttention); +TVM_REGISTER_OP("relax.nn.attention_bias") + .set_num_inputs(4) + .add_argument("query", "Tensor", "The input queries tensor.") + .add_argument("key", "Tensor", "The input keys tensor.") + .add_argument("value", "Tensor", "The input values tensor.") + .add_argument("bias", "Tensor", "The input bias tensor.") + .set_attr("FInferStructInfo", InferStructInfoAttention); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/nn/attention.h b/src/relax/op/nn/attention.h index b32c32bac104..662e0b7e7b81 100644 --- a/src/relax/op/nn/attention.h +++ b/src/relax/op/nn/attention.h @@ -33,7 +33,7 @@ namespace tvm { namespace relax { /*! \brief fused multi head attention */ -Expr attention(Expr query, Expr key, Expr value, DataType out_dtype); +Expr attention(Expr query, Expr key, Expr value, Optional bias); } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 809188e799f0..3deb6dfdd535 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -310,18 +310,18 @@ def attention_dtype(request): @pytest.fixture( params=[ # B, S, N, H - (32, (4, 4), 16, (8, 8)), - (4, (8, 4), 32, (8, 8)), # s != s_kv - (4, (8, 4), 32, (8, 16)), # h != h_v - (32, (4, 4), 16, (4, 4)), # h is not aligned - (2, (4, 4), 8, (256, 256)), # needs output accumulator buffer + (32, (8, 8), 16, (8, 8)), + (4, (16, 8), 32, (8, 8)), # s != s_kv + (4, (16, 8), 32, (8, 16)), # h != h_v + (32, (8, 8), 16, (4, 4)), # h is not aligned + (2, (8, 8), 8, (256, 256)), # needs output accumulator buffer ] ) def attention_size(request): return request.param -def get_relax_attention_module(q, k, v): +def get_relax_attention_module(q, k, v, bias=None): dtype = str(q.dtype) from tvm.script.ir_builder import IRBuilder @@ -333,9 +333,10 @@ def get_relax_attention_module(q, k, v): q = R.arg("q", R.Tensor(q.shape, dtype)) k = R.arg("k", R.Tensor(k.shape, dtype)) v = R.arg("v", R.Tensor(v.shape, dtype)) - + if bias is not None: + bias = R.arg("bias", R.Tensor(bias.shape, dtype)) with R.dataflow() as frame: - result = R.emit(R.nn.attention(q, k, v)) + result = R.emit(R.nn.attention(q, k, v, bias)) R.output(result) R.func_ret_value(frame.output_vars[0]) @@ -368,5 +369,31 @@ def test_attention_offload(attention_size, attention_dtype): tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) +@memoize("topi.tests.test_codegen_cutlass.test_attention_offload") +def get_numpy_attention_bias_ref(b, s, s_kv, n, h, h_v, dtype): + q = np.random.randn(b, s, n, h).astype(dtype) + k = np.random.randn(b, s_kv, n, h).astype(dtype) + v = np.random.randn(b, s_kv, n, h_v).astype(dtype) + bias = np.random.randn(b, s, n, s_kv).astype(dtype) + qt = q.transpose(0, 2, 1, 3) # b, n, s, h + kt = k.transpose(0, 2, 3, 1) # b, n, h, s_kv + score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv + score_bias = score + bias.transpose(0, 2, 1, 3) # b, n, s, s_kv + attn = tvm.topi.testing.softmax_python(score_bias, -1) + vt = v.transpose(0, 2, 1, 3) # b, n, s_kv, h_v + ref = attn @ vt # b, n, s, h_v + return q, k, v, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v + + +def test_attention_bias_offload(attention_size, attention_dtype): + b, (s, s_kv), n, (h, h_v) = attention_size + q, k, v, bias, ref = get_numpy_attention_bias_ref(b, s, s_kv, n, h, h_v, attention_dtype) + + mod = get_relax_attention_module(q, k, v, bias) + out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias) + + tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) + + if __name__ == "__main__": tvm.testing.main() From 5518599d7ab9bbef8aac455f4f0168c1b5a73cfc Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Thu, 2 Mar 2023 21:31:30 -0800 Subject: [PATCH 07/11] fix lint --- python/tvm/contrib/cutlass/attention_operation.py | 8 +++++--- python/tvm/contrib/cutlass/gen_tensor_op.py | 2 ++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py index 81641c9c54cd..53c3af187497 100644 --- a/python/tvm/contrib/cutlass/attention_operation.py +++ b/python/tvm/contrib/cutlass/attention_operation.py @@ -30,7 +30,7 @@ def instantiate_attention_template(attrs, func_args): p.bias_strideH = p.num_keys; // S' p.bias_strideM = p.bias_strideH * p.num_heads; // S' * N p.bias_strideB = p.bias_strideM * p.num_queries; // S' * N * S - + """ template = """ @@ -47,7 +47,9 @@ def instantiate_attention_template(attrs, func_args): /*is_aligned=*/${kIsAligned}, /*queries_per_block=*/${kQueriesPerBlock}, /*keys_per_block=*/${kKeysPerBlock}, - /*single_value_iteration=*/${kSingleValueIteration} + /*single_value_iteration=*/${kSingleValueIteration}, + /*supports_dropout=*/${kSupportsDropout}, + /*supports_bias=*/${kSupportsBias} >; typename Attention::Params p; @@ -104,7 +106,7 @@ def instantiate_attention_template(attrs, func_args): CHECK(Attention::check_supported(p)); kernel_fn<<>>(p); """ - if len(func_args) > 3: + if attrs["kSupportsBias"]: template = substitute_template(template, {"bias_template": bias_template}) else: template = substitute_template(template, {"bias_template": ""}) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 964d2f142e4a..2fe4ae5e3ff1 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -682,6 +682,8 @@ def get_batch_stride(stride_annot, arg0_idx, arg1_idx, arg0_axis_idx, arg1_axis_ attrs["kSingleValueIteration"] = True attrs["output_size"] = b * s * n * h_v attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"]) + attrs["kSupportsDropout"] = False + attrs["kSupportsBias"] = len(func_args) > 3 code = instantiate_attention_template(attrs, func_args) return CodegenResult(code, headers) From 968c192af288a986fc286590cb80794751f7d51e Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Fri, 3 Mar 2023 10:56:46 -0800 Subject: [PATCH 08/11] BNSS layout for bias --- python/tvm/contrib/cutlass/attention_operation.py | 8 ++++---- src/relax/op/nn/attention.cc | 5 ++--- tests/python/relax/test_codegen_cutlass.py | 4 ++-- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py index 53c3af187497..14f57b132de6 100644 --- a/python/tvm/contrib/cutlass/attention_operation.py +++ b/python/tvm/contrib/cutlass/attention_operation.py @@ -24,12 +24,12 @@ def instantiate_attention_template(attrs, func_args): based on a template and the provided attribute map.""" bias_template = """ - CHECK(${arg3}->ndim == 4); // B, S, N, S' + CHECK(${arg3}->ndim == 4); // B, N, S, S' p.attn_bias_ptr = reinterpret_cast(${arg3}->data); - p.bias_strideH = p.num_keys; // S' - p.bias_strideM = p.bias_strideH * p.num_heads; // S' * N - p.bias_strideB = p.bias_strideM * p.num_queries; // S' * N * S + p.bias_strideM = p.num_keys; // S' + p.bias_strideH = p.bias_strideM * p.num_queries; // S' * S + p.bias_strideB = p.bias_strideH * p.num_heads; // S' * S * N """ diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc index 9c3c53acfe93..c8abf49c336e 100644 --- a/src/relax/op/nn/attention.cc +++ b/src/relax/op/nn/attention.cc @@ -27,7 +27,6 @@ namespace relax { /* relax.nn.attention */ Expr attention(Expr query, Expr key, Expr value, Optional bias) { - static const Op& op = Op::Get("relax.nn.attention"); if (bias.defined()) { return Call(Op::Get("relax.nn.attention_bias"), {std::move(query), std::move(key), std::move(value), std::move(bias.value())}, {}, @@ -88,8 +87,8 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { } const ShapeExprNode* bias_shape = bias_sinfo->shape.as(); diag_equal(num_batches, bias_shape->values[0], "query", "bias", "batch size"); - diag_equal(num_queries, bias_shape->values[1], "query", "bias", "sequence length"); - diag_equal(num_heads, bias_shape->values[2], "query", "bias", "number of heads"); + diag_equal(num_heads, bias_shape->values[1], "query", "bias", "number of heads"); + diag_equal(num_queries, bias_shape->values[2], "query", "bias", "sequence length"); diag_equal(num_keys, bias_shape->values[3], "key", "bias", "sequence length"); } diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 3deb6dfdd535..be701337506a 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -374,11 +374,11 @@ def get_numpy_attention_bias_ref(b, s, s_kv, n, h, h_v, dtype): q = np.random.randn(b, s, n, h).astype(dtype) k = np.random.randn(b, s_kv, n, h).astype(dtype) v = np.random.randn(b, s_kv, n, h_v).astype(dtype) - bias = np.random.randn(b, s, n, s_kv).astype(dtype) + bias = np.random.randn(b, n, s, s_kv).astype(dtype) qt = q.transpose(0, 2, 1, 3) # b, n, s, h kt = k.transpose(0, 2, 3, 1) # b, n, h, s_kv score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv - score_bias = score + bias.transpose(0, 2, 1, 3) # b, n, s, s_kv + score_bias = score + bias # b, n, s, s_kv attn = tvm.topi.testing.softmax_python(score_bias, -1) vt = v.transpose(0, 2, 1, 3) # b, n, s_kv, h_v ref = attn @ vt # b, n, s, h_v From 3f9b6f50cec3ecefa89eda5af1435074879a7165 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Fri, 3 Mar 2023 10:57:52 -0800 Subject: [PATCH 09/11] update doc --- python/tvm/relax/op/nn/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index c028fcc606c4..3c7f0614bdfe 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -606,7 +606,7 @@ def attention(query: Expr, key: Expr, value: Expr, bias: Optional[Expr] = None) bias: Optional[Expr] The optional attention bias to the operator. The layout of the attention bias should be - (batch_size, seq_len, num_head, seq_len_kv). + (batch_size, num_head, seq_len, seq_len_kv). Returns ------- From 8abdef031008fc4bd83e9ead3dcb8bdd54cc49e7 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Fri, 3 Mar 2023 11:06:48 -0800 Subject: [PATCH 10/11] fix typo --- tests/python/relax/test_codegen_cutlass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index be701337506a..5da83a8e3dd6 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -369,7 +369,7 @@ def test_attention_offload(attention_size, attention_dtype): tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) -@memoize("topi.tests.test_codegen_cutlass.test_attention_offload") +@memoize("topi.tests.test_codegen_cutlass.test_attention_bias_offload") def get_numpy_attention_bias_ref(b, s, s_kv, n, h, h_v, dtype): q = np.random.randn(b, s, n, h).astype(dtype) k = np.random.randn(b, s_kv, n, h).astype(dtype) From e8f7d47c364c5879e37d2a323bfcd042f31ef182 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Fri, 3 Mar 2023 16:18:34 -0800 Subject: [PATCH 11/11] support bias broadcast --- .../contrib/cutlass/attention_operation.py | 27 +++++++-- python/tvm/contrib/cutlass/build.py | 6 ++ python/tvm/contrib/cutlass/gen_tensor_op.py | 13 +++- python/tvm/relax/op/nn/nn.py | 3 +- src/relax/op/nn/attention.cc | 24 +++++--- tests/python/relax/test_codegen_cutlass.py | 60 +++++++++++++++++-- 6 files changed, 115 insertions(+), 18 deletions(-) diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py index 14f57b132de6..9093a03dd6ed 100644 --- a/python/tvm/contrib/cutlass/attention_operation.py +++ b/python/tvm/contrib/cutlass/attention_operation.py @@ -23,15 +23,32 @@ def instantiate_attention_template(attrs, func_args): """Return CUTLASS host code for fused multi head attention based on a template and the provided attribute map.""" - bias_template = """ + bias_template = { + "B11S'": """ + CHECK(${arg3}->ndim == 2); // B, 1, 1, S' + + p.attn_bias_ptr = reinterpret_cast(${arg3}->data); + p.bias_strideM = 0; // 0 + p.bias_strideH = 0; // 0 + p.bias_strideB = p.num_keys; // S' +""", + "B1SS'": """ + CHECK(${arg3}->ndim == 3); // B, 1, S, S' + + p.attn_bias_ptr = reinterpret_cast(${arg3}->data); + p.bias_strideM = p.num_keys; // S' + p.bias_strideH = 0; // 0 + p.bias_strideB = p.bias_strideM * p.num_queries; // S' * S +""", + "BNSS'": """ CHECK(${arg3}->ndim == 4); // B, N, S, S' p.attn_bias_ptr = reinterpret_cast(${arg3}->data); p.bias_strideM = p.num_keys; // S' p.bias_strideH = p.bias_strideM * p.num_queries; // S' * S p.bias_strideB = p.bias_strideH * p.num_heads; // S' * S * N - -""" +""", + } template = """ using T = ${data_type}; @@ -107,7 +124,9 @@ def instantiate_attention_template(attrs, func_args): kernel_fn<<>>(p); """ if attrs["kSupportsBias"]: - template = substitute_template(template, {"bias_template": bias_template}) + template = substitute_template( + template, {"bias_template": bias_template[attrs["bias_layout"]]} + ) else: template = substitute_template(template, {"bias_template": ""}) for i, arg in enumerate(func_args): diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index c86df6e82978..0e8d419baecf 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -773,6 +773,11 @@ def handle_attention(self, f, op_type): num_batches, num_queries, num_heads, head_dim = q_shape _, num_keys, _, _ = k_shape _, _, _, head_dim_value = v_shape + bias = {} + if "arg3_dtype" in signature: + bias["arg3_dtype"] = signature["arg3_dtype"] + if "arg3_shape" in signature: + bias["arg3_shape"] = signature["arg3_shape"] return f.with_attrs( { @@ -792,6 +797,7 @@ def handle_attention(self, f, op_type): "head_dim": head_dim, "head_dim_value": head_dim_value, "arch": self.options["sm"], + **bias, } ) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 2fe4ae5e3ff1..78e2b489c6fa 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -683,7 +683,18 @@ def get_batch_stride(stride_annot, arg0_idx, arg1_idx, arg0_axis_idx, arg1_axis_ attrs["output_size"] = b * s * n * h_v attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"]) attrs["kSupportsDropout"] = False - attrs["kSupportsBias"] = len(func_args) > 3 + if len(func_args) > 3: + attrs["kSupportsBias"] = True + if len(annotations["arg3_shape"]) == 4: + attrs["bias_layout"] = "BNSS'" + elif len(annotations["arg3_shape"]) == 3: + attrs["bias_layout"] = "B1SS'" + elif len(annotations["arg3_shape"]) == 2: + attrs["bias_layout"] = "B11S'" + else: + raise NotImplementedError() + else: + attrs["kSupportsBias"] = False code = instantiate_attention_template(attrs, func_args) return CodegenResult(code, headers) diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 3c7f0614bdfe..2fef37249703 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -606,7 +606,8 @@ def attention(query: Expr, key: Expr, value: Expr, bias: Optional[Expr] = None) bias: Optional[Expr] The optional attention bias to the operator. The layout of the attention bias should be - (batch_size, num_head, seq_len, seq_len_kv). + (batch_size, num_head, seq_len, seq_len_kv), + (batch_size, seq_len, seq_len_kv) or (batch_size, seq_len_kv). Returns ------- diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc index c8abf49c336e..e139aa09d692 100644 --- a/src/relax/op/nn/attention.cc +++ b/src/relax/op/nn/attention.cc @@ -80,16 +80,24 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { if (input_sinfo.size() == 4) { TensorStructInfo bias_sinfo = input_sinfo[3]; - if (bias_sinfo->ndim != 4) { + const ShapeExprNode* bias_shape = bias_sinfo->shape.as(); + if (bias_sinfo->ndim == 4) { + diag_equal(num_batches, bias_shape->values[0], "query", "bias", "batch size"); + diag_equal(num_heads, bias_shape->values[1], "query", "bias", "number of heads"); + diag_equal(num_queries, bias_shape->values[2], "query", "bias", "sequence length"); + diag_equal(num_keys, bias_shape->values[3], "key", "bias", "sequence length"); + } else if (bias_sinfo->ndim == 3) { + diag_equal(num_batches, bias_shape->values[0], "query", "bias", "batch size"); + diag_equal(num_queries, bias_shape->values[1], "query", "bias", "sequence length"); + diag_equal(num_keys, bias_shape->values[2], "key", "bias", "sequence length"); + } else if (bias_sinfo->ndim == 2) { + diag_equal(num_batches, bias_shape->values[0], "query", "bias", "batch size"); + diag_equal(num_keys, bias_shape->values[1], "key", "bias", "sequence length"); + } else { ctx->ReportFatal(Diagnostic::Error(call) - << "The bias should have 4 dimension, namely " - << "[batch size, sequence length, number of heads, sequence length]."); + << "The bias should have 2, 3 or 4 dimensions." + << "However, the bias input has " << bias_sinfo->ndim << " dimensions."); } - const ShapeExprNode* bias_shape = bias_sinfo->shape.as(); - diag_equal(num_batches, bias_shape->values[0], "query", "bias", "batch size"); - diag_equal(num_heads, bias_shape->values[1], "query", "bias", "number of heads"); - diag_equal(num_queries, bias_shape->values[2], "query", "bias", "sequence length"); - diag_equal(num_keys, bias_shape->values[3], "key", "bias", "sequence length"); } Array output_shape = {num_batches, num_queries, num_heads, head_dim_value}; diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 5da83a8e3dd6..36a1c4cd16ff 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -369,8 +369,8 @@ def test_attention_offload(attention_size, attention_dtype): tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) -@memoize("topi.tests.test_codegen_cutlass.test_attention_bias_offload") -def get_numpy_attention_bias_ref(b, s, s_kv, n, h, h_v, dtype): +@memoize("topi.tests.test_codegen_cutlass.test_attention_bias_4d_offload") +def get_numpy_attention_bias_4d_ref(b, s, s_kv, n, h, h_v, dtype): q = np.random.randn(b, s, n, h).astype(dtype) k = np.random.randn(b, s_kv, n, h).astype(dtype) v = np.random.randn(b, s_kv, n, h_v).astype(dtype) @@ -385,9 +385,61 @@ def get_numpy_attention_bias_ref(b, s, s_kv, n, h, h_v, dtype): return q, k, v, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v -def test_attention_bias_offload(attention_size, attention_dtype): +def test_attention_bias_4d_offload(attention_size, attention_dtype): b, (s, s_kv), n, (h, h_v) = attention_size - q, k, v, bias, ref = get_numpy_attention_bias_ref(b, s, s_kv, n, h, h_v, attention_dtype) + q, k, v, bias, ref = get_numpy_attention_bias_4d_ref(b, s, s_kv, n, h, h_v, attention_dtype) + + mod = get_relax_attention_module(q, k, v, bias) + out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias) + + tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) + + +@memoize("topi.tests.test_codegen_cutlass.test_attention_bias_3d_offload") +def get_numpy_attention_bias_3d_ref(b, s, s_kv, n, h, h_v, dtype): + q = np.random.randn(b, s, n, h).astype(dtype) + k = np.random.randn(b, s_kv, n, h).astype(dtype) + v = np.random.randn(b, s_kv, n, h_v).astype(dtype) + bias = np.random.randn(b, s, s_kv).astype(dtype) + qt = q.transpose(0, 2, 1, 3) # b, n, s, h + kt = k.transpose(0, 2, 3, 1) # b, n, h, s_kv + score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv + score_bias = score + bias.reshape(b, 1, s, s_kv) # b, n, s, s_kv + attn = tvm.topi.testing.softmax_python(score_bias, -1) + vt = v.transpose(0, 2, 1, 3) # b, n, s_kv, h_v + ref = attn @ vt # b, n, s, h_v + return q, k, v, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v + + +def test_attention_bias_3d_offload(attention_size, attention_dtype): + b, (s, s_kv), n, (h, h_v) = attention_size + q, k, v, bias, ref = get_numpy_attention_bias_3d_ref(b, s, s_kv, n, h, h_v, attention_dtype) + + mod = get_relax_attention_module(q, k, v, bias) + out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias) + + tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) + + +@memoize("topi.tests.test_codegen_cutlass.test_attention_bias_2d_offload") +def get_numpy_attention_bias_2d_ref(b, s, s_kv, n, h, h_v, dtype): + q = np.random.randn(b, s, n, h).astype(dtype) + k = np.random.randn(b, s_kv, n, h).astype(dtype) + v = np.random.randn(b, s_kv, n, h_v).astype(dtype) + bias = np.random.randn(b, s_kv).astype(dtype) + qt = q.transpose(0, 2, 1, 3) # b, n, s, h + kt = k.transpose(0, 2, 3, 1) # b, n, h, s_kv + score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv + score_bias = score + bias.reshape(b, 1, 1, s_kv) # b, n, s, s_kv + attn = tvm.topi.testing.softmax_python(score_bias, -1) + vt = v.transpose(0, 2, 1, 3) # b, n, s_kv, h_v + ref = attn @ vt # b, n, s, h_v + return q, k, v, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v + + +def test_attention_bias_2d_offload(attention_size, attention_dtype): + b, (s, s_kv), n, (h, h_v) = attention_size + q, k, v, bias, ref = get_numpy_attention_bias_2d_ref(b, s, s_kv, n, h, h_v, attention_dtype) mod = get_relax_attention_module(q, k, v, bias) out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias)