Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/cutlass
Submodule cutlass updated 25 files
+1 −1 docs/annotated.html
+1 −1 docs/default__mma__core__simt_8h_source.html
+1 −1 docs/hierarchy.html
+1 −1 docs/namespacecutlass_1_1transform.html
+1 −1 docs/pitch__linear__thread__map_8h.html
+1 −1 docs/pitch__linear__thread__map_8h_source.html
+1 −1 docs/structcutlass_1_1transform_1_1TransposePitchLinearThreadMap2DThreadTile.html
+46 −7 examples/41_fused_multi_head_attention/debug_utils.h
+6 −6 examples/41_fused_multi_head_attention/default_fmha_grouped.h
+0 −0 examples/41_fused_multi_head_attention/epilogue/epilogue_pipelined.h
+0 −0 examples/41_fused_multi_head_attention/epilogue/epilogue_rescale_output.h
+0 −0 examples/41_fused_multi_head_attention/epilogue/epilogue_thread_apply_logsumexp.h
+132 −11 examples/41_fused_multi_head_attention/fmha_grouped.h
+7 −2 examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu
+2 −0 examples/41_fused_multi_head_attention/gemm/find_default_mma.h
+17 −152 examples/41_fused_multi_head_attention/gemm/mma_accum_lambda_iterator.h
+263 −23 examples/41_fused_multi_head_attention/gemm/mma_from_smem.h
+437 −52 examples/41_fused_multi_head_attention/kernel_forward.h
+88 −0 examples/41_fused_multi_head_attention/transform/tile_smem_loader.h
+1 −1 include/cutlass/gemm/device/base_grouped.h
+52 −126 include/cutlass/gemm/kernel/gemm_universal_streamk.h
+2 −9 include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h
+1 −1 include/cutlass/transform/pitch_linear_thread_map.h
+1 −1 test/unit/common/cutlass_unit_test.h
+36 −16 tools/library/scripts/generator.py
134 changes: 134 additions & 0 deletions python/tvm/contrib/cutlass/attention_operation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-wildcard-import, wildcard-import
"""Generator for CUTLASS attention kernels."""
from .library import *


def instantiate_attention_template(attrs, func_args):
"""Return CUTLASS host code for fused multi head attention
based on a template and the provided attribute map."""

bias_template = {
"B11S'": """
CHECK(${arg3}->ndim == 2); // B, 1, 1, S'

p.attn_bias_ptr = reinterpret_cast<T *>(${arg3}->data);
p.bias_strideM = 0; // 0
p.bias_strideH = 0; // 0
p.bias_strideB = p.num_keys; // S'
""",
"B1SS'": """
CHECK(${arg3}->ndim == 3); // B, 1, S, S'

p.attn_bias_ptr = reinterpret_cast<T *>(${arg3}->data);
p.bias_strideM = p.num_keys; // S'
p.bias_strideH = 0; // 0
p.bias_strideB = p.bias_strideM * p.num_queries; // S' * S
""",
"BNSS'": """
CHECK(${arg3}->ndim == 4); // B, N, S, S'

p.attn_bias_ptr = reinterpret_cast<T *>(${arg3}->data);
p.bias_strideM = p.num_keys; // S'
p.bias_strideH = p.bias_strideM * p.num_queries; // S' * S
p.bias_strideB = p.bias_strideH * p.num_heads; // S' * S * N
""",
}

