diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index 06edd9febfd3..0c2f38e3007a 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -232,6 +232,33 @@ def residual_block_patterns(): return patterns +def _check_stacked_attention(context: PatternCheckContext) -> bool: + """Check if the given stacked attention workload can be offloaded to CUTLASS.""" + if _has_leaking_intermediate_variables(context): + return False + if not context.annotated_expr["stacked_qkv"].struct_info.ndim == 3: + return False + if "split" in context.annotated_expr: + split_op = context.annotated_expr["split"] + if not split_op.attrs.axis == 2: + return False + else: + last_end = 0 + for name in ["query", "key", "value"]: + assert f"strided_slice_{name}" in context.annotated_expr + strided_slice_op = context.annotated_expr[f"strided_slice_{name}"] + if list(strided_slice_op.attrs.axes) != [2]: + return False + if list(strided_slice_op.attrs.begin) != [last_end]: + return False + if not len(strided_slice_op.attrs.end) == 1: + return False + last_end = strided_slice_op.attrs.end[0] + if list(strided_slice_op.attrs.strides) != [1]: + return False + return True + + def attention_patterns(): """ Returns a list of all attention patterns in cutlass BYOC backend. @@ -248,18 +275,22 @@ def attention_patterns(): ( "cutlass.stacked_attention", *make_stacked_attention_pattern(start_op="split"), + _check_stacked_attention, ), ( "cutlass.stacked_attention", *make_stacked_attention_pattern(start_op="split", with_bias=True), + _check_stacked_attention, ), ( "cutlass.stacked_attention", *make_stacked_attention_pattern(start_op="strided_slice"), + _check_stacked_attention, ), ( "cutlass.stacked_attention", *make_stacked_attention_pattern(start_op="strided_slice", with_bias=True), + _check_stacked_attention, ), ] diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py index 6197fe44ca70..7119c6c4b0e1 100644 --- a/python/tvm/relax/backend/patterns.py +++ b/python/tvm/relax/backend/patterns.py @@ -220,15 +220,16 @@ def make_stacked_attention_pattern(start_op: str, with_bias: bool = False): check function and codegen. """ stacked_qkv = wildcard() + ops = {} if start_op == "split": - qkv_tuple = is_op("relax.split")(stacked_qkv) + ops["split"] = qkv_tuple = is_op("relax.split")(stacked_qkv) query_raw = is_tuple_get_item(qkv_tuple, 0) key_raw = is_tuple_get_item(qkv_tuple, 1) value_raw = is_tuple_get_item(qkv_tuple, 2) elif start_op == "strided_slice": - query_raw = is_op("relax.strided_slice")(stacked_qkv) - key_raw = is_op("relax.strided_slice")(stacked_qkv) - value_raw = is_op("relax.strided_slice")(stacked_qkv) + ops["strided_slice_query"] = query_raw = is_op("relax.strided_slice")(stacked_qkv) + ops["strided_slice_key"] = key_raw = is_op("relax.strided_slice")(stacked_qkv) + ops["strided_slice_value"] = value_raw = is_op("relax.strided_slice")(stacked_qkv) else: raise NotImplementedError() query_reshape_list = wildcard() @@ -242,6 +243,7 @@ def make_stacked_attention_pattern(start_op: str, with_bias: bool = False): "query_reshape_list": query_reshape_list, "key_reshape_list": key_reshape_list, "value_reshape_list": value_reshape_list, + **ops, } if with_bias: bias = wildcard() diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index adce61f4b8fe..c9c36bfcd81a 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -1045,7 +1045,7 @@ class PatternBasedPartitioner : ExprVisitor { Map matched_bindings; for (const auto& [pat, match] : matched_result) { - if (pat->IsInstance()) { + if (pat->IsInstance() || pat->IsInstance()) { matched_bindings.Set(value_to_bound_var_[match], match); } }