From 21d183cc3089c1cf8e018cd61317325eaecd6b57 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Tue, 18 Apr 2023 22:28:15 +0000 Subject: [PATCH 1/2] [Unity][BYOC] Add check for stacked attention patterns This PR is a follow up for #14608 and #14649. In this PR, we add the checks for the fused stacked attention patterns. So we only enable the fusion of `stacked_qkv` with `ndim=3` and the `split/strided_slice axis=2`. --- python/tvm/relax/backend/contrib/cutlass.py | 23 +++++++++++++++++++++ python/tvm/relax/backend/patterns.py | 10 +++++---- src/relax/transform/fuse_ops.cc | 2 +- 3 files changed, 30 insertions(+), 5 deletions(-) diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index 06edd9febfd3..457acdbb0d6f 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -232,6 +232,25 @@ def residual_block_patterns(): return patterns +def _check_stacked_attention(context: PatternCheckContext) -> bool: + """Check if the given stacked attention workload can be offloaded to CUTLASS.""" + if _has_leaking_intermediate_variables(context): + return False + if not context.annotated_expr["stacked_qkv"].struct_info.ndim == 3: + return False + if "split" in context.annotated_expr: + split_op = context.annotated_expr["split"] + if not split_op.attrs.axis == 2: + return False + else: + for name in ["query", "key", "value"]: + assert f"strided_slice_{name}" in context.annotated_expr + strided_slice_op = context.annotated_expr[f"strided_slice_{name}"] + if not (len(strided_slice_op.attrs.axes) == 1 and strided_slice_op.attrs.axes[0] == 2): + return False + return True + + def attention_patterns(): """ Returns a list of all attention patterns in cutlass BYOC backend. @@ -248,18 +267,22 @@ def attention_patterns(): ( "cutlass.stacked_attention", *make_stacked_attention_pattern(start_op="split"), + _check_stacked_attention, ), ( "cutlass.stacked_attention", *make_stacked_attention_pattern(start_op="split", with_bias=True), + _check_stacked_attention, ), ( "cutlass.stacked_attention", *make_stacked_attention_pattern(start_op="strided_slice"), + _check_stacked_attention, ), ( "cutlass.stacked_attention", *make_stacked_attention_pattern(start_op="strided_slice", with_bias=True), + _check_stacked_attention, ), ] diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py index 6197fe44ca70..7119c6c4b0e1 100644 --- a/python/tvm/relax/backend/patterns.py +++ b/python/tvm/relax/backend/patterns.py @@ -220,15 +220,16 @@ def make_stacked_attention_pattern(start_op: str, with_bias: bool = False): check function and codegen. """ stacked_qkv = wildcard() + ops = {} if start_op == "split": - qkv_tuple = is_op("relax.split")(stacked_qkv) + ops["split"] = qkv_tuple = is_op("relax.split")(stacked_qkv) query_raw = is_tuple_get_item(qkv_tuple, 0) key_raw = is_tuple_get_item(qkv_tuple, 1) value_raw = is_tuple_get_item(qkv_tuple, 2) elif start_op == "strided_slice": - query_raw = is_op("relax.strided_slice")(stacked_qkv) - key_raw = is_op("relax.strided_slice")(stacked_qkv) - value_raw = is_op("relax.strided_slice")(stacked_qkv) + ops["strided_slice_query"] = query_raw = is_op("relax.strided_slice")(stacked_qkv) + ops["strided_slice_key"] = key_raw = is_op("relax.strided_slice")(stacked_qkv) + ops["strided_slice_value"] = value_raw = is_op("relax.strided_slice")(stacked_qkv) else: raise NotImplementedError() query_reshape_list = wildcard() @@ -242,6 +243,7 @@ def make_stacked_attention_pattern(start_op: str, with_bias: bool = False): "query_reshape_list": query_reshape_list, "key_reshape_list": key_reshape_list, "value_reshape_list": value_reshape_list, + **ops, } if with_bias: bias = wildcard() diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index adce61f4b8fe..c9c36bfcd81a 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -1045,7 +1045,7 @@ class PatternBasedPartitioner : ExprVisitor { Map matched_bindings; for (const auto& [pat, match] : matched_result) { - if (pat->IsInstance()) { + if (pat->IsInstance() || pat->IsInstance()) { matched_bindings.Set(value_to_bound_var_[match], match); } } From d44351422f5ccb5ef961ac7abafdbb648dd8d1f4 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Tue, 18 Apr 2023 22:52:10 +0000 Subject: [PATCH 2/2] check the order of strided_slice --- python/tvm/relax/backend/contrib/cutlass.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index 457acdbb0d6f..0c2f38e3007a 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -243,10 +243,18 @@ def _check_stacked_attention(context: PatternCheckContext) -> bool: if not split_op.attrs.axis == 2: return False else: + last_end = 0 for name in ["query", "key", "value"]: assert f"strided_slice_{name}" in context.annotated_expr strided_slice_op = context.annotated_expr[f"strided_slice_{name}"] - if not (len(strided_slice_op.attrs.axes) == 1 and strided_slice_op.attrs.axes[0] == 2): + if list(strided_slice_op.attrs.axes) != [2]: + return False + if list(strided_slice_op.attrs.begin) != [last_end]: + return False + if not len(strided_slice_op.attrs.end) == 1: + return False + last_end = strided_slice_op.attrs.end[0] + if list(strided_slice_op.attrs.strides) != [1]: return False return True