template = """
using T = ${data_type};

CHECK(${arg0}->ndim == 4); // B, S, N, H
CHECK(${arg1}->ndim == 4); // B, S', N, H
CHECK(${arg2}->ndim == 4); // B, S', N, H'
CHECK(out0->ndim == 4); // B, S, N, H'

using Attention =
AttentionKernel<T,
/*ArchTag=*/${arch},
/*is_aligned=*/${kIsAligned},
/*queries_per_block=*/${kQueriesPerBlock},
/*keys_per_block=*/${kKeysPerBlock},
/*single_value_iteration=*/${kSingleValueIteration},
/*supports_dropout=*/${kSupportsDropout},
/*supports_bias=*/${kSupportsBias}
>;

typename Attention::Params p;

p.query_ptr = reinterpret_cast<T *>(${arg0}->data);
p.key_ptr = reinterpret_cast<T *>(${arg1}->data);
p.value_ptr = reinterpret_cast<T *>(${arg2}->data);
p.logsumexp_ptr = nullptr;
p.output_ptr = reinterpret_cast<T *>(out0->data);
p.output_accum_ptr = nullptr;
if (Attention::kNeedsOutputAccumulatorBuffer) {
cudaMalloc(
&p.output_accum_ptr,
${output_size} * sizeof(Attention::output_accum_t)
);
}

p.num_heads = ${num_heads}; // N
p.num_batches = ${num_batches}; // B
p.head_dim = ${head_dim}; // H
p.head_dim_value = ${head_dim_value}; // H'
p.num_queries = ${num_queries}; // S
p.num_keys = ${num_keys}; // S'
p.scale = 1.0f / sqrt(float(${head_dim}));

// stride for N
p.q_strideH = p.head_dim; // H
p.k_strideH = p.head_dim; // H
p.v_strideH = p.head_dim_value; // H'

// stride for S
p.q_strideM = p.q_strideH * p.num_heads; // H * N
p.k_strideM = p.k_strideH * p.num_heads; // H * N
p.v_strideM = p.v_strideH * p.num_heads; // H' * N
p.o_strideM = p.head_dim_value * p.num_heads; // H' * N

// stride for B
p.q_strideB = p.q_strideM * p.num_queries; // H * N * S
p.k_strideB = p.k_strideM * p.num_keys; // H * N * S'
p.v_strideB = p.v_strideM * p.num_keys; // H'* N * S'

${bias_template}

constexpr auto kernel_fn = attention_kernel_batched_impl<Attention>;
int smem_bytes = sizeof(typename Attention::SharedStorage);
if (smem_bytes > 0xc000) {
static bool once = [&]() {
cudaFuncSetAttribute(
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
return true;
}();
}

CHECK(Attention::check_supported(p));
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
"""
if attrs["kSupportsBias"]:
template = substitute_template(
template, {"bias_template": bias_template[attrs["bias_layout"]]}
)
else:
template = substitute_template(template, {"bias_template": ""})
for i, arg in enumerate(func_args):
attrs["arg{}".format(i)] = arg
return substitute_template(template, attrs)
47 changes: 47 additions & 0 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def _get_cutlass_compile_options(sm, threads, use_fast_math=False):
cutlass_root = _get_cutlass_path()
cutlass_include = os.path.join(cutlass_root, "include")
cutlass_util_include = os.path.join(cutlass_root, "tools/util/include")
cutlass_attention_include = os.path.join(cutlass_root, "examples/41_fused_multi_head_attention")

kwargs = {}
kwargs["cc"] = "nvcc"
Expand All @@ -71,6 +72,7 @@ def _get_cutlass_compile_options(sm, threads, use_fast_math=False):
"-std=c++17",
"-I" + cutlass_include,
"-I" + cutlass_util_include,
"-I" + cutlass_attention_include,
]
if use_fast_math:
kwargs["options"].append("-DCUTLASS_USE_TANH_FOR_SIGMOID")
Expand Down Expand Up @@ -756,6 +758,49 @@ def handle_matmul(self, f, op_type):
}
)

def handle_attention(self, f, op_type):
"""Tune and annotate a dense op."""
signature = _extract_relax_function_signature(f)

q_shape = signature["arg0_shape"]
k_shape = signature["arg1_shape"]
v_shape = signature["arg2_shape"]
out_shape = signature["ret_shape"]
q_dtype = signature["arg0_dtype"]
k_dtype = signature["arg1_dtype"]
v_dtype = signature["arg2_dtype"]
out_dtype = signature["ret_dtype"]
num_batches, num_queries, num_heads, head_dim = q_shape
_, num_keys, _, _ = k_shape
_, _, _, head_dim_value = v_shape
bias = {}
if "arg3_dtype" in signature:
bias["arg3_dtype"] = signature["arg3_dtype"]
if "arg3_shape" in signature:
bias["arg3_shape"] = signature["arg3_shape"]

