diff --git a/python/tvm/relax/backend/dispatch_sampling.py b/python/tvm/relax/backend/dispatch_sampling.py index 68d162fdf19b..528529c723c9 100644 --- a/python/tvm/relax/backend/dispatch_sampling.py +++ b/python/tvm/relax/backend/dispatch_sampling.py @@ -36,7 +36,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: return super().visit_call_(call) if call.op.name == "relax.multinomial_from_uniform": - from tvm.relax.backend_tir import ( # pylint: disable=import-outside-toplevel + from tvm.relax.backend.gpu_generic import ( # pylint: disable=import-outside-toplevel generic_get_sample_index, gpu_multinomial_from_uniform, ) diff --git a/python/tvm/relax/backend/dispatch_sort_scan.py b/python/tvm/relax/backend/dispatch_sort_scan.py index b5a94619c228..9f7cbaee9a99 100644 --- a/python/tvm/relax/backend/dispatch_sort_scan.py +++ b/python/tvm/relax/backend/dispatch_sort_scan.py @@ -141,7 +141,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: and call.op.name == "relax.cumsum" and call.attrs.exclusive == 0 ): - from tvm.relax.backend_tir import ( # pylint: disable=import-outside-toplevel + from tvm.relax.backend.gpu_generic import ( # pylint: disable=import-outside-toplevel gpu_2d_continuous_cumsum, ) diff --git a/python/tvm/relax/backend/gpu_generic/__init__.py b/python/tvm/relax/backend/gpu_generic/__init__.py index ea2d2a2afb5a..d7c316d28cdc 100644 --- a/python/tvm/relax/backend/gpu_generic/__init__.py +++ b/python/tvm/relax/backend/gpu_generic/__init__.py @@ -15,10 +15,12 @@ # specific language governing permissions and limitations # under the License. """The Relax Metal backend compilation pipeline and other passes.""" +from .cumsum import gpu_2d_continuous_cumsum from .pipeline import ( + dataflow_lower_passes, finalize_passes, get_default_pipeline, legalize_passes, - dataflow_lower_passes, library_dispatch_passes, ) +from .sampling import generic_get_sample_index, gpu_multinomial_from_uniform diff --git a/python/tvm/relax/backend_tir/cumsum.py b/python/tvm/relax/backend/gpu_generic/cumsum.py similarity index 100% rename from python/tvm/relax/backend_tir/cumsum.py rename to python/tvm/relax/backend/gpu_generic/cumsum.py diff --git a/python/tvm/relax/backend_tir/sampling.py b/python/tvm/relax/backend/gpu_generic/sampling.py similarity index 100% rename from python/tvm/relax/backend_tir/sampling.py rename to python/tvm/relax/backend/gpu_generic/sampling.py diff --git a/python/tvm/relax/backend_tir/__init__.py b/python/tvm/relax/backend_tir/__init__.py deleted file mode 100644 index b64bdcda6bb6..000000000000 --- a/python/tvm/relax/backend_tir/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Relax backends, tir based""" - -from . import contrib -from .cumsum import gpu_2d_continuous_cumsum -from .pattern import get_tir_pattern -from .sampling import gpu_multinomial_from_uniform, generic_get_sample_index diff --git a/python/tvm/relax/backend_tir/contrib/__init__.py b/python/tvm/relax/backend_tir/contrib/__init__.py deleted file mode 100644 index 9274f22374b9..000000000000 --- a/python/tvm/relax/backend_tir/contrib/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""External backend codegen modules for Relax, tir based.""" - -from .cutlass import cutlass_fcodegen diff --git a/python/tvm/relax/backend_tir/contrib/cutlass.py b/python/tvm/relax/backend_tir/contrib/cutlass.py deleted file mode 100644 index 0dbe31c468ad..000000000000 --- a/python/tvm/relax/backend_tir/contrib/cutlass.py +++ /dev/null @@ -1,720 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name,comparison-with-callable,unused-variable,missing-function-docstring -"""codegen for cutlass""" -import operator -from functools import reduce -from typing import List, Dict, Any - -from tvm.contrib.cutlass.build import _get_cutlass_path, _get_cutlass_compile_options -from tvm.contrib.nvcc import get_target_compute_version -from tvm.contrib.cutlass.library import LayoutType, ConvKind -from tvm.contrib.cutlass.gen_tensor_op import instantiate_template -from tvm.contrib.cutlass.gen_gemm import CutlassGemmProfiler -from tvm.contrib.cutlass.gen_conv2d import CutlassConv2DProfiler -from ..pattern import ( - MatchResult, - matmul_rrr_fp16, - bias_row_2d_fp16, - bias_row_1d_fp16, - batch_bias_row_2d_fp16, - batch_bias_row_1d_fp16, - relu_fp16, - erf_3d_fp32, - batch_matmul_rrr_2d_fp16, - batch_matmul_rrr_3d_fp16, - conv2d_nhwc_fp16, - padding_2d_nhwc_fp16, - copy_4d_fp16, - bias_add_nhwc_2d_fp16, - bias_add_nhwc_1d_fp16, - elem_add_4d_fp16, - elem_mul_3d_fp16, - scalar_add_3d_fp16, - scalar_mul_3d_fp16, - cast_3d_fp16, - cast_3d_fp32, -) - -#### helper functions #### -# list representing the anchor ops -# in the future more layouts/dtypes can be supported -MATMUL_LIST = [matmul_rrr_fp16] -MATMUL_BIAS_LIST = [bias_row_2d_fp16, bias_row_1d_fp16] -BATCH_MATMUL_LIST = [batch_matmul_rrr_2d_fp16, batch_matmul_rrr_3d_fp16] -BATCH_MATMUL_BIAS_LIST = [batch_bias_row_2d_fp16, batch_bias_row_1d_fp16] -CONV2D_LIST = [conv2d_nhwc_fp16] - -# attributes for anchor ops used in code generation -OP_PATTERN_ATTR_LIST = { - matmul_rrr_fp16: { - "arg0_dtype": "float16", - "arg1_dtype": "float16", - "ret_dtype": "float16", - }, - batch_matmul_rrr_2d_fp16: { - "arg0_dtype": "float16", - "arg1_dtype": "float16", - "ret_dtype": "float16", - }, - batch_matmul_rrr_3d_fp16: { - "arg0_dtype": "float16", - "arg1_dtype": "float16", - "ret_dtype": "float16", - }, - conv2d_nhwc_fp16: { - "arg0_dtype": "float16", - "arg1_dtype": "float16", - "ret_dtype": "float16", - # in the future we can add layout here - }, -} - - -def _get_cutlass_code(attr): - pattern = attr["op_type"] - if pattern.startswith("cutlass.matmul"): - return cutlass_codegen_gemm(attr) - elif pattern.startswith("cutlass.conv2d"): - return cutlass_codegen_conv2d(attr) - else: - raise ValueError("op not supported") - - -def _final_code(code, headers, func_args): - res = "" - res += "#define DMLC_USE_LOGGING_LIBRARY \n" - res += "#include \n" - res += "#include \n" - res += "#include \n" - res += "#include \n" - res += "#include \n" - res += "#include \n" - res += "#include \n" - res += "#include \n" - - for header in headers: - res += "#include <" + header + ">\n" - res += "namespace {\n" - res += "using namespace tvm;\n" - res += "using namespace tvm::runtime;\n" - res += "void _cutlass_kernel(" - for arg in func_args: - res += "NDArray " + arg + ", " - res += "NDArray out0) {" - res += code - res += "}\n" - res += "} // namespace\n" - res += "TVM_DLL_EXPORT_TYPED_FUNC({global_symbol}, _cutlass_kernel);\n" - return res - - -#### cutlass patterns #### -def matmul_bias_relu(match_results, attr, get_code=True): - if len(match_results) < 3: - return None - attr = matmul_bias(match_results[:2], attr, get_code=False) - if attr is None or match_results[2].pattern != relu_fp16: - return None - m_bias, n_bias = match_results[1].symbol_values - m_relu, n_relu = match_results[2].symbol_values - A_bias, B_bias, C_bias = match_results[1].matched_buffers - A_relu, B_relu = match_results[2].matched_buffers - if m_bias == m_relu and n_bias == n_relu and C_bias == A_relu: - attr["op_type"] = "cutlass.matmul_bias_relu" - return [_get_cutlass_code(attr=attr), 3, attr["args"]] if get_code else attr - return None - - -def matmul_bias(match_results, attr, get_code=True): - if len(match_results) < 2: - return None - attr = matmul(match_results[:1], attr, get_code=False) - if attr is None or match_results[1].pattern not in MATMUL_BIAS_LIST: - return None - m_matmul, n_matmul, k_matmul = match_results[0].symbol_values - m_bias, n_bias = match_results[1].symbol_values - A_matmul, B_matmul, C_matmul = match_results[0].matched_buffers - A_bias, B_bias, C_bias = match_results[1].matched_buffers - if m_matmul == m_bias and n_matmul == n_bias and C_matmul == A_bias: - attr["op_type"] = "cutlass.matmul_bias" - attr["bias_arg_idx"] = 2 - attr["args"].append(B_bias) - return [_get_cutlass_code(attr=attr), 2, attr["args"]] if get_code else attr - return None - - -def matmul(match_results, attr, get_code=True): - if len(match_results) < 1: - return None - if match_results[0].pattern in MATMUL_LIST: - # matmul - attr["op_type"] = "cutlass.matmul" - return [_get_cutlass_code(attr=attr), 1, attr["args"]] if get_code else attr - return None - - -def batch_matmul_bias_gelu(match_results, attr, get_code=True): - if len(match_results) < 9: - return None - attr = batch_matmul_bias(match_results[:2], attr, get_code=False) # batch_matmul, batch_bias - if ( - attr is None - or match_results[2].pattern != scalar_mul_3d_fp16 - or match_results[3].pattern != cast_3d_fp32 - or match_results[4].pattern != erf_3d_fp32 - or match_results[5].pattern != cast_3d_fp16 - or match_results[6].pattern != scalar_mul_3d_fp16 - or match_results[7].pattern != scalar_add_3d_fp16 - or match_results[8].pattern != elem_mul_3d_fp16 - ): - return None - - def shape_match_3d(shape1, shape2): - if len(shape1) < 3 or len(shape2) < 3: - return False - return shape1[0] == shape2[0] and shape1[1] == shape2[1] and shape1[2] == shape2[2] - - for i in range(1, 8): - if not shape_match_3d(match_results[i].symbol_values, match_results[i + 1].symbol_values): - return None - - if not ( - match_results[1].matched_buffers[-1] == match_results[2].matched_buffers[0] - and match_results[2].matched_buffers[-1] == match_results[3].matched_buffers[0] - and match_results[3].matched_buffers[-1] == match_results[4].matched_buffers[0] - and match_results[4].matched_buffers[-1] == match_results[5].matched_buffers[0] - and match_results[5].matched_buffers[-1] == match_results[6].matched_buffers[0] - and match_results[6].matched_buffers[-1] == match_results[7].matched_buffers[0] - and match_results[1].matched_buffers[-1] == match_results[8].matched_buffers[0] - and match_results[7].matched_buffers[-1] == match_results[8].matched_buffers[1] - ): - return None - - if ( - abs(float(match_results[2].symbol_values[-1] - 0.5**0.5)) > 1e-5 - or abs(float(match_results[6].symbol_values[-1] - 0.5)) > 1e-5 - or abs(float(match_results[7].symbol_values[-1] - 0.5)) > 1e-5 - ): - return None - - attr["op_type"] = "cutlass.matmul_bias_gelu" - return [_get_cutlass_code(attr=attr), 9, attr["args"]] if get_code else attr - - -def batch_matmul_bias_residual_mul(match_results, attr, get_code=True): - if len(match_results) < 3: - return None - attr = batch_matmul_bias(match_results[:2], attr, get_code=False) # batch_matmul, batch_bias - if attr is None or match_results[2].pattern != elem_mul_3d_fp16: - return None - ( - b_bias, - m_bias, - n_bias, - ) = match_results[1].symbol_values - ( - b_mul, - m_mul, - n_mul, - ) = match_results[2].symbol_values - A_bias, B_bias, C_bias = match_results[1].matched_buffers - A_mul, B_mul, C_mul = match_results[2].matched_buffers - if b_bias == b_mul and m_bias == m_mul and n_bias == n_mul and C_bias == A_mul: - attr["op_type"] = "cutlass.matmul_bias_residual_multiply" - attr["residual_arg_idx"] = 3 - return [_get_cutlass_code(attr=attr), 3, attr["args"]] if get_code else attr - return None - - -def batch_matmul_bias(match_results, attr, get_code=True): - if len(match_results) < 2: - return None - attr = batch_matmul(match_results[:1], attr, get_code=False) - if attr is None or match_results[1].pattern not in BATCH_MATMUL_BIAS_LIST: - return None - ( - b_matmul, - m_matmul, - n_matmul, - k_matmul, - ) = match_results[0].symbol_values - ( - b_bias, - m_bias, - n_bias, - ) = match_results[1].symbol_values - A_matmul, B_matmul, C_matmul = match_results[0].matched_buffers - A_bias, B_bias, C_bias = match_results[1].matched_buffers - if b_matmul == b_bias and m_matmul == m_bias and n_matmul == n_bias and C_matmul == A_bias: - attr["op_type"] = "cutlass.matmul_bias" - attr["bias_arg_idx"] = 2 - attr["args"].append(B_bias) - return [_get_cutlass_code(attr=attr), 2, attr["args"]] if get_code else attr - return None - - -def batch_matmul(match_results, attr, get_code=True): - if len(match_results) < 1: - return None - if match_results[0].pattern in BATCH_MATMUL_LIST: - attr["op_type"] = "cutlass.matmul" - return [_get_cutlass_code(attr=attr), 1, attr["args"]] if get_code else attr - return None - - -def conv2d_bias_residual_add(match_results, attr, get_code=True): - if len(match_results) < 4: - return None - attr = conv2d_bias(match_results[:3], attr, get_code=False) - if attr is None or match_results[3].pattern != elem_add_4d_fp16: - return None - N_bias, H_bias, W_bias, C_bias = match_results[2].symbol_values - in1_bias, in2_bias, out_bias = match_results[2].matched_buffers - N_add, H_add, W_add, C_add = match_results[3].symbol_values - in1_add, in2_add, out_add = match_results[3].matched_buffers - if ( - N_bias == N_add - and H_bias == H_add - and W_bias == W_add - and C_bias == C_add - and out_bias in [in1_add, in2_add] - ): - attr["op_type"] = "cutlass.conv2d_bias_residual_add" - attr["residual_arg_idx"] = 3 - attr["args"].append(in2_add if out_bias == in1_add else in1_add) - return [_get_cutlass_code(attr=attr), 4, attr["args"]] if get_code else attr - return None - - -def conv2d_bias(match_results, attr, get_code=True): - if len(match_results) < 3: - return None - attr = conv2d(match_results[:2], attr, get_code=False) - if attr is None or ( - match_results[2].pattern not in [bias_add_nhwc_2d_fp16, bias_add_nhwc_1d_fp16] - ): - return None - (N_conv, pH_conv, pW_conv, H_conv, W_conv, C_conv, O_conv,) = match_results[ - 1 - ].symbol_values[:7] - A_pad_conv, B_conv, out_conv = match_results[1].matched_buffers - N_bias, H_bias, W_bias, C_bias = match_results[2].symbol_values - A_bias, B_bias, out_bias = match_results[2].matched_buffers - if ( - N_bias == N_conv - and H_bias == H_conv - and W_bias == W_conv - and C_bias == O_conv - and out_conv == A_bias - ): - attr["op_type"] = "cutlass.conv2d_bias" - attr["bias_arg_idx"] = 2 - attr["args"].append(B_bias) - return [_get_cutlass_code(attr=attr), 3, attr["args"]] if get_code else attr - return None - - -def conv2d(match_results, attr, get_code=True): - if len(match_results) < 2: - return None - if ( - match_results[0].pattern in [padding_2d_nhwc_fp16, copy_4d_fp16] - and match_results[1].pattern == conv2d_nhwc_fp16 - ): - if match_results[0].pattern == padding_2d_nhwc_fp16: - ( - N_pad, - H_pad, - W_pad, - C_pad, - pH_pad, - pW_pad, - lH_pad, - lW_pad, - rH_pad, - rW_pad, - ) = match_results[0].symbol_values - else: - ( - N_pad, - H_pad, - W_pad, - C_pad, - ) = match_results[0].symbol_values - pH_pad = rH_pad = H_pad - pW_pad = rW_pad = W_pad - lH_pad = lW_pad = 0 - ( - N_conv, - pH_conv, - pW_conv, - H_conv, - W_conv, - C_conv, - O_conv, - KH_conv, - KW_conv, - stride_h_conv, - stride_w_conv, - dilation_h_conv, - dilation_w_conv, - ) = match_results[1].symbol_values - A, A_pad = match_results[0].matched_buffers - A_pad_conv, B_conv, out_conv = match_results[1].matched_buffers - if ( - N_pad == N_conv - and pH_pad == pH_conv - and pW_pad == pW_conv - and C_pad == C_conv - and A_pad == A_pad_conv - ): - if ( - lH_pad == pH_pad - rH_pad - and lW_pad == pW_pad - rW_pad - and lH_pad + H_pad == rH_pad - and lW_pad + W_pad == rW_pad - ): - padding = (lH_pad, lW_pad) - strides = (stride_h_conv, stride_w_conv) - dilation = (dilation_h_conv, dilation_w_conv) - attr["padding"] = padding - attr["strides"] = strides - attr["dilation"] = dilation - attr["op_type"] = "cutlass.conv2d" - return [_get_cutlass_code(attr=attr), 2, attr["args"]] if get_code else attr - return None - - -### cutlass codegen functions ### -def compile_options(target, threads=-1, use_fast_math=False): - compute_version = int("".join(get_target_compute_version(target).split("."))) - kwargs = _get_cutlass_compile_options(compute_version, threads, use_fast_math) - kwargs["options"].remove("-c") - return kwargs - - -def cutlass_fcodegen(sm=80, bin_dir="./bin"): - gemm_profiler = CutlassGemmProfiler(sm, _get_cutlass_path(), bin_dir) - conv2d_profiler = CutlassConv2DProfiler(sm, _get_cutlass_path(), bin_dir) - - def cutlass_codegen_with_match_results(match_results: List[MatchResult]): - """generate cutlass code with match results""" - nonlocal gemm_profiler - nonlocal conv2d_profiler - - assert len(match_results) > 0 - - # add shape into attr - if match_results[0].pattern in MATMUL_LIST: - A_matmul, B_matmul, C_matmul = match_results[0].matched_buffers - attr: Dict[Any, Any] = OP_PATTERN_ATTR_LIST[match_results[0].pattern] - attr["args"] = [A_matmul, B_matmul] - attr["arg0_shape"] = A_matmul.shape - attr["arg1_shape"] = B_matmul.shape - attr["ret_shape"] = C_matmul.shape - attr["lhs_arg_idx"] = 0 - attr["rhs_arg_idx"] = 1 - elif match_results[0].pattern in BATCH_MATMUL_LIST: - A_matmul, B_matmul, C_matmul = match_results[0].matched_buffers - attr = OP_PATTERN_ATTR_LIST[match_results[0].pattern] - attr["args"] = [A_matmul, B_matmul] - attr["arg0_shape"] = A_matmul.shape - attr["arg1_shape"] = B_matmul.shape - attr["ret_shape"] = C_matmul.shape - attr["lhs_arg_idx"] = 0 - attr["rhs_arg_idx"] = 1 - elif len(match_results) >= 1 and match_results[1].pattern in CONV2D_LIST: - A_input = match_results[0].matched_buffers[0] - A_conv2d, B_conv2d, C_conv2d = match_results[1].matched_buffers - attr = OP_PATTERN_ATTR_LIST[match_results[1].pattern] - attr["args"] = [A_input, B_conv2d] - attr["arg0_shape"] = A_input.shape - attr["arg1_shape"] = B_conv2d.shape - attr["ret_shape"] = C_conv2d.shape - attr["lhs_arg_idx"] = 0 - attr["rhs_arg_idx"] = 1 - else: - return ["", 0] - - # add profiler into attr - attr["gemm_profiler"] = gemm_profiler - attr["conv2d_profiler"] = conv2d_profiler - - cutlass_patterns = [ - # 9 - batch_matmul_bias_gelu, - # 4 - conv2d_bias_residual_add, - # 3 - batch_matmul_bias_residual_mul, - matmul_bias_relu, - conv2d_bias, - # 2 - matmul_bias, - batch_matmul_bias, - conv2d, - # 1 - matmul, - batch_matmul, - ] - for pattern in cutlass_patterns: - res = pattern(match_results, attr) - if res is not None: - return res - - return ["", 0] - - return cutlass_codegen_with_match_results - - -def cutlass_codegen_gemm(attrs): - """cutlass codegen for gemm""" - gemm_profiler = attrs["gemm_profiler"] - op_type = attrs["op_type"] - lhs_shape = attrs["arg0_shape"] - rhs_shape = attrs["arg1_shape"] - MM = lhs_shape[-2] - KK = lhs_shape[-1] - if "transposed" in op_type: - NN = rhs_shape[-2] - ldb = "K" - layout_b = LayoutType.ColumnMajor - else: - NN = rhs_shape[-1] - ldb = "N" - layout_b = LayoutType.RowMajor - - lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1) - rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1) - if lhs_batches == 1 and rhs_batches == 1: - # Regular matmul - is_batched = False - batch_attrs = {} - else: - is_batched = True - batch_attrs = { - # If both lhs_batches and rhs_batches are greater than 1, - # they must be equal. This is checked by is_shape_valid_for_cutlass_matmul. - "batch": lhs_batches if rhs_batches == 1 else rhs_batches, - "batch_stride_A": 0 if lhs_batches == 1 else MM * KK, - "batch_stride_B": 0 if rhs_batches == 1 else KK * NN, - "batch_stride_C": MM * NN, - } - op_name, op_def, _ = gemm_profiler.profile( - op_type, - MM, - NN, - KK, - attrs["ret_dtype"], - attrs["arg0_dtype"], - attrs["arg1_dtype"], - False, - batched=is_batched, - find_first_valid=False, - use_multiprocessing=True, - layout_b=layout_b, - ) - attrs["cutlass_op_name"] = op_name - attrs["cutlass_op_def"] = op_def - attrs["lda"] = "K" - attrs["ldb"] = ldb - attrs["ldc"] = "N" - attrs.update(batch_attrs) - del attrs["gemm_profiler"] - del attrs["conv2d_profiler"] - - nargs = 2 - if "bias_arg_idx" in attrs: - nargs += 1 - if "residual_arg_idx" in attrs: - nargs += 1 - func_args = ["inp" + str(i) for i in range(nargs)] - - # A temporary solution to handle batch matmul residual cases - # TODO(@bohan): remove this after initialize_template supports bmm residual - if op_type in [ - "cutlass.matmul_bias_residual_multiply", - ]: - - def _convert_dtype_str(dtype): - if isinstance(dtype, list): - arr = [] - for t in dtype: - arr.append(_convert_dtype_str(t)) - return arr - elif isinstance(dtype, str): - if dtype == "float16": - return "cutlass::half_t" - elif dtype == "float32": - return "float" - raise ValueError("dtype not supported") - - typea, typeb, typec = _convert_dtype_str( - [attrs["arg0_dtype"], attrs["arg1_dtype"], attrs["ret_dtype"]] - ) - - text = f""" -#define CUTLASS_ENABLE_CUBLAS 1 -#define CUTLASS_NAMESPACE cutlass -#define CUTLASS_ENABLE_TENSOR_CORE_MMA 1 -#define NDEBUG -#include -#include -#include -#include -#include -#include -#include -#include "cutlass/epilogue/thread/activation.h" -#include "cutlass/epilogue/thread/linear_combination_residual_block.h" -#include "cutlass/gemm/device/gemm_universal_with_broadcast.h" -#include -#include -#include -#include -#define DMLC_USE_LOGGING_LIBRARY -#include -#include -#include -namespace {{ -using namespace tvm; -using namespace tvm::runtime; -void _BHGEMM(NDArray A, NDArray B, NDArray Bias, NDArray D, NDArray C) {{ - // A: [Batch, M, K], B: [1, K, N]/[K, N], Bias: [1, N]/[N], D: [Batch, M, N], C: [Batch, M, N] - CHECK_EQ(A->ndim, 3); - int bdim = B->ndim; - int bias_dim = Bias->ndim; - CHECK_EQ(C->ndim, 3); - CHECK_EQ(A->shape[2], B->shape[bdim - 2]); - CHECK_EQ(Bias->shape[bias_dim - 1], B->shape[bdim - 1]); - CHECK_EQ(D->ndim, 3); - CHECK_EQ(D->shape[0], A->shape[0]); - CHECK_EQ(D->shape[1], A->shape[1]); - CHECK_EQ(D->shape[2], B->shape[bdim - 1]); - CHECK_EQ(C->shape[0], A->shape[0]); - CHECK_EQ(C->shape[1], A->shape[1]); - CHECK_EQ(C->shape[2], B->shape[bdim - 1]); - int64_t M = A->shape[0] * A->shape[1]; - int64_t N = B->shape[bdim - 1]; - int64_t K = A->shape[2]; - int64_t input_a_batch_stride = M * K; - int64_t input_a_stride = K; - int64_t input_a_offset = 0; // default to 0 - int64_t input_b_batch_stride = K * N; - int64_t input_b_stride = N; - int64_t input_b_offset = 0; // default to 0 - int64_t output_stride = N; - int64_t output_offset = 0; - int64_t a_size = 1; - a_size *= A->shape[0]; - a_size *= A->shape[1]; - a_size *= A->shape[2]; - - int64_t b_size = 1; - b_size *= B->shape[bias_dim - 2]; - b_size *= B->shape[bias_dim - 1]; - - int64_t c_size = 1; - c_size *= C->shape[0]; - c_size *= C->shape[1]; - c_size *= C->shape[2]; - - // Define the GEMM operation - {op_def} - using kernel = Operation_{op_name}; - using ElementComputeEpilogue = typename kernel::ElementAccumulator; - typename kernel::Arguments arguments({{ - cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode - {{M, N, K}}, // GemmCoord problem_size - 1, // int batch_count - {{ElementComputeEpilogue(1), ElementComputeEpilogue(1)}}, // typename EpilogueOutputOp::Params epilogue - ({typea}*)(A->data) + input_a_offset, // void const * ptr_A - ({typeb}*)(B->data) + input_b_offset, // void const * ptr_B - ({typec}*)(D->data), // void const * ptr_C1 - ({typec}*)(C->data) + output_offset, // void * ptr_D - ({typea}*)(Bias->data), // void * ptr_Vector - nullptr, // void * ptr_Tensor - input_a_batch_stride, // int64_t batch_stride_A - input_b_batch_stride, // int64_t batch_stride_B - 0, // int64_t batch_stride_C1 - 0, // int64_t batch_stride_D - 0, // int64_t batch_stride_Vector - 0, // int64_t batch_stride_Tensor - input_a_stride, // typename LayoutA::Stride::Index lda - input_b_stride, // typename LayoutB::Stride::Index ldb - N, // typename LayoutC::Stride::Index ldc1 - output_stride, // typename LayoutC::Stride::Index ldd - 0, // typename LayoutC::Stride::Index ldr - 0, // typename LayoutC::Stride::Index ldt - }}); - kernel gemm_op; - size_t workspace_size = gemm_op.get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); - cutlass::Status status = gemm_op.can_implement(arguments); - CHECK(status == cutlass::Status::kSuccess); - status = gemm_op.initialize(arguments, workspace.get()); - CHECK(status == cutlass::Status::kSuccess); - status = gemm_op(); - CHECK(status == cutlass::Status::kSuccess); - return; -}} -}} // namespace -TVM_DLL_EXPORT_TYPED_FUNC({{global_symbol}}, _BHGEMM); - """ - return text - - code = instantiate_template(op_type, attrs, func_args) - return _final_code(code.code, code.headers, func_args) - - -def cutlass_codegen_conv2d(attrs): - """cutlass codegen for conv2d""" - # cutlass backend only supports nhwc for now - conv2d_profiler = attrs["conv2d_profiler"] - op_type = attrs["op_type"] - conv_kind = ConvKind.Fprop - op_name, op_def, _ = conv2d_profiler.profile( - op_type=attrs["op_type"], - d_shape=attrs["arg0_shape"], - w_shape=attrs["arg1_shape"], - padding=attrs["padding"], - stride=attrs["strides"], - dilation=attrs["dilation"], - out_dtype=attrs["ret_dtype"], - data_dtype=attrs["arg0_dtype"], - weight_dtype=attrs["arg1_dtype"], - use_3xtf32=False, - conv_kind=conv_kind, - split_k_slices=[1], - profile_all_alignments=True, - find_first_valid=False, - use_multiprocessing=True, - ) - attrs["cutlass_op_def"] = op_def - attrs["cutlass_op_name"] = op_name - del attrs["gemm_profiler"] - del attrs["conv2d_profiler"] - - nargs = 2 - if "bias_arg_idx" in attrs: - nargs += 1 - if "residual_arg_idx" in attrs: - nargs += 1 - func_args = ["inp" + str(i) for i in range(nargs)] - code = instantiate_template(op_type, attrs, func_args) - return _final_code(code.code, code.headers, func_args) diff --git a/python/tvm/relax/backend_tir/pattern.py b/python/tvm/relax/backend_tir/pattern.py deleted file mode 100644 index 10f7a3b1628d..000000000000 --- a/python/tvm/relax/backend_tir/pattern.py +++ /dev/null @@ -1,576 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name,missing-function-docstring,chained-comparison -"""TIR Patterns""" -from typing import List - -import tvm -from tvm.runtime import Object -import tvm._ffi - -from tvm.script import tir as T - - -@tvm._ffi.register_object("relax.MatchResult") -class MatchResult(Object): - """The match result of a TIR pattern.""" - - def __init__(self, pattern, symbol_values, matched_buffers): - self.__init_handle_by_constructor__( - tvm._ffi.MatchResult, pattern, symbol_values, matched_buffers - ) - - -@T.prim_func -def matmul_rrr_fp16( - var_rxplaceholder: T.handle, - var_rxplaceholder_1: T.handle, - var_matmul: T.handle, - M: T.int64, - N: T.int64, - K: T.int64, -) -> None: - # function attr dict - T.func_attr({"tir.noalias": True}) - rxplaceholder = T.match_buffer(var_rxplaceholder, [M, K], dtype="float16") - rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [K, N], dtype="float16") - matmul = T.match_buffer(var_matmul, [M, N], dtype="float16") - # body - # with T.block("root") - for i0, i1, i2 in T.grid(M, N, K): - with T.block("matmul"): - i0_1, i1_1, k = T.axis.remap("SSR", [i0, i1, i2]) - T.reads(rxplaceholder[i0_1, k], rxplaceholder_1[k, i1_1]) - T.writes(matmul[i0_1, i1_1]) - with T.init(): - matmul[i0_1, i1_1] = T.float16(0) - matmul[i0_1, i1_1] = ( - matmul[i0_1, i1_1] + rxplaceholder[i0_1, k] * rxplaceholder_1[k, i1_1] - ) - - -@T.prim_func -def bias_row_2d_fp16( - var_rxplaceholder: T.handle, - var_rxplaceholder_1: T.handle, - var_T_add: T.handle, - M: T.int64, - N: T.int64, -) -> None: - # function attr dict - T.func_attr({"tir.noalias": True}) - rxplaceholder = T.match_buffer(var_rxplaceholder, [M, N], dtype="float16") - rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [T.int64(1), N], dtype="float16") - T_add = T.match_buffer(var_T_add, [M, N], dtype="float16") - # body - # with T.block("root") - for i0, i1 in T.grid(M, N): - with T.block("T_add"): - ax0, ax1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[T.int64(0), ax1]) - T.writes(T_add[ax0, ax1]) - T_add[ax0, ax1] = rxplaceholder[ax0, ax1] + rxplaceholder_1[T.int64(0), ax1] - - -@T.prim_func -def bias_row_1d_fp16( - var_rxplaceholder: T.handle, - var_rxplaceholder_1: T.handle, - var_T_add: T.handle, - M: T.int64, - N: T.int64, -) -> None: - # function attr dict - T.func_attr({"tir.noalias": True}) - rxplaceholder = T.match_buffer(var_rxplaceholder, [M, N], dtype="float16") - rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [N], dtype="float16") - T_add = T.match_buffer(var_T_add, [M, N], dtype="float16") - # body - # with T.block("root") - for i0, i1 in T.grid(M, N): - with T.block("T_add"): - ax0, ax1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[ax1]) - T.writes(T_add[ax0, ax1]) - T_add[ax0, ax1] = rxplaceholder[ax0, ax1] + rxplaceholder_1[ax1] - - -@T.prim_func -def batch_bias_row_2d_fp16( - var_rxplaceholder: T.handle, - var_rxplaceholder_1: T.handle, - var_T_add: T.handle, - batch: T.int64, - M: T.int64, - N: T.int64, -) -> None: - # function attr dict - T.func_attr({"tir.noalias": True}) - rxplaceholder = T.match_buffer(var_rxplaceholder, [batch, M, N], dtype="float16") - rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [T.int64(1), N], dtype="float16") - T_add = T.match_buffer(var_T_add, [batch, M, N], dtype="float16") - # body - # with T.block("root") - for i0, i1, i2 in T.grid(batch, M, N): - with T.block("T_add"): - ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(rxplaceholder[ax0, ax1, ax2], rxplaceholder_1[T.int64(0), ax2]) - T.writes(T_add[ax0, ax1, ax2]) - T_add[ax0, ax1, ax2] = rxplaceholder[ax0, ax1, ax2] + rxplaceholder_1[T.int64(0), ax2] - - -@T.prim_func -def batch_bias_row_1d_fp16( - var_rxplaceholder: T.handle, - var_rxplaceholder_1: T.handle, - var_T_add: T.handle, - batch: T.int64, - M: T.int64, - N: T.int64, -) -> None: - # function attr dict - T.func_attr({"tir.noalias": True}) - rxplaceholder = T.match_buffer(var_rxplaceholder, [batch, M, N], dtype="float16") - rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [N], dtype="float16") - T_add = T.match_buffer(var_T_add, [batch, M, N], dtype="float16") - # body - # with T.block("root") - for i0, i1, i2 in T.grid(batch, M, N): - with T.block("T_add"): - ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(rxplaceholder[ax0, ax1, ax2], rxplaceholder_1[ax2]) - T.writes(T_add[ax0, ax1, ax2]) - T_add[ax0, ax1, ax2] = rxplaceholder[ax0, ax1, ax2] + rxplaceholder_1[ax2] - - -@T.prim_func -def relu_fp16(var_rxplaceholder: T.handle, var_compute: T.handle, M: T.int64, N: T.int64) -> None: - # function attr dict - T.func_attr({"tir.noalias": True}) - rxplaceholder = T.match_buffer(var_rxplaceholder, [M, N], dtype="float16") - compute = T.match_buffer(var_compute, [M, N], dtype="float16") - # body - # with T.block("root") - for i0, i1 in T.grid(M, N): - with T.block("compute"): - i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[i0_1, i1_1]) - T.writes(compute[i0_1, i1_1]) - compute[i0_1, i1_1] = T.max(rxplaceholder[i0_1, i1_1], T.float16(0)) - - -@T.prim_func -def batch_matmul_rrr_2d_fp16( - var_rxplaceholder: T.handle, - var_rxplaceholder_1: T.handle, - var_matmul: T.handle, - batch: T.int64, - M: T.int64, - N: T.int64, - K: T.int64, -) -> None: - # function attr dict - T.func_attr({"tir.noalias": True}) - rxplaceholder = T.match_buffer(var_rxplaceholder, [batch, M, K], dtype="float16") - rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [K, N], dtype="float16") - matmul = T.match_buffer(var_matmul, [batch, M, N], dtype="float16") - # body - # with T.block("root") - for i0, i1, i2, i3 in T.grid(batch, M, N, K): - with T.block("matmul"): - i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3]) - T.reads(rxplaceholder[i0_1, i1_1, k], rxplaceholder_1[k, i2_1]) - T.writes(matmul[i0_1, i1_1, i2_1]) - with T.init(): - matmul[i0_1, i1_1, i2_1] = T.float16(0) - matmul[i0_1, i1_1, i2_1] = ( - matmul[i0_1, i1_1, i2_1] + rxplaceholder[i0_1, i1_1, k] * rxplaceholder_1[k, i2_1] - ) - - -@T.prim_func -def batch_matmul_rrr_3d_fp16( - var_rxplaceholder: T.handle, - var_rxplaceholder_1: T.handle, - var_matmul: T.handle, - batch: T.int64, - M: T.int64, - N: T.int64, - K: T.int64, -) -> None: - # function attr dict - T.func_attr({"tir.noalias": True}) - rxplaceholder = T.match_buffer(var_rxplaceholder, [batch, M, K], dtype="float16") - rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [batch, K, N], dtype="float16") - matmul = T.match_buffer(var_matmul, [batch, M, N], dtype="float16") - # body - # with T.block("root") - for i0, i1, i2, i3 in T.grid(batch, M, N, K): - with T.block("matmul"): - i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3]) - T.reads(rxplaceholder[i0_1, i1_1, k], rxplaceholder_1[i0_1, k, i2_1]) - T.writes(matmul[i0_1, i1_1, i2_1]) - with T.init(): - matmul[i0_1, i1_1, i2_1] = T.float16(0) - matmul[i0_1, i1_1, i2_1] = ( - matmul[i0_1, i1_1, i2_1] - + rxplaceholder[i0_1, i1_1, k] * rxplaceholder_1[i0_1, k, i2_1] - ) - - -@T.prim_func -def copy_4d_fp16( - A_handle: T.handle, - B_handle: T.handle, - N: T.int64, - H: T.int64, - W: T.int64, - C: T.int64, -) -> None: - A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16") - B = T.match_buffer(B_handle, [N, H, W, C], dtype="float16") - # body - # with T.block("root") - for n, h, w, c in T.grid(N, H, W, C): - with T.block("copy"): - vn, vh, vw, vc = T.axis.remap("SSSS", [n, h, w, c]) - T.reads(A[vn, vh, vw, vc]) - T.writes(B[vn, vh, vw, vc]) - B[vn, vh, vw, vc] = A[vn, vh, vw, vc] - - -@T.prim_func -def padding_2d_nhwc_fp16( - A_handle: T.handle, - B_handle: T.handle, - N: T.int64, - H: T.int64, - W: T.int64, - C: T.int64, - pH: T.int64, - pW: T.int64, - lH: T.int64, - lW: T.int64, - rH: T.int64, - rW: T.int64, -) -> None: - A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16") - B = T.match_buffer(B_handle, [N, pH, pW, C], dtype="float16") - # body - # with T.block("root") - for v, v_1, v_2, v_3 in T.grid(N, pH, pW, C): - with T.block("copy"): - v_4, v_5, v_6, v_7 = T.axis.remap("SSSS", [v, v_1, v_2, v_3]) - T.reads(A[v_4, v_5 - lH, v_6 - lW, v_7]) - T.writes(B[v_4, v_5, v_6, v_7]) - B[v_4, v_5, v_6, v_7] = T.if_then_else( - lH <= v_5 and v_5 < rH and lW <= v_6 and v_6 < rW, - A[v_4, v_5 - lH, v_6 - lW, v_7], - T.float16(0), - dtype="float16", - ) - - -@T.prim_func -def conv2d_nhwc_fp16( - A_handle: T.handle, - B_handle: T.handle, - out_handle: T.handle, - N: T.int64, - pH: T.int64, - pW: T.int64, - H: T.int64, - W: T.int64, - C: T.int64, - O: T.int64, - KH: T.int64, - KW: T.int64, - StrideH: T.int64, - StrideW: T.int64, - DilateH: T.int64, - DilateW: T.int64, -) -> None: - A = T.match_buffer(A_handle, [N, pH, pW, C], dtype="float16") - B = T.match_buffer(B_handle, [O, KH, KW, C], dtype="float16") - out = T.match_buffer(out_handle, [N, H, W, O], dtype="float16") - # body - # with T.block("root") - for v, v_1, v_2, v_3, v_4, v_5, v_6 in T.grid(N, H, W, O, KH, KW, C): - with T.block("conv"): - v_7, v_8, v_9, v_10, v_11, v_12, v_13 = T.axis.remap( - "SSSSRRR", [v, v_1, v_2, v_3, v_4, v_5, v_6] - ) - T.reads( - A[v_7, v_11 * DilateH + v_8 * StrideH, v_12 * DilateW + v_9 * StrideW, v_13], - B[v_10, v_11, v_12, v_13], - ) - T.writes(out[v_7, v_8, v_9, v_10]) - with T.init(): - out[v_7, v_8, v_9, v_10] = T.float16(0) - out[v_7, v_8, v_9, v_10] = ( - out[v_7, v_8, v_9, v_10] - + A[v_7, v_11 * DilateH + v_8 * StrideH, v_12 * DilateW + v_9 * StrideW, v_13] - * B[v_10, v_11, v_12, v_13] - ) - - -@T.prim_func -def bias_add_nhwc_2d_fp16( - A_handle: T.handle, - B_handle: T.handle, - out_handle: T.handle, - N: T.int64, - H: T.int64, - W: T.int64, - C: T.int64, -): - A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16") - B = T.match_buffer(B_handle, [1, 1, 1, C], dtype="float16") - out = T.match_buffer(out_handle, [N, H, W, C], dtype="float16") - for ax0, ax1, ax2, ax3 in T.grid(N, H, W, C): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, T.int64(0), T.int64(0), v_ax3]) - T.writes(out[v_ax0, v_ax1, v_ax2, v_ax3]) - out[v_ax0, v_ax1, v_ax2, v_ax3] = ( - A[v_ax0, v_ax1, v_ax2, v_ax3] + B[v_ax0, T.int64(0), T.int64(0), v_ax3] - ) - - -@T.prim_func -def bias_add_nhwc_1d_fp16( - A_handle: T.handle, - B_handle: T.handle, - out_handle: T.handle, - N: T.int64, - H: T.int64, - W: T.int64, - C: T.int64, -): - A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16") - B = T.match_buffer(B_handle, [1, 1, 1, C], dtype="float16") - out = T.match_buffer(out_handle, [N, H, W, C], dtype="float16") - for ax0, ax1, ax2, ax3 in T.grid(N, H, W, C): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[T.int64(0), T.int64(0), T.int64(0), v_ax3]) - T.writes(out[v_ax0, v_ax1, v_ax2, v_ax3]) - out[v_ax0, v_ax1, v_ax2, v_ax3] = ( - A[v_ax0, v_ax1, v_ax2, v_ax3] + B[T.int64(0), T.int64(0), T.int64(0), v_ax3] - ) - - -@T.prim_func -def elem_add_2d_fp16( - in0_handle: T.handle, - in1_handle: T.handle, - out_handle: T.handle, - N: T.int64, - M: T.int64, -): - in0 = T.match_buffer(in0_handle, [N, M], dtype="float16") - in1 = T.match_buffer(in1_handle, [N, M], dtype="float16") - out = T.match_buffer(out_handle, [N, M], dtype="float16") - for ax0, ax1 in T.grid(N, M): - with T.block("T_add"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(in0[v_ax0, v_ax1], in1[v_ax0, v_ax1]) - T.writes(out[v_ax0, v_ax1]) - out[v_ax0, v_ax1] = in0[v_ax0, v_ax1] + in1[v_ax0, v_ax1] - - -@T.prim_func -def elem_add_3d_fp16( - in0_handle: T.handle, - in1_handle: T.handle, - out_handle: T.handle, - B: T.int64, - N: T.int64, - M: T.int64, -): - in0 = T.match_buffer(in0_handle, [B, N, M], dtype="float16") - in1 = T.match_buffer(in1_handle, [B, N, M], dtype="float16") - out = T.match_buffer(out_handle, [B, N, M], dtype="float16") - for ax0, ax1, ax2 in T.grid(B, N, M): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(in0[v_ax0, v_ax1, v_ax2], in1[v_ax0, v_ax1, v_ax2]) - T.writes(out[v_ax0, v_ax1, v_ax2]) - out[v_ax0, v_ax1, v_ax2] = in0[v_ax0, v_ax1, v_ax2] + in1[v_ax0, v_ax1, v_ax2] - - -@T.prim_func -def elem_add_4d_fp16( - A_handle: T.handle, - B_handle: T.handle, - out_handle: T.handle, - N: T.int64, - H: T.int64, - W: T.int64, - C: T.int64, -): - A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16") - B = T.match_buffer(B_handle, [N, H, W, C], dtype="float16") - out = T.match_buffer(out_handle, [N, H, W, C], dtype="float16") - for ax0, ax1, ax2, ax3 in T.grid(N, H, W, C): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(out[v_ax0, v_ax1, v_ax2, v_ax3]) - out[v_ax0, v_ax1, v_ax2, v_ax3] = ( - A[v_ax0, v_ax1, v_ax2, v_ax3] + B[v_ax0, v_ax1, v_ax2, v_ax3] - ) - - -@T.prim_func -def scalar_mul_3d_fp16( - inp0_handle: T.handle, - out_handle: T.handle, - D1: T.int64, - D2: T.int64, - D3: T.int64, - scalar: T.float16, -): - inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float16") - out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float16") - for ax0, ax1, ax2 in T.grid(D1, D2, D3): - with T.block("T_mul"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(inp0[v_ax0, v_ax1, v_ax2]) - T.writes(out[v_ax0, v_ax1, v_ax2]) - out[v_ax0, v_ax1, v_ax2] = inp0[v_ax0, v_ax1, v_ax2] * scalar - - -@T.prim_func -def erf_3d_fp32( - inp0_handle: T.handle, - out_handle: T.handle, - D1: T.int64, - D2: T.int64, - D3: T.int64, -): - inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float32") - out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float32") - for ax0, ax1, ax2 in T.grid(D1, D2, D3): - with T.block("T_erf"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(inp0[v_ax0, v_ax1, v_ax2]) - T.writes(out[v_ax0, v_ax1, v_ax2]) - out[v_ax0, v_ax1, v_ax2] = T.erf(inp0[v_ax0, v_ax1, v_ax2]) - - -@T.prim_func -def scalar_add_3d_fp16( - inp0_handle: T.handle, - out_handle: T.handle, - D1: T.int64, - D2: T.int64, - D3: T.int64, - scalar: T.float16, -): - inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float16") - out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float16") - for ax0, ax1, ax2 in T.grid(D1, D2, D3): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(inp0[v_ax0, v_ax1, v_ax2]) - T.writes(out[v_ax0, v_ax1, v_ax2]) - out[v_ax0, v_ax1, v_ax2] = scalar + inp0[v_ax0, v_ax1, v_ax2] - - -@T.prim_func -def elem_mul_3d_fp16( - inp0_handle: T.handle, - inp1_handle: T.handle, - out_handle: T.handle, - D1: T.int64, - D2: T.int64, - D3: T.int64, -): - inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float16") - inp1 = T.match_buffer(inp1_handle, [D1, D2, D3], dtype="float16") - out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float16") - for ax0, ax1, ax2 in T.grid(D1, D2, D3): - with T.block("T_mul"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(inp0[v_ax0, v_ax1, v_ax2], inp1[v_ax0, v_ax1, v_ax2]) - T.writes(out[v_ax0, v_ax1, v_ax2]) - out[v_ax0, v_ax1, v_ax2] = inp0[v_ax0, v_ax1, v_ax2] * inp1[v_ax0, v_ax1, v_ax2] - - -@T.prim_func -def cast_3d_fp16( - inp0_handle: T.handle, - out_handle: T.handle, - D1: T.int64, - D2: T.int64, - D3: T.int64, -): - inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float32") - out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float16") - for ax0, ax1, ax2 in T.grid(D1, D2, D3): - with T.block("T_cast"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(inp0[v_ax0, v_ax1, v_ax2]) - T.writes(out[v_ax0, v_ax1, v_ax2]) - out[v_ax0, v_ax1, v_ax2] = T.Cast("float16", inp0[v_ax0, v_ax1, v_ax2]) - - -@T.prim_func -def cast_3d_fp32( - inp0_handle: T.handle, - out_handle: T.handle, - D1: T.int64, - D2: T.int64, - D3: T.int64, -): - inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float16") - out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float32") - for ax0, ax1, ax2 in T.grid(D1, D2, D3): - with T.block("T_cast"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(inp0[v_ax0, v_ax1, v_ax2]) - T.writes(out[v_ax0, v_ax1, v_ax2]) - out[v_ax0, v_ax1, v_ax2] = T.Cast("float32", inp0[v_ax0, v_ax1, v_ax2]) - - -def get_tir_pattern() -> List[tvm.tir.PrimFunc]: - """Get the tir patterns for backend dispatch.""" - return [ - matmul_rrr_fp16, - bias_row_2d_fp16, - bias_row_1d_fp16, - batch_bias_row_2d_fp16, - batch_bias_row_1d_fp16, - relu_fp16, - erf_3d_fp32, - batch_matmul_rrr_2d_fp16, - batch_matmul_rrr_3d_fp16, - copy_4d_fp16, - padding_2d_nhwc_fp16, - conv2d_nhwc_fp16, - bias_add_nhwc_2d_fp16, - bias_add_nhwc_1d_fp16, - elem_add_2d_fp16, - elem_add_3d_fp16, - elem_add_4d_fp16, - elem_mul_3d_fp16, - scalar_add_3d_fp16, - scalar_mul_3d_fp16, - cast_3d_fp16, - cast_3d_fp32, - ] diff --git a/tests/python/relax/test_codegen_tir_cutlass.py b/tests/python/relax/test_codegen_tir_cutlass.py deleted file mode 100644 index 9670f1598670..000000000000 --- a/tests/python/relax/test_codegen_tir_cutlass.py +++ /dev/null @@ -1,702 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from __future__ import annotations -import tempfile - -from tvm import relax, runtime -import tvm -import tvm.testing -from tvm import relax -import scipy -from scipy.special import erf -import numpy as np -from tvm.target import Target -from tvm.relax.vm_build import build as relax_build -from tvm.script.ir_builder import relax as R -from tvm.script.ir_builder import ir as I -from tvm.script.ir_builder import tir as T -from tvm.script.ir_builder import IRBuilder - -from tvm.relax.backend_tir import get_tir_pattern -from tvm.relax.backend_tir.contrib.cutlass import cutlass_fcodegen, compile_options - -A_TYPE = "float16" -B_TYPE = "float16" -C_TYPE = "float16" - -target = Target("cuda") - - -def f_run(rt_mod: runtime.Module, device: runtime.ndarray.Device, *input): - vm = relax.vm.VirtualMachine(rt_mod=rt_mod, device=device) - return vm["main"](*input) - - -def build(mod): - mod = relax.transform.LegalizeOps()(mod) - mod = relax.transform.AnnotateTIROpPattern()(mod) - mod = relax.transform.FuseOps()(mod) - mod = relax.transform.FuseTIR()(mod) - mod = relax.transform.SplitCallTIRByPattern(get_tir_pattern(), cutlass_fcodegen())(mod) - mod = relax.transform.DeadCodeElimination()(mod) - executable = tvm.compile(mod, target) - return executable.jit(**compile_options(target)) - - -def build_and_run_reference(mod, inputs_np): - dev = tvm.device("llvm", 0) - ex = tvm.compile(mod, "llvm") - vm = relax.VirtualMachine(ex, dev) - f = vm["main"] - inputs = [tvm.nd.array(inp, dev) for inp in inputs_np] - return f(*inputs).numpy() - - -def constructGEMM(M, N, K): - with IRBuilder() as ib: # pylint: disable=invalid-name - with I.ir_module() as frame: - with R.function(): - R.func_name("main") - A = R.arg( - "A", relax.TensorStructInfo((M, K), A_TYPE) - ) # pylint: disable=invalid-name - B = R.arg( - "B", relax.TensorStructInfo((K, N), B_TYPE) - ) # pylint: disable=invalid-name - with R.dataflow() as df: - C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) - R.output(C) - (C,) = df.output_vars - R.func_ret_value(C) - relax_mod = ib.get() - return relax_mod - - -@tvm.testing.requires_cutlass -def test_cutlass_dense(): - m, n, k = 128, 64, 256 - executable = build(constructGEMM(m, n, k)) - dev = tvm.cuda() - A = np.random.randn(m, k).astype("float16") - B = np.random.randn(k, n).astype("float16") - A_tvm = tvm.nd.array(A, dev) - B_tvm = tvm.nd.array(B, dev) - result = f_run(executable, dev, A_tvm, B_tvm) - np.testing.assert_allclose(result.numpy(), A @ B, rtol=5e-2, atol=5e-2) - - -def constructGEMM_bias(M, N, K): - with IRBuilder() as ib: # pylint: disable=invalid-name - with I.ir_module() as frame: - with R.function(): - R.func_name("main") - A = R.arg( - "A", relax.TensorStructInfo((M, K), A_TYPE) - ) # pylint: disable=invalid-name - B = R.arg( - "B", relax.TensorStructInfo((K, N), B_TYPE) - ) # pylint: disable=invalid-name - bias = R.arg( - "bias", relax.TensorStructInfo((1, N), A_TYPE) - ) # pylint: disable=invalid-name - with R.dataflow() as df: - C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) - D = R.emit(R.add(C, bias)) - R.output(D) - (D,) = df.output_vars - R.func_ret_value(D) - relax_mod = ib.get() - return relax_mod - - -def constructGEMM_bias2(M, N, K): - with IRBuilder() as ib: # pylint: disable=invalid-name - with I.ir_module() as frame: - with R.function(): - R.func_name("main") - A = R.arg( - "A", relax.TensorStructInfo((M, K), A_TYPE) - ) # pylint: disable=invalid-name - B = R.arg( - "B", relax.TensorStructInfo((K, N), B_TYPE) - ) # pylint: disable=invalid-name - bias = R.arg( - "bias", relax.TensorStructInfo((N,), A_TYPE) - ) # pylint: disable=invalid-name - with R.dataflow() as df: - C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) - D = R.emit(R.add(C, bias)) - R.output(D) - (D,) = df.output_vars - R.func_ret_value(D) - relax_mod = ib.get() - return relax_mod - - -@tvm.testing.requires_cutlass -def test_cutlass_dense_bias(): - m, n, k = 128, 64, 256 - executable = build(constructGEMM_bias(m, n, k)) - dev = tvm.cuda() - A = np.random.randn(m, k).astype("float16") - B = np.random.randn(k, n).astype("float16") - bias = np.random.randn(1, n).astype("float16") - A_tvm = tvm.nd.array(A, dev) - B_tvm = tvm.nd.array(B, dev) - bias_tvm = tvm.nd.array(bias, dev) - result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) - np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2, atol=5e-2) - - -@tvm.testing.requires_cutlass -def test_cutlass_dense_bias2(): - m, n, k = 128, 64, 256 - executable = build(constructGEMM_bias2(m, n, k)) - dev = tvm.cuda() - A = np.random.randn(m, k).astype("float16") - B = np.random.randn(k, n).astype("float16") - bias = np.random.randn(n).astype("float16") - A_tvm = tvm.nd.array(A, dev) - B_tvm = tvm.nd.array(B, dev) - bias_tvm = tvm.nd.array(bias, dev) - result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) - np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2, atol=5e-2) - - -def constructGEMM_bias_relu(M, N, K): - with IRBuilder() as ib: # pylint: disable=invalid-name - with I.ir_module() as frame: - with R.function(): - R.func_name("main") - A = R.arg( - "A", relax.TensorStructInfo((M, K), A_TYPE) - ) # pylint: disable=invalid-name - B = R.arg( - "B", relax.TensorStructInfo((K, N), B_TYPE) - ) # pylint: disable=invalid-name - bias = R.arg( - "bias", relax.TensorStructInfo((1, N), A_TYPE) - ) # pylint: disable=invalid-name - with R.dataflow() as df: - C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) - D = R.emit(R.add(C, bias)) - E = R.emit(R.nn.relu(D)) - R.output(E) - (E,) = df.output_vars - R.func_ret_value(E) - relax_mod = ib.get() - return relax_mod - - -@tvm.testing.requires_cutlass -def test_cutlass_dense_bias_relu(): - m, n, k = 128, 64, 256 - executable = build(constructGEMM_bias_relu(m, n, k)) - dev = tvm.cuda() - A = np.random.randn(m, k).astype("float16") - B = np.random.randn(k, n).astype("float16") - bias = np.random.randn(1, n).astype("float16") - A_tvm = tvm.nd.array(A, dev) - B_tvm = tvm.nd.array(B, dev) - bias_tvm = tvm.nd.array(bias, dev) - result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) - np.testing.assert_allclose(result.numpy(), np.maximum(A @ B + bias, 0), rtol=5e-2, atol=5e-2) - - -def constructBatchGEMM(batch, M, N, K): - with IRBuilder() as ib: # pylint: disable=invalid-name - with I.ir_module() as frame: - with R.function(): - R.func_name("main") - A = R.arg( - "A", relax.TensorStructInfo((batch, M, K), A_TYPE) - ) # pylint: disable=invalid-name - B = R.arg( - "B", relax.TensorStructInfo((K, N), B_TYPE) - ) # pylint: disable=invalid-name - with R.dataflow() as df: - C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) - R.output(C) - (C,) = df.output_vars - R.func_ret_value(C) - relax_mod = ib.get() - return relax_mod - - -@tvm.testing.requires_cutlass -def test_cutlass_batch_dense(): - b, m, n, k = 2, 128, 256, 64 - executable = build(constructBatchGEMM(b, m, n, k)) - dev = tvm.cuda() - A = np.random.randn(b, m, k).astype("float16") - B = np.random.randn(k, n).astype("float16") - A_tvm = tvm.nd.array(A, dev) - B_tvm = tvm.nd.array(B, dev) - result = f_run(executable, dev, A_tvm, B_tvm) - np.testing.assert_allclose(result.numpy(), A @ B, rtol=5e-2, atol=5e-2) - - -def constructBatchGEMM2(batch, M, N, K): - with IRBuilder() as ib: # pylint: disable=invalid-name - with I.ir_module() as frame: - with R.function(): - R.func_name("main") - A = R.arg( - "A", relax.TensorStructInfo((batch, M, K), A_TYPE) - ) # pylint: disable=invalid-name - B = R.arg( - "B", relax.TensorStructInfo((batch, K, N), B_TYPE) - ) # pylint: disable=invalid-name - with R.dataflow() as df: - C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) - R.output(C) - (C,) = df.output_vars - R.func_ret_value(C) - relax_mod = ib.get() - return relax_mod - - -@tvm.testing.requires_cutlass -def test_cutlass_batch_dense2(): - b, m, n, k = 2, 128, 256, 64 - executable = build(constructBatchGEMM2(b, m, n, k)) - dev = tvm.cuda() - A = np.random.randn(b, m, k).astype("float16") - B = np.random.randn(b, k, n).astype("float16") - A_tvm = tvm.nd.array(A, dev) - B_tvm = tvm.nd.array(B, dev) - result = f_run(executable, dev, A_tvm, B_tvm) - np.testing.assert_allclose(result.numpy(), A @ B, rtol=5e-2, atol=5e-2) - - -def constructBatchGEMM_bias(batch, M, N, K): - with IRBuilder() as ib: # pylint: disable=invalid-name - with I.ir_module() as frame: - with R.function(): - R.func_name("main") - A = R.arg( - "A", relax.TensorStructInfo((batch, M, K), A_TYPE) - ) # pylint: disable=invalid-name - B = R.arg( - "B", relax.TensorStructInfo((K, N), B_TYPE) - ) # pylint: disable=invalid-name - bias = R.arg( - "bias", relax.TensorStructInfo((1, N), A_TYPE) - ) # pylint: disable=invalid-name - with R.dataflow() as df: - C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) - D = R.emit(R.add(C, bias)) - R.output(D) - (D,) = df.output_vars - R.func_ret_value(D) - relax_mod = ib.get() - return relax_mod - - -@tvm.testing.requires_cutlass -def test_cutlass_batch_dense_bias(): - b, m, n, k = 2, 128, 256, 64 - executable = build(constructBatchGEMM_bias(b, m, n, k)) - dev = tvm.cuda() - A = np.random.randn(b, m, k).astype("float16") - B = np.random.randn(k, n).astype("float16") - bias = np.random.randn(1, n).astype("float16") - A_tvm = tvm.nd.array(A, dev) - B_tvm = tvm.nd.array(B, dev) - bias_tvm = tvm.nd.array(bias, dev) - result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) - np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2, atol=5e-2) - - -def constructBatchGEMM_bias2(batch, M, N, K): - with IRBuilder() as ib: # pylint: disable=invalid-name - with I.ir_module() as frame: - with R.function(): - R.func_name("main") - A = R.arg( - "A", relax.TensorStructInfo((batch, M, K), A_TYPE) - ) # pylint: disable=invalid-name - B = R.arg( - "B", relax.TensorStructInfo((K, N), B_TYPE) - ) # pylint: disable=invalid-name - bias = R.arg( - "bias", relax.TensorStructInfo((N,), A_TYPE) - ) # pylint: disable=invalid-name - with R.dataflow() as df: - C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) - D = R.emit(R.add(C, bias)) - R.output(D) - (D,) = df.output_vars - R.func_ret_value(D) - relax_mod = ib.get() - return relax_mod - - -@tvm.testing.requires_cutlass -def test_cutlass_batch_dense_bias2(): - b, m, n, k = 2, 128, 256, 64 - executable = build(constructBatchGEMM_bias2(b, m, n, k)) - dev = tvm.cuda() - A = np.random.randn(b, m, k).astype("float16") - B = np.random.randn(k, n).astype("float16") - bias = np.random.randn(n).astype("float16") - A_tvm = tvm.nd.array(A, dev) - B_tvm = tvm.nd.array(B, dev) - bias_tvm = tvm.nd.array(bias, dev) - result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) - np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2, atol=5e-2) - - -def constructBatchGEMM_bias2_gelu(batch, M, N, K): - with IRBuilder() as ib: # pylint: disable=invalid-name - with I.ir_module() as frame: - with R.function(): - R.func_name("main") - A = R.arg( - "A", relax.TensorStructInfo((batch, M, K), A_TYPE) - ) # pylint: disable=invalid-name - B = R.arg( - "B", relax.TensorStructInfo((K, N), B_TYPE) - ) # pylint: disable=invalid-name - bias = R.arg( - "bias", relax.TensorStructInfo((N,), A_TYPE) - ) # pylint: disable=invalid-name - with R.dataflow() as df: - C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) - D = R.emit(R.add(C, bias)) - E = R.emit(R.nn.gelu(D)) - R.output(E) - (E,) = df.output_vars - R.func_ret_value(E) - relax_mod = ib.get() - return relax_mod - - -@tvm.testing.requires_cutlass -def test_cutlass_batch_dense_bias2_gelu(): - b, m, n, k = 2, 128, 64, 256 - executable = build(constructBatchGEMM_bias2_gelu(b, m, n, k)) - dev = tvm.cuda() - A = np.random.randn(b, m, k).astype("float16") - B = np.random.randn(k, n).astype("float16") - bias = np.random.randn(n).astype("float16") - A_tvm = tvm.nd.array(A, dev) - B_tvm = tvm.nd.array(B, dev) - bias_tvm = tvm.nd.array(bias, dev) - result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) - C = A @ B + bias - O = 0.5 * C * (1 + erf(C / np.sqrt(2))) - np.testing.assert_allclose(result.numpy(), O, rtol=5e-2, atol=5e-2) - - -def constructBatchGEMM_bias2_mul(batch, M, N, K): - with IRBuilder() as ib: # pylint: disable=invalid-name - with I.ir_module() as frame: - with R.function(): - R.func_name("main") - A = R.arg( - "A", relax.TensorStructInfo((batch, M, K), A_TYPE) - ) # pylint: disable=invalid-name - B = R.arg( - "B", relax.TensorStructInfo((K, N), B_TYPE) - ) # pylint: disable=invalid-name - bias = R.arg( - "bias", relax.TensorStructInfo((N,), A_TYPE) - ) # pylint: disable=invalid-name - residual = R.arg("residual", relax.TensorStructInfo((batch, M, N), A_TYPE)) - with R.dataflow() as df: - C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) - D = R.emit(R.add(C, bias)) - E = R.emit(R.multiply(D, residual)) - R.output(E) - (E,) = df.output_vars - R.func_ret_value(E) - relax_mod = ib.get() - return relax_mod - - -@tvm.testing.requires_cutlass -def test_cutlass_batch_dense_bias2_mul(): - b, m, n, k = 2, 128, 256, 64 - executable = build(constructBatchGEMM_bias2_mul(b, m, n, k)) - dev = tvm.cuda() - A = np.random.randn(b, m, k).astype("float16") - B = np.random.randn(k, n).astype("float16") - bias = np.random.randn(n).astype("float16") - residual = np.random.randn(b, m, n).astype("float16") - A_tvm = tvm.nd.array(A, dev) - B_tvm = tvm.nd.array(B, dev) - bias_tvm = tvm.nd.array(bias, dev) - residual_tvm = tvm.nd.array(residual, dev) - result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm, residual_tvm) - np.testing.assert_allclose(result.numpy(), ((A @ B) + bias) * residual, rtol=5e-2, atol=5e-2) - - -def constructBatchGEMM2_bias(batch, M, N, K): - with IRBuilder() as ib: # pylint: disable=invalid-name - with I.ir_module() as frame: - with R.function(): - R.func_name("main") - A = R.arg( - "A", relax.TensorStructInfo((batch, M, K), A_TYPE) - ) # pylint: disable=invalid-name - B = R.arg( - "B", relax.TensorStructInfo((batch, K, N), B_TYPE) - ) # pylint: disable=invalid-name - bias = R.arg( - "bias", relax.TensorStructInfo((1, N), A_TYPE) - ) # pylint: disable=invalid-name - with R.dataflow() as df: - C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) - D = R.emit(R.add(C, bias)) - R.output(D) - (D,) = df.output_vars - R.func_ret_value(D) - relax_mod = ib.get() - return relax_mod - - -@tvm.testing.requires_cutlass -def test_cutlass_batch_dense2_bias(): - b, m, n, k = 2, 128, 256, 64 - executable = build(constructBatchGEMM2_bias(b, m, n, k)) - dev = tvm.cuda() - A = np.random.randn(b, m, k).astype("float16") - B = np.random.randn(b, k, n).astype("float16") - bias = np.random.randn(1, n).astype("float16") - A_tvm = tvm.nd.array(A, dev) - B_tvm = tvm.nd.array(B, dev) - bias_tvm = tvm.nd.array(bias, dev) - result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) - np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2, atol=5e-2) - - -def constructConv2D(N, C, H, W, KH, KW, O, strides, padding, dilation): - from tvm.script.ir_builder import IRBuilder - from tvm.script.ir_builder import ir as I - from tvm.script.ir_builder import relax as R - from tvm.script.ir_builder import tir as T - - with IRBuilder() as ib: # pylint: disable=invalid-name - with I.ir_module() as frame: - with R.function(): - R.func_name("main") - x = R.arg( - "x", relax.TensorStructInfo((N, H, W, C), A_TYPE) - ) # pylint: disable=invalid-name - w = R.arg( - "w", relax.TensorStructInfo((O, KH, KW, C), B_TYPE) - ) # pylint: disable=invalid-name - with R.dataflow() as df: - C = R.emit( - R.nn.conv2d( - x, - w, - strides=strides, - padding=padding, - dilation=dilation, - groups=1, - data_layout="NHWC", - kernel_layout="OHWI", - out_layout="NHWC", - out_dtype=C_TYPE, - ) - ) - R.output(C) - (C,) = df.output_vars - R.func_ret_value(C) - mod = ib.get() - return mod - - -@tvm.testing.requires_cutlass -def test_cutlass_conv2d(): - n, c, h, w = 1, 3, 224, 224 - kh, kw, o = 3, 3, 64 - for strides in [(1, 1), (2, 2)]: - for padding in [(0, 0), (3, 3)]: - for dilation in [(1, 1), (4, 4)]: - mod = constructConv2D(n, c, h, w, kh, kw, o, strides, padding, dilation) - executable = build(mod) - dev = tvm.cuda() - np.random.seed(0) - A = np.random.randn(n, h, w, c).astype("float16") - B = np.random.randn(o, kh, kw, c).astype("float16") - A_tvm = tvm.nd.array(A, dev) - B_tvm = tvm.nd.array(B, dev) - result = f_run(executable, dev, A_tvm, B_tvm) - result_ref = build_and_run_reference(mod, [A, B]) - np.testing.assert_allclose( - result.numpy(), - result_ref, - rtol=5e-2, - atol=5e-2, - ) - - -def constructConv2D_bias(N, C, H, W, KH, KW, O, strides, padding, dilation): - from tvm.script.ir_builder import IRBuilder - from tvm.script.ir_builder import ir as I - from tvm.script.ir_builder import relax as R - from tvm.script.ir_builder import tir as T - - with IRBuilder() as ib: # pylint: disable=invalid-name - with I.ir_module() as frame: - with R.function(): - R.func_name("main") - x = R.arg( - "x", relax.TensorStructInfo((N, H, W, C), A_TYPE) - ) # pylint: disable=invalid-name - w = R.arg( - "w", relax.TensorStructInfo((O, KH, KW, C), B_TYPE) - ) # pylint: disable=invalid-name - bias = R.arg( - "bias", relax.TensorStructInfo((1, 1, 1, O), A_TYPE) - ) # pylint: disable=invalid-name - with R.dataflow() as df: - C = R.emit( - R.nn.conv2d( - x, - w, - strides=strides, - padding=padding, - dilation=dilation, - groups=1, - data_layout="NHWC", - kernel_layout="OHWI", - out_layout="NHWC", - out_dtype=C_TYPE, - ) - ) - D = R.emit(R.add(C, bias)) - R.output(D) - (D,) = df.output_vars - R.func_ret_value(D) - mod = ib.get() - return mod - - -@tvm.testing.requires_cutlass -def test_cutlass_conv2d_bias(): - c, h, w = 3, 224, 224 - kh, kw, o = 3, 3, 64 - for n in [1, 2]: - for strides in [(1, 1), (2, 2)]: - for padding in [(0, 0), (3, 3)]: - for dilation in [(1, 1), (4, 4)]: - mod = constructConv2D_bias(n, c, h, w, kh, kw, o, strides, padding, dilation) - executable = build(mod) - dev = tvm.cuda() - np.random.seed(0) - A = np.random.randn(n, h, w, c).astype("float16") - B = np.random.randn(o, kh, kw, c).astype("float16") - bias = np.random.randn(1, 1, 1, o).astype("float16") - A_tvm = tvm.nd.array(A, dev) - B_tvm = tvm.nd.array(B, dev) - bias_tvm = tvm.nd.array(bias, dev) - result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) - result_ref = build_and_run_reference(mod, [A, B, bias]) - np.testing.assert_allclose( - result.numpy(), - result_ref, - rtol=5e-2, - atol=5e-2, - ) - - -def constructConv2D_bias_add(N, C, H, W, KH, KW, O, OH, OW, strides, padding, dilation): - from tvm.script.ir_builder import IRBuilder - from tvm.script.ir_builder import ir as I - from tvm.script.ir_builder import relax as R - from tvm.script.ir_builder import tir as T - - with IRBuilder() as ib: # pylint: disable=invalid-name - with I.ir_module() as frame: - with R.function(): - R.func_name("main") - x = R.arg( - "x", relax.TensorStructInfo((N, H, W, C), A_TYPE) - ) # pylint: disable=invalid-name - w = R.arg( - "w", relax.TensorStructInfo((O, KH, KW, C), B_TYPE) - ) # pylint: disable=invalid-name - bias = R.arg( - "bias", relax.TensorStructInfo((1, 1, 1, O), A_TYPE) - ) # pylint: disable=invalid-name - res = R.arg( - "res", relax.TensorStructInfo((N, OH, OW, O), A_TYPE) - ) # pylint: disable=invalid-name - with R.dataflow() as df: - C = R.emit( - R.nn.conv2d( - x, - w, - strides=strides, - padding=padding, - dilation=dilation, - groups=1, - data_layout="NHWC", - kernel_layout="OHWI", - out_layout="NHWC", - out_dtype=C_TYPE, - ) - ) - D = R.emit(R.add(C, bias)) - E = R.emit(R.add(D, res)) - R.output(E) - (E,) = df.output_vars - R.func_ret_value(E) - mod = ib.get() - return mod - - -@tvm.testing.requires_cutlass -def test_cutlass_conv2d_bias_add(): - n, c, h, w = 2, 3, 224, 224 - kh, kw, o = 3, 3, 64 - for strides in [(1, 1), (2, 2)]: - for padding in [(0, 0), (3, 3)]: - for dilation in [(1, 1), (4, 4)]: - oh = (h + 2 * padding[0] - dilation[0] * (kh - 1) - 1) // strides[0] + 1 - ow = (w + 2 * padding[1] - dilation[1] * (kw - 1) - 1) // strides[1] + 1 - mod = constructConv2D_bias_add( - n, c, h, w, kh, kw, o, oh, ow, strides, padding, dilation - ) - executable = build(mod) - dev = tvm.cuda() - np.random.seed(0) - A = np.random.randn(n, h, w, c).astype("float16") - B = np.random.randn(o, kh, kw, c).astype("float16") - bias = np.random.randn(1, 1, 1, o).astype("float16") - res = np.random.randn(n, oh, ow, o).astype("float16") - A_tvm = tvm.nd.array(A, dev) - B_tvm = tvm.nd.array(B, dev) - bias_tvm = tvm.nd.array(bias, dev) - res_tvm = tvm.nd.array(res, dev) - result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm, res_tvm) - result_ref = build_and_run_reference(mod, [A, B, bias, res]) - np.testing.assert_allclose( - result.numpy(), - result_ref, - rtol=5e-2, - atol=5e-2, - ) - - -if __name__ == "__main__": - tvm.testing.main()