Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 60 additions & 32 deletions python/tvm/contrib/cutlass/attention_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,45 +19,87 @@
from .library import *


def instantiate_attention_template(attrs, func_args):
def instantiate_attention_template(attrs):
"""Return CUTLASS host code for fused multi head attention
based on a template and the provided attribute map."""

bias_template = {
"B11S'": """
CHECK(${arg3}->ndim == 2); // B, 1, 1, S'
CHECK(${bias}->ndim == 2); // B, 1, 1, S'

p.attn_bias_ptr = reinterpret_cast<T *>(${arg3}->data);
p.attn_bias_ptr = reinterpret_cast<T *>(${bias}->data);
p.bias_strideM = 0; // 0
p.bias_strideH = 0; // 0
p.bias_strideB = p.num_keys; // S'
""",
"B1SS'": """
CHECK(${arg3}->ndim == 3); // B, 1, S, S'
CHECK(${bias}->ndim == 3); // B, 1, S, S'

p.attn_bias_ptr = reinterpret_cast<T *>(${arg3}->data);
p.attn_bias_ptr = reinterpret_cast<T *>(${bias}->data);
p.bias_strideM = p.num_keys; // S'
p.bias_strideH = 0; // 0
p.bias_strideB = p.bias_strideM * p.num_queries; // S' * S
""",
"BNSS'": """
CHECK(${arg3}->ndim == 4); // B, N, S, S'
CHECK(${bias}->ndim == 4); // B, N, S, S'

p.attn_bias_ptr = reinterpret_cast<T *>(${arg3}->data);
p.attn_bias_ptr = reinterpret_cast<T *>(${bias}->data);
p.bias_strideM = p.num_keys; // S'
p.bias_strideH = p.bias_strideM * p.num_queries; // S' * S
p.bias_strideB = p.bias_strideH * p.num_heads; // S' * S * N
""",
}

qkv_template = {
"default": """
p.query_ptr = reinterpret_cast<T *>(${query}->data);
p.key_ptr = reinterpret_cast<T *>(${key}->data);
p.value_ptr = reinterpret_cast<T *>(${value}->data);
CHECK(${query}->ndim == 4); // B, S, N, H
CHECK(${key}->ndim == 4); // B, S', N, H
CHECK(${value}->ndim == 4); // B, S', N, H'

// stride for N
p.q_strideH = p.head_dim; // H
p.k_strideH = p.head_dim; // H
p.v_strideH = p.head_dim_value; // H'

// stride for S
p.q_strideM = p.q_strideH * p.num_heads; // H * N
p.k_strideM = p.k_strideH * p.num_heads; // H * N
p.v_strideM = p.v_strideH * p.num_heads; // H' * N

// stride for B
p.q_strideB = p.q_strideM * p.num_queries; // H * N * S
p.k_strideB = p.k_strideM * p.num_keys; // H * N * S'
p.v_strideB = p.v_strideM * p.num_keys; // H'* N * S'
""",
"qkv_stacked": """
p.query_ptr = reinterpret_cast<T *>(${qkv}->data);
p.key_ptr = reinterpret_cast<T *>(${qkv}->data) + p.head_dim * p.num_heads;
p.value_ptr = reinterpret_cast<T *>(${qkv}->data) + p.head_dim * p.num_heads * 2;
CHECK(${qkv}->ndim == 3); // B, S, NH + NH + NH'

// stride for N
p.q_strideH = p.head_dim; // H
p.k_strideH = p.head_dim; // H
p.v_strideH = p.head_dim_value; // H'

// stride for S
p.q_strideM = p.k_strideM = p.v_strideM =
p.q_strideH * p.num_heads +
p.k_strideH * p.num_heads +
p.v_strideH * p.num_heads; // H * N + H * N + H * N'

// stride for B
p.q_strideB = p.k_strideB = p.v_strideB =
p.q_strideM * p.num_queries; // (H * N + H * N + H * N') * S
""",
}

