diff --git a/3rdparty/libflash_attn b/3rdparty/libflash_attn index 58b343e57571..63cce0ca8fa6 160000 --- a/3rdparty/libflash_attn +++ b/3rdparty/libflash_attn @@ -1 +1 @@ -Subproject commit 58b343e57571fe5e0a5b43b5eb721acef8b35dff +Subproject commit 63cce0ca8fa6bfca1982b342588273641cc5b86b diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py index 67a68df442f8..e59dbf032e6a 100644 --- a/python/tvm/contrib/cutlass/attention_operation.py +++ b/python/tvm/contrib/cutlass/attention_operation.py @@ -169,10 +169,10 @@ def instantiate_flash_attention_template(attrs): int k_head_stride = ${head_dim}; int v_head_stride = ${head_dim}; int o_head_stride = ${head_dim}; - int q_row_stride = q_head_stride * ${num_heads}; - int k_row_stride = k_head_stride * ${num_heads}; - int v_row_stride = v_head_stride * ${num_heads}; - int o_row_stride = o_head_stride * ${num_heads}; + int q_row_stride = q_head_stride * ${num_q_heads}; + int k_row_stride = k_head_stride * ${num_kv_heads}; + int v_row_stride = v_head_stride * ${num_kv_heads}; + int o_row_stride = o_head_stride * ${num_q_heads}; int q_batch_stride = q_row_stride * ${num_queries}; int k_batch_stride = k_row_stride * ${num_keys}; int v_batch_stride = v_row_stride * ${num_keys}; @@ -190,8 +190,8 @@ def instantiate_flash_attention_template(attrs): ${num_batches}, ${num_queries}, ${num_keys}, - ${num_heads}, - ${num_heads}, + ${num_q_heads}, + ${num_kv_heads}, ${head_dim}, q_batch_stride, k_batch_stride, @@ -215,13 +215,13 @@ def instantiate_flash_attention_template(attrs): int k_head_stride = ${head_dim}; int v_head_stride = ${head_dim}; int o_head_stride = ${head_dim}; - int row_stride = q_head_stride * ${num_heads} + - k_head_stride * ${num_heads} + - v_head_stride * ${num_heads}; + int row_stride = q_head_stride * ${num_q_heads} + + k_head_stride * ${num_kv_heads} + + v_head_stride * ${num_kv_heads}; int q_row_stride = row_stride; int k_row_stride = row_stride; int v_row_stride = row_stride; - int o_row_stride = o_head_stride * ${num_heads}; + int o_row_stride = o_head_stride * ${num_q_heads}; int q_batch_stride = q_row_stride * ${num_queries}; int k_batch_stride = k_row_stride * ${num_keys}; @@ -234,14 +234,14 @@ def instantiate_flash_attention_template(attrs): flash_attn::flash_attention_forward( static_cast(${qkv}->data), - static_cast(${qkv}->data) + ${head_dim} * ${num_heads}, - static_cast(${qkv}->data) + ${head_dim} * ${num_heads} * 2, + static_cast(${qkv}->data) + ${head_dim} * ${num_q_heads}, + static_cast(${qkv}->data) + ${head_dim} * (${num_q_heads} + ${num_kv_heads}), static_cast(out0->data), ${num_batches}, ${num_queries}, ${num_keys}, - ${num_heads}, - ${num_heads}, + ${num_q_heads}, + ${num_kv_heads}, ${head_dim}, q_batch_stride, k_batch_stride, diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 0c57c4750e87..b97fc20008b4 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -909,8 +909,8 @@ def handle_attention(self, f, op_type): out_shape = signature["ret_shape"] out_dtype = signature["ret_dtype"] - num_batches, num_queries, num_heads, head_dim = q_shape - _, num_keys, _, _ = k_shape + num_batches, num_queries, num_q_heads, head_dim = q_shape + _, num_keys, num_kv_heads, _ = k_shape _, _, _, head_dim_value = v_shape scale = op_attrs.scale @@ -931,7 +931,8 @@ def handle_attention(self, f, op_type): "num_batches": num_batches, "num_queries": num_queries, "num_keys": num_keys, - "num_heads": num_heads, + "num_q_heads": num_q_heads, + "num_kv_heads": num_kv_heads, "head_dim": head_dim, "head_dim_value": head_dim_value, "scale": scale, diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 58bc91863dcc..62e64549c2ae 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -745,7 +745,6 @@ def get_batch_on_arg(arg_name, arg_shape): attrs["data_type"] = DataTypeTag[data_type] attrs["num_batches"] = b = annotations["num_batches"] - attrs["num_heads"] = n = annotations["num_heads"] attrs["head_dim"] = h = annotations["head_dim"] attrs["head_dim_value"] = h_v = annotations["head_dim_value"] attrs["kMaxK"] = max(int(attrs["head_dim"]), int(attrs["head_dim_value"])) @@ -753,26 +752,40 @@ def get_batch_on_arg(arg_name, arg_shape): float(1 / math.sqrt(h.value)) if annotations["scale"] is None else annotations["scale"] ) + is_mqa = annotations["num_q_heads"] != annotations["num_kv_heads"] + use_flash = ( annotations["ret_dtype"] == "float16" and "bias" not in attrs and int(attrs["head_dim"]) <= 256 and int(attrs["head_dim"]) % 8 == 0 and int(attrs["head_dim"]) == int(attrs["head_dim_value"]) - # We have not thoroughly validated flash with causal mask yet, so for now we support - # only non-causal cases. - and int(annotations["custom_mask_type"]) == 0 + # For the causal case (custom mask = "BottomRight"), only use flash for multi-query + # attention workloads. Otherwise, CUTLASS fMHA seems faster for causal attention + # with a single query. + and ( + int(annotations["custom_mask_type"]) == 0 + or (int(annotations["custom_mask_type"]) == 2 and is_mqa) + ) # Flash v2 is currently not supported for sm < 80 and int(annotations["arch"]) >= 80 ) if use_flash: headers.append("flash.h") - attrs["is_causal"] = int(annotations["custom_mask_type"]) > 0 + attrs["is_causal"] = int(annotations["custom_mask_type"]) == 2 + attrs["num_q_heads"] = annotations["num_q_heads"] + attrs["num_kv_heads"] = annotations["num_kv_heads"] code = instantiate_flash_attention_template(attrs) else: headers.append("kernel_forward.h") + assert ( + not is_mqa + ), "The number of query and KV heads need to be the same for CUTLASS fMHA." + + attrs["num_heads"] = n = annotations["num_q_heads"] + data_type_size = DataTypeSize[data_type] if (data_type_size * h // 8) % 16 == 0 and (data_type_size * h_v // 8) % 16 == 0: attrs["kIsAligned"] = True diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index fef6a1ec03c4..9efea3a0dccf 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -576,7 +576,7 @@ def annotate_workspace(mod, _): return mod -def partition_for_cutlass(mod, annotate_codegen=True): +def partition_for_cutlass(mod, annotate_codegen=True, use_flash_mqa=True): """ Partition the input module into CUTLASS-supported subgraphs. @@ -590,6 +590,10 @@ def partition_for_cutlass(mod, annotate_codegen=True): body consists only of a call to the composite function. See the doc of FuseOpsByPattern for more detail. + use_flash_mqa: bool + Whether to consider a rewrite pattern for multi-query attention, which is supported by + the Flash Attention kernel. + Returns ------- mod: tvm.IRModule @@ -598,8 +602,15 @@ def partition_for_cutlass(mod, annotate_codegen=True): """ for func_name, func in mod.functions.items(): if isinstance(func, Function): + if use_flash_mqa: + mqa_pattern, rewriter = make_attention_rewrite_pattern( + "BSNH", "BSNH", with_bias=False, with_cast=True, with_kv_repeat=True + ) + func = rewrite_call(mqa_pattern, rewriter, func) + for pattern, rewriter in _REWRITE_PATTERNS: func = rewrite_call(pattern, rewriter, func) + mod[func_name] = func patterns = get_patterns_with_prefix("cutlass") diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py index 24edd0e7c950..10a075647b5a 100644 --- a/python/tvm/relax/backend/patterns.py +++ b/python/tvm/relax/backend/patterns.py @@ -318,7 +318,7 @@ def make_rms_norm_pattern(): def make_attention_rewrite_pattern( - qkv_layout: str, out_layout: str, with_bias: bool, with_cast: bool + qkv_layout: str, out_layout: str, with_bias: bool, with_cast: bool, with_kv_repeat: bool = False ): """ Create pattern for implicit fused multi head attention rewriting. @@ -338,6 +338,10 @@ def make_attention_rewrite_pattern( Whether or not rewriting is intended to be applied to a module after the FP16 conversion pass. + with_kv_repeat: bool + Whether or not to include the Relax repeat op in the pattern, which is typically used + in a Relax module to support multi-query attention. + Returns ------- pattern: DFPattern @@ -350,7 +354,10 @@ def make_attention_rewrite_pattern( """ # pylint: disable=invalid-name - def handle_input(tensor, layout, transpose): + def handle_input(tensor, layout, transpose, repeat=False): + if repeat: + tensor = is_op("relax.repeat")(tensor) + if layout == "BSNH": permuted = is_op("relax.permute_dims")(tensor) shape = wildcard() @@ -434,8 +441,8 @@ def rewriter(matchings, x): q_raw, k_raw, v_raw = wildcard(), wildcard(), wildcard() q, q_rewriter = handle_input(q_raw, qkv_layout, False) - k, k_rewriter = handle_input(k_raw, qkv_layout, True) - v, v_rewriter = handle_input(v_raw, qkv_layout, False) + k, k_rewriter = handle_input(k_raw, qkv_layout, True, repeat=with_kv_repeat) + v, v_rewriter = handle_input(v_raw, qkv_layout, False, repeat=with_kv_repeat) matmul_1 = is_op("relax.matmul")(q, k) scale = is_const() diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc index 4f37e3a33c29..484137fecc40 100644 --- a/src/relax/op/nn/attention.cc +++ b/src/relax/op/nn/attention.cc @@ -77,10 +77,19 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { << v1 << " while the " << dim << " of " << m2 << " is " << v2); } }; + auto multiple_of = [&](PrimExpr v1, PrimExpr v2, String m1, String m2, String dim) { + if (analyzer->CanProve(indexmod(v1, v2) != 0)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "The " << m1 << " " << dim << " should be a multiple of " << m2 << " " + << dim << ". However, the " << dim << " of " << m1 << " is " << v1 + << " while the " << dim << " of " << m2 << " is " << v2); + } + }; + diag_equal(num_batches, k_shape->values[0], "query", "key", "batch size"); diag_equal(num_batches, v_shape->values[0], "query", "value", "batch size"); - diag_equal(num_heads, k_shape->values[2], "query", "key", "number of heads"); - diag_equal(num_heads, v_shape->values[2], "query", "value", "number of heads"); + multiple_of(num_heads, k_shape->values[2], "query", "key", "number of heads"); + multiple_of(num_heads, v_shape->values[2], "query", "value", "number of heads"); diag_equal(num_keys, v_shape->values[1], "key", "value", "sequence length"); diag_equal(head_dim, k_shape->values[3], "query", "key", "dimension of heads"); diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index e8d4e83521b0..83936ef9c99f 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -746,7 +746,7 @@ def attention_causal(request): def test_attention_causal_offload(attention_causal_size, attention_causal): b, (s, s_kv), n, (h, h_v), bias_shape = attention_causal_size q, k, v, bias, ref = get_numpy_attention_ref( - b, s, s_kv, n, h, h_v, bias_shape, "none", attention_causal, "float32" + b, s, s_kv, n, h, h_v, bias_shape, "none", attention_causal, "float16" ) q_shape = (b, s, n, h) @@ -757,10 +757,11 @@ def test_attention_causal_offload(attention_causal_size, attention_causal): q_shape, k_shape, v_shape, - dtype="float32", + dtype="float16", bias_shape=bias_shape, causal_mask=attention_causal, ) + if bias is None: out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3) else: @@ -1945,5 +1946,51 @@ def main( tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) +def test_attention_rewrite_multi_query(): + @I.ir_module + class Module: + @R.function + def main( + q: R.Tensor((4, 16, 32, 16), dtype="float16"), + k_single: R.Tensor((4, 16, 1, 16), dtype="float16"), + v_single: R.Tensor((4, 16, 1, 16), dtype="float16"), + ) -> R.Tensor((4, 16, 32, 8), dtype="float16"): + with R.dataflow(): + k = R.repeat(k_single, 32, axis=2) + v = R.repeat(v_single, 32, axis=2) + + lv = R.permute_dims(q, axes=[0, 2, 1, 3]) + lv1 = R.reshape(lv, R.shape([128, 16, 16])) + lv2 = R.permute_dims(k, axes=[0, 2, 1, 3]) + lv3 = R.reshape(lv2, R.shape([128, 16, 16])) + lv4 = R.permute_dims(v, axes=[0, 2, 1, 3]) + lv5 = R.reshape(lv4, R.shape([128, 16, 16])) + + lv6 = R.permute_dims(lv3, axes=[0, 2, 1]) + lv7 = R.matmul(lv1, lv6, out_dtype="float16") + lv3_1 = R.astype(R.const(0.25, "float32"), "float16") + lv8 = R.multiply(lv7, lv3_1) + lv11 = R.astype(R.nn.softmax(R.astype(lv8, "float32"), axis=2), "float16") + lv12 = R.matmul(lv11, lv5, out_dtype="float16") + lv13 = R.reshape(lv12, R.shape([4, 32, 16, 16])) + lv6_1 = R.permute_dims(lv13, axes=[0, 2, 1, 3]) + R.output(lv6_1) + return lv6_1 + + q_np = np.random.randn(4, 16, 32, 16).astype("float16") + k_np = np.random.randn(4, 16, 1, 16).astype("float16") + v_np = np.random.randn(4, 16, 1, 16).astype("float16") + args = [q_np, k_np, v_np] + ref = build_and_run(Module, args, "llvm", legalize=True) + + mod = partition_for_cutlass(Module, use_flash_mqa=True) + codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80}}) + mod = codegen_pass(mod) + + out = build_and_run(mod, args, "cuda") + + tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) + + if __name__ == "__main__": tvm.testing.main()