return f.with_attrs(
{
"op_type": op_type,
"arg0_dtype": q_dtype,
"arg1_dtype": k_dtype,
"arg2_dtype": v_dtype,
"ret_dtype": out_dtype,
"arg0_shape": q_shape,
"arg1_shape": k_shape,
"arg2_shape": v_shape,
"ret_shape": out_shape,
"num_batches": num_batches,
"num_queries": num_queries,
"num_keys": num_keys,
"num_heads": num_heads,
"head_dim": head_dim,
"head_dim_value": head_dim_value,
"arch": self.options["sm"],
**bias,
}
)

def visit_function_(self, f):
if "Composite" not in f.attrs:
body = super().visit_expr(f.body)
Expand All @@ -767,6 +812,8 @@ def visit_function_(self, f):
return self.handle_conv2d(f, op_type)
elif "matmul" in op_type:
return self.handle_matmul(f, op_type)
elif "attention" in op_type:
return self.handle_attention(f, op_type)

raise ValueError("Unsupported composite {}".format(op_type))

Expand Down
69 changes: 57 additions & 12 deletions python/tvm/contrib/cutlass/gen_tensor_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@
from . import _ffi_api as ffi
from .conv2d_operation import instantiate_conv2d_template
from .gemm_operation import instantiate_gemm_template
from .attention_operation import instantiate_attention_template
from .library import (
DataType,
DataTypeSize,
DataTypeTag,
EpilogueFunctor,
MathInstruction,
Expand Down Expand Up @@ -549,7 +551,7 @@ def get_batch_stride(stride_annot, arg0_idx, arg1_idx, arg0_axis_idx, arg1_axis_
attrs["ElementInputB"] = DataTypeTag[dtype_map[annotations[f"arg{rhs_arg_idx}_dtype"]]]
attrs["ElementOutput"] = DataTypeTag[dtype_map[annotations["ret_dtype"]]]

attrs["K"] = str(int(lhs_shape[lhs_batched_offset + 1]))
attrs["K"] = lhs_shape[lhs_batched_offset + 1]
attrs["M"] = get_dim(lhs_shape[lhs_batched_offset], lhs_arg, 0, lhs_batched_offset)

if transposed:
Expand Down Expand Up @@ -630,27 +632,70 @@ def get_batch_stride(stride_annot, arg0_idx, arg1_idx, arg0_axis_idx, arg1_axis_
attrs["N"] = get_dim(activation_shape[0], activation_var, 0)
attrs["H"] = get_dim(activation_shape[1], activation_var, 1)
attrs["W"] = get_dim(activation_shape[2], activation_var, 2)
attrs["C"] = str(int(activation_shape[3]))
attrs["C"] = activation_shape[3]
attrs["P"] = get_dim(output_shape[1], "out0", 1)
attrs["Q"] = get_dim(output_shape[2], "out0", 2)
attrs["K"] = str(int(output_shape[3]))
attrs["R"] = str(int(weight_shape[1]))
attrs["S"] = str(int(weight_shape[2]))
attrs["pad_h"] = str(int(annotations["padding"][0]))
attrs["pad_w"] = str(int(annotations["padding"][1]))
attrs["stride_h"] = str(int(annotations["strides"][0]))
attrs["stride_w"] = str(int(annotations["strides"][1]))
attrs["dilation_h"] = str(int(annotations["dilation"][0]))
attrs["dilation_w"] = str(int(annotations["dilation"][1]))
attrs["K"] = output_shape[3]
attrs["R"] = weight_shape[1]
attrs["S"] = weight_shape[2]
attrs["pad_h"] = annotations["padding"][0]
attrs["pad_w"] = annotations["padding"][1]
attrs["stride_h"] = annotations["strides"][0]
attrs["stride_w"] = annotations["strides"][1]
attrs["dilation_h"] = annotations["dilation"][0]
attrs["dilation_w"] = annotations["dilation"][1]

if "splitk" in op_name:
attrs["split_k_mode"] = "kParallel"
attrs["split_k_slices"] = str(re.search(r"splitk(\d+)", op_name).group(1))
else:
attrs["split_k_mode"] = "kSerial"
attrs["split_k_slices"] = "1"
attrs["split_k_slices"] = 1

code = instantiate_conv2d_template(attrs, func_args)
return CodegenResult(code, headers)

elif "attention" in func_name:
headers.append("kernel_forward.h")
data_type = dtype_map[annotations["arg0_dtype"]]
attrs["data_type"] = DataTypeTag[data_type]
attrs["num_batches"] = b = annotations["num_batches"]
attrs["num_queries"] = s = annotations["num_queries"]
attrs["num_keys"] = annotations["num_keys"]
attrs["num_heads"] = n = annotations["num_heads"]
attrs["head_dim"] = h = annotations["head_dim"]
attrs["head_dim_value"] = h_v = annotations["head_dim_value"]
data_type_size = DataTypeSize[data_type]
if (data_type_size * h // 8) % 16 == 0 and (data_type_size * h_v // 8) % 16 == 0:
attrs["kIsAligned"] = True
elif (h % 4 == 0) and (h_v % 4 == 0):
attrs["kIsAligned"] = False
else:
raise NotImplementedError()
if h_v > 64:
attrs["kQueriesPerBlock"] = 32
attrs["kKeysPerBlock"] = 128
attrs["kSingleValueIteration"] = h_v <= 128
else:
attrs["kQueriesPerBlock"] = 64
attrs["kKeysPerBlock"] = 64
attrs["kSingleValueIteration"] = True
attrs["output_size"] = b * s * n * h_v
attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"])
attrs["kSupportsDropout"] = False
if len(func_args) > 3:
attrs["kSupportsBias"] = True
if len(annotations["arg3_shape"]) == 4:
attrs["bias_layout"] = "BNSS'"
elif len(annotations["arg3_shape"]) == 3:
attrs["bias_layout"] = "B1SS'"
elif len(annotations["arg3_shape"]) == 2:
attrs["bias_layout"] = "B11S'"
else:
raise NotImplementedError()
else:
attrs["kSupportsBias"] = False
code = instantiate_attention_template(attrs, func_args)
return CodegenResult(code, headers)

raise ValueError("Do not have a template for {}".format(func_name))
6 changes: 6 additions & 0 deletions python/tvm/contrib/cutlass/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import enum
from enum import auto as enum_auto

from tvm.tir.expr import IntImm


class GeneratorTarget(enum.Enum):
Library = enum_auto()
Expand Down Expand Up @@ -143,6 +145,10 @@ def substitute_template(template, values):
while changed:
changed = False
for key, value in values.items():
if isinstance(value, (int, IntImm)):
value = str(int(value))
elif isinstance(value, bool):
value = str(value).lower()
regex = "\\$\\{%s\\}" % key
newtext = re.sub(regex, value, text)
if newtext != text:
Expand Down
14 changes: 13 additions & 1 deletion python/tvm/relax/backend/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
from tvm.relax.dpl import DFPattern

from ..pattern_registry import get_patterns_with_prefix, register_patterns
from ..patterns import make_fused_bias_activation_pattern, make_matmul_pattern
from ..patterns import (
make_fused_bias_activation_pattern,
make_matmul_pattern,
make_attention_pattern,
)


def _get_static_shape(shape: ShapeExpr) -> Optional[Tuple[int]]:
Expand Down Expand Up @@ -157,6 +161,14 @@ def _check_matmul(
),
_check_matmul,
),
(
"cutlass.attention",
*make_attention_pattern(),
),
(
"cutlass.attention_bias",
*make_attention_pattern(with_bias=True),
),
]
)

Expand Down
27 changes: 27 additions & 0 deletions python/tvm/relax/backend/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,30 @@ def make_matmul_pattern(
out = is_op("relax.matmul")(lhs, rhs)

return _with_bias_activation_pattern(out, args, with_bias, activation)


def make_attention_pattern(with_bias: bool = False):
"""
Create pattern for fused multi head attention.

Returns
-------
pattern: DFPattern
The resulting pattern describing a fused multi head attention.

args: Mapping[str, DFPattern]
The mapping from arg name to its pattern. It can be used to extract
arg expression from match result.
"""
query = wildcard()
key = wildcard()
value = wildcard()
args = {"query": query, "key": key, "value": value}
if with_bias:
bias = wildcard()
args["bias"] = bias
out = is_op("relax.nn.attention_bias")(query, key, value, bias)
else:
out = is_op("relax.nn.attention")(query, key, value)

return out, args
Loading