template = """
using T = ${data_type};

CHECK(${arg0}->ndim == 4); // B, S, N, H
CHECK(${arg1}->ndim == 4); // B, S', N, H
CHECK(${arg2}->ndim == 4); // B, S', N, H'
CHECK(out0->ndim == 4); // B, S, N, H'

using Attention =
AttentionKernel<T,
/*ArchTag=*/${arch},
Expand All @@ -70,10 +112,6 @@ def instantiate_attention_template(attrs, func_args):
>;

typename Attention::Params p;

p.query_ptr = reinterpret_cast<T *>(${arg0}->data);
p.key_ptr = reinterpret_cast<T *>(${arg1}->data);
p.value_ptr = reinterpret_cast<T *>(${arg2}->data);
p.logsumexp_ptr = nullptr;
p.output_ptr = reinterpret_cast<T *>(out0->data);
p.output_accum_ptr = nullptr;
Expand All @@ -92,22 +130,11 @@ def instantiate_attention_template(attrs, func_args):
p.num_keys = ${num_keys}; // S'
p.scale = ${scale};

// stride for N
p.q_strideH = p.head_dim; // H
p.k_strideH = p.head_dim; // H
p.v_strideH = p.head_dim_value; // H'

// stride for S
p.q_strideM = p.q_strideH * p.num_heads; // H * N
p.k_strideM = p.k_strideH * p.num_heads; // H * N
p.v_strideM = p.v_strideH * p.num_heads; // H' * N
p.o_strideM = p.head_dim_value * p.num_heads; // H' * N
CHECK(out0->ndim == 4); // B, S, N, H'

// stride for B
p.q_strideB = p.q_strideM * p.num_queries; // H * N * S
p.k_strideB = p.k_strideM * p.num_keys; // H * N * S'
p.v_strideB = p.v_strideM * p.num_keys; // H'* N * S'

${qkv_template}
${bias_template}

constexpr auto kernel_fn = attention_kernel_batched_impl<Attention>;
Expand All @@ -126,9 +153,10 @@ def instantiate_attention_template(attrs, func_args):

template = substitute_template(
template,
{"bias_template": bias_template[attrs["bias_layout"]] if "bias_layout" in attrs else ""},
{
"qkv_template": qkv_template[attrs["qkv_layout"]],
"bias_template": bias_template[attrs["bias_layout"]] if "bias_layout" in attrs else "",
},
)

for i, arg in enumerate(func_args):
attrs["arg{}".format(i)] = arg
return substitute_template(template, attrs)
56 changes: 35 additions & 21 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,12 +546,17 @@ def _extract_relax_function_signature(f):

for i, arg in enumerate(f.params):
sinfo = arg.struct_info
signature["arg%d_shape" % i] = get_const_tuple(sinfo.shape)
signature["arg%d_dtype" % i] = sinfo.dtype
if isinstance(sinfo, relax.TensorStructInfo):
signature["arg%d_shape" % i] = get_const_tuple(sinfo.shape)
signature["arg%d_dtype" % i] = sinfo.dtype
elif isinstance(sinfo, relax.ShapeStructInfo):
signature["arg%d_shape" % i] = get_const_tuple(sinfo.values)
else:
raise NotImplementedError()

ret_sinfo = f.ret_struct_info
if ret_sinfo.shape is not None:
signature["ret_shape"] = list(ret_sinfo.shape)
signature["ret_shape"] = get_const_tuple(ret_sinfo.shape)
else:
signature["ret_shape"] = None
signature["ret_dtype"] = ret_sinfo.dtype
Expand Down Expand Up @@ -779,34 +784,42 @@ def handle_attention(self, f, op_type):
op_attrs = _get_call_node(f.body, "relax.nn.attention_bias").attrs
else:
raise ValueError(f"Cannot find call node for attention")
q_shape = signature["arg0_shape"]
k_shape = signature["arg1_shape"]
v_shape = signature["arg2_shape"]
arg = {}

if "stacked_attention" in op_type:
arg["arg0_shape"] = signature["arg0_shape"]
arg["arg0_dtype"] = signature["arg0_dtype"]
arg["arg1_shape"] = q_shape = signature["arg1_shape"]
arg["arg2_shape"] = k_shape = signature["arg2_shape"]
arg["arg3_shape"] = v_shape = signature["arg3_shape"]
if "arg4_dtype" in signature:
arg["bias_dtype"] = signature["arg4_dtype"]
if "arg4_shape" in signature:
arg["bias_shape"] = signature["arg4_shape"]
qkv_layout = "qkv_stacked"
else:
arg["arg0_shape"] = q_shape = signature["arg0_shape"]
arg["arg1_shape"] = k_shape = signature["arg1_shape"]
arg["arg2_shape"] = v_shape = signature["arg2_shape"]
arg["arg0_dtype"] = signature["arg0_dtype"]
arg["arg1_dtype"] = signature["arg1_dtype"]
arg["arg2_dtype"] = signature["arg2_dtype"]
if "arg3_dtype" in signature:
arg["bias_dtype"] = signature["arg3_dtype"]
if "arg3_shape" in signature:
arg["bias_shape"] = signature["arg3_shape"]
qkv_layout = "default"
out_shape = signature["ret_shape"]
q_dtype = signature["arg0_dtype"]
k_dtype = signature["arg1_dtype"]
v_dtype = signature["arg2_dtype"]
out_dtype = signature["ret_dtype"]
num_batches, num_queries, num_heads, head_dim = q_shape
_, num_keys, _, _ = k_shape
_, _, _, head_dim_value = v_shape
scale = op_attrs.scale
bias = {}
if "arg3_dtype" in signature:
bias["arg3_dtype"] = signature["arg3_dtype"]
if "arg3_shape" in signature:
bias["arg3_shape"] = signature["arg3_shape"]

return f.with_attrs(
{
"op_type": op_type,
"arg0_dtype": q_dtype,
"arg1_dtype": k_dtype,
"arg2_dtype": v_dtype,
"ret_dtype": out_dtype,
"arg0_shape": q_shape,
"arg1_shape": k_shape,
"arg2_shape": v_shape,
"ret_shape": out_shape,
"num_batches": num_batches,
"num_queries": num_queries,
Expand All @@ -816,7 +829,8 @@ def handle_attention(self, f, op_type):
"head_dim_value": head_dim_value,
"scale": scale,
"arch": self.options["sm"],
**bias,
"qkv_layout": qkv_layout,
**arg,
}
)

Expand Down
35 changes: 21 additions & 14 deletions python/tvm/contrib/cutlass/gen_tensor_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,12 +500,6 @@ def instantiate_template(func_name, annotations, func_args):
if k in annotations:
attrs[k] = annotations[k]

arg0_shape = annotations["arg0_shape"]
arg1_shape = annotations["arg1_shape"]
attrs["ElementInputA"] = DataTypeTag[dtype_map[annotations["arg0_dtype"]]]
attrs["ElementInputB"] = DataTypeTag[dtype_map[annotations["arg1_dtype"]]]
attrs["ElementOutput"] = DataTypeTag[dtype_map[annotations["ret_dtype"]]]

headers = []

if "relu" in func_name:
Expand Down Expand Up @@ -649,12 +643,12 @@ def get_batch_on_arg(arg_name, arg_shape):
if "conv2d_transpose" in func_name:
headers.append("cutlass/conv/kernel/default_conv2d_dgrad.h")
activation_shape = output_shape
output_shape = arg0_shape
output_shape = annotations["arg0_shape"]
elif "backward" in func_name:
headers.append("cutlass/conv/kernel/default_conv2d_wgrad.h")
activation_shape = arg1_shape
activation_shape = annotations["arg1_shape"]
weight_shape = output_shape
output_shape = arg0_shape
output_shape = annotations["arg0_shape"]
elif "residual" in func_name:
headers.append("cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h")
else:
Expand Down Expand Up @@ -731,21 +725,34 @@ def get_batch_on_arg(arg_name, arg_shape):
), "Cutlass may generate nan occasionally when scale == 0.0"
attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"])
attrs["kSupportsDropout"] = False
if len(func_args) > 3:
attrs["qkv_layout"] = annotations["qkv_layout"]
if attrs["qkv_layout"] == "default":
attrs["query"] = func_args[0]
attrs["key"] = func_args[1]
attrs["value"] = func_args[2]
if len(func_args) > 3:
attrs["bias"] = func_args[3]
elif attrs["qkv_layout"] == "qkv_stacked":
attrs["qkv"] = func_args[0]
if len(func_args) > 4:
attrs["bias"] = func_args[4]
else:
raise NotImplementedError()
if "bias" in attrs:
attrs["kSupportsBias"] = True
if len(annotations["arg3_shape"]) == 4:
if len(annotations["bias_shape"]) == 4:
attrs["bias_layout"] = "BNSS'"
elif len(annotations["arg3_shape"]) == 3:
elif len(annotations["bias_shape"]) == 3:
attrs["bias_layout"] = "B1SS'"
elif len(annotations["arg3_shape"]) == 2:
elif len(annotations["bias_shape"]) == 2:
attrs["bias_layout"] = "B11S'"
else:
raise NotImplementedError()
else:
# To support negative scale in current Cutlass implementation,
# kSupportsBias should be set true, or there are nan's as result.
attrs["kSupportsBias"] = attrs["scale"] < 0
code = instantiate_attention_template(attrs, func_args)
code = instantiate_attention_template(attrs)
return CodegenResult(code, headers)

raise ValueError("Do not have a template for {}".format(func_name))
9 changes: 9 additions & 0 deletions python/tvm/relax/backend/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
make_fused_bias_activation_pattern,
make_matmul_pattern,
make_residual_block_pattern,
make_stacked_attention_pattern,
)


Expand Down Expand Up @@ -244,6 +245,14 @@ def attention_patterns():
"cutlass.attention_bias",
*make_attention_pattern(with_bias=True),
),
(
"cutlass.stacked_attention",
Copy link
Member

@vinx13 vinx13 Apr 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the order of patterns here matter? If we have a subgraph containing both reshape and attention, will cutlass.attention that matches only a single attention operation be selected first?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the order matters here. I tried to change the order with stacked attention first, however, the original attention matches first.

*make_stacked_attention_pattern(),
),
(
"cutlass.stacked_attention",
*make_stacked_attention_pattern(with_bias=True),
),
]


Expand Down
49 changes: 48 additions & 1 deletion python/tvm/relax/backend/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from typing import Dict, Mapping, Tuple, Union

from tvm.relax.dpl.pattern import DFPattern, is_op, wildcard
from tvm.relax.dpl.pattern import DFPattern, is_op, is_tuple_get_item, wildcard


def _with_bias_activation_pattern(
Expand Down Expand Up @@ -168,6 +168,11 @@ def make_attention_pattern(with_bias: bool = False):
"""
Create pattern for fused multi head attention.

Parameters
----------
with_bias: bool
Whether or not to include bias addition

Returns
-------
pattern: DFPattern
Expand All @@ -190,3 +195,45 @@ def make_attention_pattern(with_bias: bool = False):
out = is_op("relax.nn.attention")(query, key, value)

return out, annotations


def make_stacked_attention_pattern(with_bias: bool = False):
"""
Create pattern for fused multi head attention with stacked input.

Parameters
----------
with_bias: bool
Whether or not to include bias addition

Returns
-------
pattern: DFPattern
The resulting pattern describing a fused multi head attention.

annotations: Mapping[str, DFPattern]
A mapping from name to sub pattern. It can be used to extract
important expressions from match result, to power the partition
check function and codegen.
"""
stacked_qkv = wildcard()
qkv_tuple = is_op("relax.split")(stacked_qkv)
query_reshape_list = wildcard()
key_reshape_list = wildcard()
value_reshape_list = wildcard()
query = is_op("relax.reshape")(is_tuple_get_item(qkv_tuple, 0), query_reshape_list)
key = is_op("relax.reshape")(is_tuple_get_item(qkv_tuple, 1), key_reshape_list)
value = is_op("relax.reshape")(is_tuple_get_item(qkv_tuple, 2), value_reshape_list)
annotations = {
"stacked_qkv": stacked_qkv,
"query_reshape_list": query_reshape_list,
"key_reshape_list": key_reshape_list,
"value_reshape_list": value_reshape_list,
}
if with_bias:
bias = wildcard()
annotations["bias"] = bias
out = is_op("relax.nn.attention_bias")(query, key, value, bias)
else:
out = is_op("relax.nn.attention")(query, key, value)
return out, annotations
Loading