From d765566299538581b5b7d7ea83447b5d9d833693 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Fri, 10 Mar 2023 10:11:38 -0800 Subject: [PATCH 1/2] split call tir --- include/tvm/relax/tir_pattern.h | 75 ++ python/tvm/contrib/cutlass/gemm_operation.py | 3 +- python/tvm/relax/backend_tir/__init__.py | 20 + .../tvm/relax/backend_tir/contrib/__init__.py | 20 + .../tvm/relax/backend_tir/contrib/cutlass.py | 720 ++++++++++++++++ python/tvm/relax/backend_tir/pattern.py | 576 +++++++++++++ python/tvm/relax/transform/transform.py | 19 + src/relax/backend/vm/codegen_vm.cc | 11 + src/relax/ir/tir_pattern.cc | 37 + .../transform/split_call_tir_by_pattern.cc | 782 ++++++++++++++++++ .../python/relax/test_codegen_tir_cutlass.py | 755 +++++++++++++++++ 11 files changed, 3017 insertions(+), 1 deletion(-) create mode 100644 include/tvm/relax/tir_pattern.h create mode 100644 python/tvm/relax/backend_tir/__init__.py create mode 100644 python/tvm/relax/backend_tir/contrib/__init__.py create mode 100644 python/tvm/relax/backend_tir/contrib/cutlass.py create mode 100644 python/tvm/relax/backend_tir/pattern.py create mode 100644 src/relax/ir/tir_pattern.cc create mode 100644 src/relax/transform/split_call_tir_by_pattern.cc create mode 100644 tests/python/relax/test_codegen_tir_cutlass.py diff --git a/include/tvm/relax/tir_pattern.h b/include/tvm/relax/tir_pattern.h new file mode 100644 index 000000000000..02634dcbbf71 --- /dev/null +++ b/include/tvm/relax/tir_pattern.h @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir_pattern.h + * \brief Data Structure of TIR Pattern used for matching. + */ + +#ifndef TVM_RELAX_TIR_PATTERN_H_ +#define TVM_RELAX_TIR_PATTERN_H_ + +#include + +namespace tvm { +namespace relax { + +using TIRPattern = tir::PrimFunc; + +/* + * \brief The match result of a TIR pattern. + */ +class MatchResultNode : public Object { + public: + /*! The matched tir pattern*/ + TIRPattern pattern; + /*! \brief The evaluated values of symbolic vars. */ + Array symbol_values; + /*! \brief The matched buffers of input and output. */ + Array matched_buffers; + void VisitAttrs(AttrVisitor* v) { + v->Visit("pattern", &pattern); + v->Visit("symbol_values", &symbol_values); + v->Visit("matched_buffers", &matched_buffers); + } + static constexpr const char* _type_key = "relax.MatchResult"; + TVM_DECLARE_FINAL_OBJECT_INFO(MatchResultNode, Object); +}; + +/*! + * \brief Managed reference to MatchResultNode. + */ +class MatchResult : public ObjectRef { + public: + /*! + * \brief Constructor + * \param pattern The matched tir pattern. + * \param symbol_values The evaluated values of symbolic vars. + * \param matched_buffers The matched buffers of input and output. + */ + TVM_DLL explicit MatchResult(TIRPattern pattern, Array symbol_values, + Array matched_buffers); + + TVM_DEFINE_OBJECT_REF_METHODS(MatchResult, ObjectRef, MatchResultNode) +}; + +using FCodegen = runtime::TypedPackedFunc(Array match_results)>; +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_TIR_PATTERN_H_ diff --git a/python/tvm/contrib/cutlass/gemm_operation.py b/python/tvm/contrib/cutlass/gemm_operation.py index eb9f92dad39a..b820ead016fe 100644 --- a/python/tvm/contrib/cutlass/gemm_operation.py +++ b/python/tvm/contrib/cutlass/gemm_operation.py @@ -369,7 +369,8 @@ def instantiate_gemm_template(attrs): { "bias_decl": "void* ptr_bias = (void*)(${bias_arg}->data);\n", "ptr_c": "ptr_bias", - "c_stride": "${bias_arg}->ndim == 1 ? 0 : " + attrs["ldc"], + "c_stride": "(${bias_arg}->ndim == 1 || ${bias_arg}->shape[0] == 1) ? 0 : " + + attrs["ldc"], } ) else: diff --git a/python/tvm/relax/backend_tir/__init__.py b/python/tvm/relax/backend_tir/__init__.py new file mode 100644 index 000000000000..eeb8fe438f6e --- /dev/null +++ b/python/tvm/relax/backend_tir/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relax backends, tir based""" + +from . import contrib +from .pattern import get_tir_pattern diff --git a/python/tvm/relax/backend_tir/contrib/__init__.py b/python/tvm/relax/backend_tir/contrib/__init__.py new file mode 100644 index 000000000000..9274f22374b9 --- /dev/null +++ b/python/tvm/relax/backend_tir/contrib/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""External backend codegen modules for Relax, tir based.""" + +from .cutlass import cutlass_fcodegen diff --git a/python/tvm/relax/backend_tir/contrib/cutlass.py b/python/tvm/relax/backend_tir/contrib/cutlass.py new file mode 100644 index 000000000000..0dbe31c468ad --- /dev/null +++ b/python/tvm/relax/backend_tir/contrib/cutlass.py @@ -0,0 +1,720 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,comparison-with-callable,unused-variable,missing-function-docstring +"""codegen for cutlass""" +import operator +from functools import reduce +from typing import List, Dict, Any + +from tvm.contrib.cutlass.build import _get_cutlass_path, _get_cutlass_compile_options +from tvm.contrib.nvcc import get_target_compute_version +from tvm.contrib.cutlass.library import LayoutType, ConvKind +from tvm.contrib.cutlass.gen_tensor_op import instantiate_template +from tvm.contrib.cutlass.gen_gemm import CutlassGemmProfiler +from tvm.contrib.cutlass.gen_conv2d import CutlassConv2DProfiler +from ..pattern import ( + MatchResult, + matmul_rrr_fp16, + bias_row_2d_fp16, + bias_row_1d_fp16, + batch_bias_row_2d_fp16, + batch_bias_row_1d_fp16, + relu_fp16, + erf_3d_fp32, + batch_matmul_rrr_2d_fp16, + batch_matmul_rrr_3d_fp16, + conv2d_nhwc_fp16, + padding_2d_nhwc_fp16, + copy_4d_fp16, + bias_add_nhwc_2d_fp16, + bias_add_nhwc_1d_fp16, + elem_add_4d_fp16, + elem_mul_3d_fp16, + scalar_add_3d_fp16, + scalar_mul_3d_fp16, + cast_3d_fp16, + cast_3d_fp32, +) + +#### helper functions #### +# list representing the anchor ops +# in the future more layouts/dtypes can be supported +MATMUL_LIST = [matmul_rrr_fp16] +MATMUL_BIAS_LIST = [bias_row_2d_fp16, bias_row_1d_fp16] +BATCH_MATMUL_LIST = [batch_matmul_rrr_2d_fp16, batch_matmul_rrr_3d_fp16] +BATCH_MATMUL_BIAS_LIST = [batch_bias_row_2d_fp16, batch_bias_row_1d_fp16] +CONV2D_LIST = [conv2d_nhwc_fp16] + +# attributes for anchor ops used in code generation +OP_PATTERN_ATTR_LIST = { + matmul_rrr_fp16: { + "arg0_dtype": "float16", + "arg1_dtype": "float16", + "ret_dtype": "float16", + }, + batch_matmul_rrr_2d_fp16: { + "arg0_dtype": "float16", + "arg1_dtype": "float16", + "ret_dtype": "float16", + }, + batch_matmul_rrr_3d_fp16: { + "arg0_dtype": "float16", + "arg1_dtype": "float16", + "ret_dtype": "float16", + }, + conv2d_nhwc_fp16: { + "arg0_dtype": "float16", + "arg1_dtype": "float16", + "ret_dtype": "float16", + # in the future we can add layout here + }, +} + + +def _get_cutlass_code(attr): + pattern = attr["op_type"] + if pattern.startswith("cutlass.matmul"): + return cutlass_codegen_gemm(attr) + elif pattern.startswith("cutlass.conv2d"): + return cutlass_codegen_conv2d(attr) + else: + raise ValueError("op not supported") + + +def _final_code(code, headers, func_args): + res = "" + res += "#define DMLC_USE_LOGGING_LIBRARY \n" + res += "#include \n" + res += "#include \n" + res += "#include \n" + res += "#include \n" + res += "#include \n" + res += "#include \n" + res += "#include \n" + res += "#include \n" + + for header in headers: + res += "#include <" + header + ">\n" + res += "namespace {\n" + res += "using namespace tvm;\n" + res += "using namespace tvm::runtime;\n" + res += "void _cutlass_kernel(" + for arg in func_args: + res += "NDArray " + arg + ", " + res += "NDArray out0) {" + res += code + res += "}\n" + res += "} // namespace\n" + res += "TVM_DLL_EXPORT_TYPED_FUNC({global_symbol}, _cutlass_kernel);\n" + return res + + +#### cutlass patterns #### +def matmul_bias_relu(match_results, attr, get_code=True): + if len(match_results) < 3: + return None + attr = matmul_bias(match_results[:2], attr, get_code=False) + if attr is None or match_results[2].pattern != relu_fp16: + return None + m_bias, n_bias = match_results[1].symbol_values + m_relu, n_relu = match_results[2].symbol_values + A_bias, B_bias, C_bias = match_results[1].matched_buffers + A_relu, B_relu = match_results[2].matched_buffers + if m_bias == m_relu and n_bias == n_relu and C_bias == A_relu: + attr["op_type"] = "cutlass.matmul_bias_relu" + return [_get_cutlass_code(attr=attr), 3, attr["args"]] if get_code else attr + return None + + +def matmul_bias(match_results, attr, get_code=True): + if len(match_results) < 2: + return None + attr = matmul(match_results[:1], attr, get_code=False) + if attr is None or match_results[1].pattern not in MATMUL_BIAS_LIST: + return None + m_matmul, n_matmul, k_matmul = match_results[0].symbol_values + m_bias, n_bias = match_results[1].symbol_values + A_matmul, B_matmul, C_matmul = match_results[0].matched_buffers + A_bias, B_bias, C_bias = match_results[1].matched_buffers + if m_matmul == m_bias and n_matmul == n_bias and C_matmul == A_bias: + attr["op_type"] = "cutlass.matmul_bias" + attr["bias_arg_idx"] = 2 + attr["args"].append(B_bias) + return [_get_cutlass_code(attr=attr), 2, attr["args"]] if get_code else attr + return None + + +def matmul(match_results, attr, get_code=True): + if len(match_results) < 1: + return None + if match_results[0].pattern in MATMUL_LIST: + # matmul + attr["op_type"] = "cutlass.matmul" + return [_get_cutlass_code(attr=attr), 1, attr["args"]] if get_code else attr + return None + + +def batch_matmul_bias_gelu(match_results, attr, get_code=True): + if len(match_results) < 9: + return None + attr = batch_matmul_bias(match_results[:2], attr, get_code=False) # batch_matmul, batch_bias + if ( + attr is None + or match_results[2].pattern != scalar_mul_3d_fp16 + or match_results[3].pattern != cast_3d_fp32 + or match_results[4].pattern != erf_3d_fp32 + or match_results[5].pattern != cast_3d_fp16 + or match_results[6].pattern != scalar_mul_3d_fp16 + or match_results[7].pattern != scalar_add_3d_fp16 + or match_results[8].pattern != elem_mul_3d_fp16 + ): + return None + + def shape_match_3d(shape1, shape2): + if len(shape1) < 3 or len(shape2) < 3: + return False + return shape1[0] == shape2[0] and shape1[1] == shape2[1] and shape1[2] == shape2[2] + + for i in range(1, 8): + if not shape_match_3d(match_results[i].symbol_values, match_results[i + 1].symbol_values): + return None + + if not ( + match_results[1].matched_buffers[-1] == match_results[2].matched_buffers[0] + and match_results[2].matched_buffers[-1] == match_results[3].matched_buffers[0] + and match_results[3].matched_buffers[-1] == match_results[4].matched_buffers[0] + and match_results[4].matched_buffers[-1] == match_results[5].matched_buffers[0] + and match_results[5].matched_buffers[-1] == match_results[6].matched_buffers[0] + and match_results[6].matched_buffers[-1] == match_results[7].matched_buffers[0] + and match_results[1].matched_buffers[-1] == match_results[8].matched_buffers[0] + and match_results[7].matched_buffers[-1] == match_results[8].matched_buffers[1] + ): + return None + + if ( + abs(float(match_results[2].symbol_values[-1] - 0.5**0.5)) > 1e-5 + or abs(float(match_results[6].symbol_values[-1] - 0.5)) > 1e-5 + or abs(float(match_results[7].symbol_values[-1] - 0.5)) > 1e-5 + ): + return None + + attr["op_type"] = "cutlass.matmul_bias_gelu" + return [_get_cutlass_code(attr=attr), 9, attr["args"]] if get_code else attr + + +def batch_matmul_bias_residual_mul(match_results, attr, get_code=True): + if len(match_results) < 3: + return None + attr = batch_matmul_bias(match_results[:2], attr, get_code=False) # batch_matmul, batch_bias + if attr is None or match_results[2].pattern != elem_mul_3d_fp16: + return None + ( + b_bias, + m_bias, + n_bias, + ) = match_results[1].symbol_values + ( + b_mul, + m_mul, + n_mul, + ) = match_results[2].symbol_values + A_bias, B_bias, C_bias = match_results[1].matched_buffers + A_mul, B_mul, C_mul = match_results[2].matched_buffers + if b_bias == b_mul and m_bias == m_mul and n_bias == n_mul and C_bias == A_mul: + attr["op_type"] = "cutlass.matmul_bias_residual_multiply" + attr["residual_arg_idx"] = 3 + return [_get_cutlass_code(attr=attr), 3, attr["args"]] if get_code else attr + return None + + +def batch_matmul_bias(match_results, attr, get_code=True): + if len(match_results) < 2: + return None + attr = batch_matmul(match_results[:1], attr, get_code=False) + if attr is None or match_results[1].pattern not in BATCH_MATMUL_BIAS_LIST: + return None + ( + b_matmul, + m_matmul, + n_matmul, + k_matmul, + ) = match_results[0].symbol_values + ( + b_bias, + m_bias, + n_bias, + ) = match_results[1].symbol_values + A_matmul, B_matmul, C_matmul = match_results[0].matched_buffers + A_bias, B_bias, C_bias = match_results[1].matched_buffers + if b_matmul == b_bias and m_matmul == m_bias and n_matmul == n_bias and C_matmul == A_bias: + attr["op_type"] = "cutlass.matmul_bias" + attr["bias_arg_idx"] = 2 + attr["args"].append(B_bias) + return [_get_cutlass_code(attr=attr), 2, attr["args"]] if get_code else attr + return None + + +def batch_matmul(match_results, attr, get_code=True): + if len(match_results) < 1: + return None + if match_results[0].pattern in BATCH_MATMUL_LIST: + attr["op_type"] = "cutlass.matmul" + return [_get_cutlass_code(attr=attr), 1, attr["args"]] if get_code else attr + return None + + +def conv2d_bias_residual_add(match_results, attr, get_code=True): + if len(match_results) < 4: + return None + attr = conv2d_bias(match_results[:3], attr, get_code=False) + if attr is None or match_results[3].pattern != elem_add_4d_fp16: + return None + N_bias, H_bias, W_bias, C_bias = match_results[2].symbol_values + in1_bias, in2_bias, out_bias = match_results[2].matched_buffers + N_add, H_add, W_add, C_add = match_results[3].symbol_values + in1_add, in2_add, out_add = match_results[3].matched_buffers + if ( + N_bias == N_add + and H_bias == H_add + and W_bias == W_add + and C_bias == C_add + and out_bias in [in1_add, in2_add] + ): + attr["op_type"] = "cutlass.conv2d_bias_residual_add" + attr["residual_arg_idx"] = 3 + attr["args"].append(in2_add if out_bias == in1_add else in1_add) + return [_get_cutlass_code(attr=attr), 4, attr["args"]] if get_code else attr + return None + + +def conv2d_bias(match_results, attr, get_code=True): + if len(match_results) < 3: + return None + attr = conv2d(match_results[:2], attr, get_code=False) + if attr is None or ( + match_results[2].pattern not in [bias_add_nhwc_2d_fp16, bias_add_nhwc_1d_fp16] + ): + return None + (N_conv, pH_conv, pW_conv, H_conv, W_conv, C_conv, O_conv,) = match_results[ + 1 + ].symbol_values[:7] + A_pad_conv, B_conv, out_conv = match_results[1].matched_buffers + N_bias, H_bias, W_bias, C_bias = match_results[2].symbol_values + A_bias, B_bias, out_bias = match_results[2].matched_buffers + if ( + N_bias == N_conv + and H_bias == H_conv + and W_bias == W_conv + and C_bias == O_conv + and out_conv == A_bias + ): + attr["op_type"] = "cutlass.conv2d_bias" + attr["bias_arg_idx"] = 2 + attr["args"].append(B_bias) + return [_get_cutlass_code(attr=attr), 3, attr["args"]] if get_code else attr + return None + + +def conv2d(match_results, attr, get_code=True): + if len(match_results) < 2: + return None + if ( + match_results[0].pattern in [padding_2d_nhwc_fp16, copy_4d_fp16] + and match_results[1].pattern == conv2d_nhwc_fp16 + ): + if match_results[0].pattern == padding_2d_nhwc_fp16: + ( + N_pad, + H_pad, + W_pad, + C_pad, + pH_pad, + pW_pad, + lH_pad, + lW_pad, + rH_pad, + rW_pad, + ) = match_results[0].symbol_values + else: + ( + N_pad, + H_pad, + W_pad, + C_pad, + ) = match_results[0].symbol_values + pH_pad = rH_pad = H_pad + pW_pad = rW_pad = W_pad + lH_pad = lW_pad = 0 + ( + N_conv, + pH_conv, + pW_conv, + H_conv, + W_conv, + C_conv, + O_conv, + KH_conv, + KW_conv, + stride_h_conv, + stride_w_conv, + dilation_h_conv, + dilation_w_conv, + ) = match_results[1].symbol_values + A, A_pad = match_results[0].matched_buffers + A_pad_conv, B_conv, out_conv = match_results[1].matched_buffers + if ( + N_pad == N_conv + and pH_pad == pH_conv + and pW_pad == pW_conv + and C_pad == C_conv + and A_pad == A_pad_conv + ): + if ( + lH_pad == pH_pad - rH_pad + and lW_pad == pW_pad - rW_pad + and lH_pad + H_pad == rH_pad + and lW_pad + W_pad == rW_pad + ): + padding = (lH_pad, lW_pad) + strides = (stride_h_conv, stride_w_conv) + dilation = (dilation_h_conv, dilation_w_conv) + attr["padding"] = padding + attr["strides"] = strides + attr["dilation"] = dilation + attr["op_type"] = "cutlass.conv2d" + return [_get_cutlass_code(attr=attr), 2, attr["args"]] if get_code else attr + return None + + +### cutlass codegen functions ### +def compile_options(target, threads=-1, use_fast_math=False): + compute_version = int("".join(get_target_compute_version(target).split("."))) + kwargs = _get_cutlass_compile_options(compute_version, threads, use_fast_math) + kwargs["options"].remove("-c") + return kwargs + + +def cutlass_fcodegen(sm=80, bin_dir="./bin"): + gemm_profiler = CutlassGemmProfiler(sm, _get_cutlass_path(), bin_dir) + conv2d_profiler = CutlassConv2DProfiler(sm, _get_cutlass_path(), bin_dir) + + def cutlass_codegen_with_match_results(match_results: List[MatchResult]): + """generate cutlass code with match results""" + nonlocal gemm_profiler + nonlocal conv2d_profiler + + assert len(match_results) > 0 + + # add shape into attr + if match_results[0].pattern in MATMUL_LIST: + A_matmul, B_matmul, C_matmul = match_results[0].matched_buffers + attr: Dict[Any, Any] = OP_PATTERN_ATTR_LIST[match_results[0].pattern] + attr["args"] = [A_matmul, B_matmul] + attr["arg0_shape"] = A_matmul.shape + attr["arg1_shape"] = B_matmul.shape + attr["ret_shape"] = C_matmul.shape + attr["lhs_arg_idx"] = 0 + attr["rhs_arg_idx"] = 1 + elif match_results[0].pattern in BATCH_MATMUL_LIST: + A_matmul, B_matmul, C_matmul = match_results[0].matched_buffers + attr = OP_PATTERN_ATTR_LIST[match_results[0].pattern] + attr["args"] = [A_matmul, B_matmul] + attr["arg0_shape"] = A_matmul.shape + attr["arg1_shape"] = B_matmul.shape + attr["ret_shape"] = C_matmul.shape + attr["lhs_arg_idx"] = 0 + attr["rhs_arg_idx"] = 1 + elif len(match_results) >= 1 and match_results[1].pattern in CONV2D_LIST: + A_input = match_results[0].matched_buffers[0] + A_conv2d, B_conv2d, C_conv2d = match_results[1].matched_buffers + attr = OP_PATTERN_ATTR_LIST[match_results[1].pattern] + attr["args"] = [A_input, B_conv2d] + attr["arg0_shape"] = A_input.shape + attr["arg1_shape"] = B_conv2d.shape + attr["ret_shape"] = C_conv2d.shape + attr["lhs_arg_idx"] = 0 + attr["rhs_arg_idx"] = 1 + else: + return ["", 0] + + # add profiler into attr + attr["gemm_profiler"] = gemm_profiler + attr["conv2d_profiler"] = conv2d_profiler + + cutlass_patterns = [ + # 9 + batch_matmul_bias_gelu, + # 4 + conv2d_bias_residual_add, + # 3 + batch_matmul_bias_residual_mul, + matmul_bias_relu, + conv2d_bias, + # 2 + matmul_bias, + batch_matmul_bias, + conv2d, + # 1 + matmul, + batch_matmul, + ] + for pattern in cutlass_patterns: + res = pattern(match_results, attr) + if res is not None: + return res + + return ["", 0] + + return cutlass_codegen_with_match_results + + +def cutlass_codegen_gemm(attrs): + """cutlass codegen for gemm""" + gemm_profiler = attrs["gemm_profiler"] + op_type = attrs["op_type"] + lhs_shape = attrs["arg0_shape"] + rhs_shape = attrs["arg1_shape"] + MM = lhs_shape[-2] + KK = lhs_shape[-1] + if "transposed" in op_type: + NN = rhs_shape[-2] + ldb = "K" + layout_b = LayoutType.ColumnMajor + else: + NN = rhs_shape[-1] + ldb = "N" + layout_b = LayoutType.RowMajor + + lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1) + rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1) + if lhs_batches == 1 and rhs_batches == 1: + # Regular matmul + is_batched = False + batch_attrs = {} + else: + is_batched = True + batch_attrs = { + # If both lhs_batches and rhs_batches are greater than 1, + # they must be equal. This is checked by is_shape_valid_for_cutlass_matmul. + "batch": lhs_batches if rhs_batches == 1 else rhs_batches, + "batch_stride_A": 0 if lhs_batches == 1 else MM * KK, + "batch_stride_B": 0 if rhs_batches == 1 else KK * NN, + "batch_stride_C": MM * NN, + } + op_name, op_def, _ = gemm_profiler.profile( + op_type, + MM, + NN, + KK, + attrs["ret_dtype"], + attrs["arg0_dtype"], + attrs["arg1_dtype"], + False, + batched=is_batched, + find_first_valid=False, + use_multiprocessing=True, + layout_b=layout_b, + ) + attrs["cutlass_op_name"] = op_name + attrs["cutlass_op_def"] = op_def + attrs["lda"] = "K" + attrs["ldb"] = ldb + attrs["ldc"] = "N" + attrs.update(batch_attrs) + del attrs["gemm_profiler"] + del attrs["conv2d_profiler"] + + nargs = 2 + if "bias_arg_idx" in attrs: + nargs += 1 + if "residual_arg_idx" in attrs: + nargs += 1 + func_args = ["inp" + str(i) for i in range(nargs)] + + # A temporary solution to handle batch matmul residual cases + # TODO(@bohan): remove this after initialize_template supports bmm residual + if op_type in [ + "cutlass.matmul_bias_residual_multiply", + ]: + + def _convert_dtype_str(dtype): + if isinstance(dtype, list): + arr = [] + for t in dtype: + arr.append(_convert_dtype_str(t)) + return arr + elif isinstance(dtype, str): + if dtype == "float16": + return "cutlass::half_t" + elif dtype == "float32": + return "float" + raise ValueError("dtype not supported") + + typea, typeb, typec = _convert_dtype_str( + [attrs["arg0_dtype"], attrs["arg1_dtype"], attrs["ret_dtype"]] + ) + + text = f""" +#define CUTLASS_ENABLE_CUBLAS 1 +#define CUTLASS_NAMESPACE cutlass +#define CUTLASS_ENABLE_TENSOR_CORE_MMA 1 +#define NDEBUG +#include +#include +#include +#include +#include +#include +#include +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/linear_combination_residual_block.h" +#include "cutlass/gemm/device/gemm_universal_with_broadcast.h" +#include +#include +#include +#include +#define DMLC_USE_LOGGING_LIBRARY +#include +#include +#include +namespace {{ +using namespace tvm; +using namespace tvm::runtime; +void _BHGEMM(NDArray A, NDArray B, NDArray Bias, NDArray D, NDArray C) {{ + // A: [Batch, M, K], B: [1, K, N]/[K, N], Bias: [1, N]/[N], D: [Batch, M, N], C: [Batch, M, N] + CHECK_EQ(A->ndim, 3); + int bdim = B->ndim; + int bias_dim = Bias->ndim; + CHECK_EQ(C->ndim, 3); + CHECK_EQ(A->shape[2], B->shape[bdim - 2]); + CHECK_EQ(Bias->shape[bias_dim - 1], B->shape[bdim - 1]); + CHECK_EQ(D->ndim, 3); + CHECK_EQ(D->shape[0], A->shape[0]); + CHECK_EQ(D->shape[1], A->shape[1]); + CHECK_EQ(D->shape[2], B->shape[bdim - 1]); + CHECK_EQ(C->shape[0], A->shape[0]); + CHECK_EQ(C->shape[1], A->shape[1]); + CHECK_EQ(C->shape[2], B->shape[bdim - 1]); + int64_t M = A->shape[0] * A->shape[1]; + int64_t N = B->shape[bdim - 1]; + int64_t K = A->shape[2]; + int64_t input_a_batch_stride = M * K; + int64_t input_a_stride = K; + int64_t input_a_offset = 0; // default to 0 + int64_t input_b_batch_stride = K * N; + int64_t input_b_stride = N; + int64_t input_b_offset = 0; // default to 0 + int64_t output_stride = N; + int64_t output_offset = 0; + int64_t a_size = 1; + a_size *= A->shape[0]; + a_size *= A->shape[1]; + a_size *= A->shape[2]; + + int64_t b_size = 1; + b_size *= B->shape[bias_dim - 2]; + b_size *= B->shape[bias_dim - 1]; + + int64_t c_size = 1; + c_size *= C->shape[0]; + c_size *= C->shape[1]; + c_size *= C->shape[2]; + + // Define the GEMM operation + {op_def} + using kernel = Operation_{op_name}; + using ElementComputeEpilogue = typename kernel::ElementAccumulator; + typename kernel::Arguments arguments({{ + cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode + {{M, N, K}}, // GemmCoord problem_size + 1, // int batch_count + {{ElementComputeEpilogue(1), ElementComputeEpilogue(1)}}, // typename EpilogueOutputOp::Params epilogue + ({typea}*)(A->data) + input_a_offset, // void const * ptr_A + ({typeb}*)(B->data) + input_b_offset, // void const * ptr_B + ({typec}*)(D->data), // void const * ptr_C1 + ({typec}*)(C->data) + output_offset, // void * ptr_D + ({typea}*)(Bias->data), // void * ptr_Vector + nullptr, // void * ptr_Tensor + input_a_batch_stride, // int64_t batch_stride_A + input_b_batch_stride, // int64_t batch_stride_B + 0, // int64_t batch_stride_C1 + 0, // int64_t batch_stride_D + 0, // int64_t batch_stride_Vector + 0, // int64_t batch_stride_Tensor + input_a_stride, // typename LayoutA::Stride::Index lda + input_b_stride, // typename LayoutB::Stride::Index ldb + N, // typename LayoutC::Stride::Index ldc1 + output_stride, // typename LayoutC::Stride::Index ldd + 0, // typename LayoutC::Stride::Index ldr + 0, // typename LayoutC::Stride::Index ldt + }}); + kernel gemm_op; + size_t workspace_size = gemm_op.get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + cutlass::Status status = gemm_op.can_implement(arguments); + CHECK(status == cutlass::Status::kSuccess); + status = gemm_op.initialize(arguments, workspace.get()); + CHECK(status == cutlass::Status::kSuccess); + status = gemm_op(); + CHECK(status == cutlass::Status::kSuccess); + return; +}} +}} // namespace +TVM_DLL_EXPORT_TYPED_FUNC({{global_symbol}}, _BHGEMM); + """ + return text + + code = instantiate_template(op_type, attrs, func_args) + return _final_code(code.code, code.headers, func_args) + + +def cutlass_codegen_conv2d(attrs): + """cutlass codegen for conv2d""" + # cutlass backend only supports nhwc for now + conv2d_profiler = attrs["conv2d_profiler"] + op_type = attrs["op_type"] + conv_kind = ConvKind.Fprop + op_name, op_def, _ = conv2d_profiler.profile( + op_type=attrs["op_type"], + d_shape=attrs["arg0_shape"], + w_shape=attrs["arg1_shape"], + padding=attrs["padding"], + stride=attrs["strides"], + dilation=attrs["dilation"], + out_dtype=attrs["ret_dtype"], + data_dtype=attrs["arg0_dtype"], + weight_dtype=attrs["arg1_dtype"], + use_3xtf32=False, + conv_kind=conv_kind, + split_k_slices=[1], + profile_all_alignments=True, + find_first_valid=False, + use_multiprocessing=True, + ) + attrs["cutlass_op_def"] = op_def + attrs["cutlass_op_name"] = op_name + del attrs["gemm_profiler"] + del attrs["conv2d_profiler"] + + nargs = 2 + if "bias_arg_idx" in attrs: + nargs += 1 + if "residual_arg_idx" in attrs: + nargs += 1 + func_args = ["inp" + str(i) for i in range(nargs)] + code = instantiate_template(op_type, attrs, func_args) + return _final_code(code.code, code.headers, func_args) diff --git a/python/tvm/relax/backend_tir/pattern.py b/python/tvm/relax/backend_tir/pattern.py new file mode 100644 index 000000000000..10f7a3b1628d --- /dev/null +++ b/python/tvm/relax/backend_tir/pattern.py @@ -0,0 +1,576 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,missing-function-docstring,chained-comparison +"""TIR Patterns""" +from typing import List + +import tvm +from tvm.runtime import Object +import tvm._ffi + +from tvm.script import tir as T + + +@tvm._ffi.register_object("relax.MatchResult") +class MatchResult(Object): + """The match result of a TIR pattern.""" + + def __init__(self, pattern, symbol_values, matched_buffers): + self.__init_handle_by_constructor__( + tvm._ffi.MatchResult, pattern, symbol_values, matched_buffers + ) + + +@T.prim_func +def matmul_rrr_fp16( + var_rxplaceholder: T.handle, + var_rxplaceholder_1: T.handle, + var_matmul: T.handle, + M: T.int64, + N: T.int64, + K: T.int64, +) -> None: + # function attr dict + T.func_attr({"tir.noalias": True}) + rxplaceholder = T.match_buffer(var_rxplaceholder, [M, K], dtype="float16") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [K, N], dtype="float16") + matmul = T.match_buffer(var_matmul, [M, N], dtype="float16") + # body + # with T.block("root") + for i0, i1, i2 in T.grid(M, N, K): + with T.block("matmul"): + i0_1, i1_1, k = T.axis.remap("SSR", [i0, i1, i2]) + T.reads(rxplaceholder[i0_1, k], rxplaceholder_1[k, i1_1]) + T.writes(matmul[i0_1, i1_1]) + with T.init(): + matmul[i0_1, i1_1] = T.float16(0) + matmul[i0_1, i1_1] = ( + matmul[i0_1, i1_1] + rxplaceholder[i0_1, k] * rxplaceholder_1[k, i1_1] + ) + + +@T.prim_func +def bias_row_2d_fp16( + var_rxplaceholder: T.handle, + var_rxplaceholder_1: T.handle, + var_T_add: T.handle, + M: T.int64, + N: T.int64, +) -> None: + # function attr dict + T.func_attr({"tir.noalias": True}) + rxplaceholder = T.match_buffer(var_rxplaceholder, [M, N], dtype="float16") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [T.int64(1), N], dtype="float16") + T_add = T.match_buffer(var_T_add, [M, N], dtype="float16") + # body + # with T.block("root") + for i0, i1 in T.grid(M, N): + with T.block("T_add"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[T.int64(0), ax1]) + T.writes(T_add[ax0, ax1]) + T_add[ax0, ax1] = rxplaceholder[ax0, ax1] + rxplaceholder_1[T.int64(0), ax1] + + +@T.prim_func +def bias_row_1d_fp16( + var_rxplaceholder: T.handle, + var_rxplaceholder_1: T.handle, + var_T_add: T.handle, + M: T.int64, + N: T.int64, +) -> None: + # function attr dict + T.func_attr({"tir.noalias": True}) + rxplaceholder = T.match_buffer(var_rxplaceholder, [M, N], dtype="float16") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [N], dtype="float16") + T_add = T.match_buffer(var_T_add, [M, N], dtype="float16") + # body + # with T.block("root") + for i0, i1 in T.grid(M, N): + with T.block("T_add"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[ax1]) + T.writes(T_add[ax0, ax1]) + T_add[ax0, ax1] = rxplaceholder[ax0, ax1] + rxplaceholder_1[ax1] + + +@T.prim_func +def batch_bias_row_2d_fp16( + var_rxplaceholder: T.handle, + var_rxplaceholder_1: T.handle, + var_T_add: T.handle, + batch: T.int64, + M: T.int64, + N: T.int64, +) -> None: + # function attr dict + T.func_attr({"tir.noalias": True}) + rxplaceholder = T.match_buffer(var_rxplaceholder, [batch, M, N], dtype="float16") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [T.int64(1), N], dtype="float16") + T_add = T.match_buffer(var_T_add, [batch, M, N], dtype="float16") + # body + # with T.block("root") + for i0, i1, i2 in T.grid(batch, M, N): + with T.block("T_add"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, ax1, ax2], rxplaceholder_1[T.int64(0), ax2]) + T.writes(T_add[ax0, ax1, ax2]) + T_add[ax0, ax1, ax2] = rxplaceholder[ax0, ax1, ax2] + rxplaceholder_1[T.int64(0), ax2] + + +@T.prim_func +def batch_bias_row_1d_fp16( + var_rxplaceholder: T.handle, + var_rxplaceholder_1: T.handle, + var_T_add: T.handle, + batch: T.int64, + M: T.int64, + N: T.int64, +) -> None: + # function attr dict + T.func_attr({"tir.noalias": True}) + rxplaceholder = T.match_buffer(var_rxplaceholder, [batch, M, N], dtype="float16") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [N], dtype="float16") + T_add = T.match_buffer(var_T_add, [batch, M, N], dtype="float16") + # body + # with T.block("root") + for i0, i1, i2 in T.grid(batch, M, N): + with T.block("T_add"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, ax1, ax2], rxplaceholder_1[ax2]) + T.writes(T_add[ax0, ax1, ax2]) + T_add[ax0, ax1, ax2] = rxplaceholder[ax0, ax1, ax2] + rxplaceholder_1[ax2] + + +@T.prim_func +def relu_fp16(var_rxplaceholder: T.handle, var_compute: T.handle, M: T.int64, N: T.int64) -> None: + # function attr dict + T.func_attr({"tir.noalias": True}) + rxplaceholder = T.match_buffer(var_rxplaceholder, [M, N], dtype="float16") + compute = T.match_buffer(var_compute, [M, N], dtype="float16") + # body + # with T.block("root") + for i0, i1 in T.grid(M, N): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.max(rxplaceholder[i0_1, i1_1], T.float16(0)) + + +@T.prim_func +def batch_matmul_rrr_2d_fp16( + var_rxplaceholder: T.handle, + var_rxplaceholder_1: T.handle, + var_matmul: T.handle, + batch: T.int64, + M: T.int64, + N: T.int64, + K: T.int64, +) -> None: + # function attr dict + T.func_attr({"tir.noalias": True}) + rxplaceholder = T.match_buffer(var_rxplaceholder, [batch, M, K], dtype="float16") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [K, N], dtype="float16") + matmul = T.match_buffer(var_matmul, [batch, M, N], dtype="float16") + # body + # with T.block("root") + for i0, i1, i2, i3 in T.grid(batch, M, N, K): + with T.block("matmul"): + i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[i0_1, i1_1, k], rxplaceholder_1[k, i2_1]) + T.writes(matmul[i0_1, i1_1, i2_1]) + with T.init(): + matmul[i0_1, i1_1, i2_1] = T.float16(0) + matmul[i0_1, i1_1, i2_1] = ( + matmul[i0_1, i1_1, i2_1] + rxplaceholder[i0_1, i1_1, k] * rxplaceholder_1[k, i2_1] + ) + + +@T.prim_func +def batch_matmul_rrr_3d_fp16( + var_rxplaceholder: T.handle, + var_rxplaceholder_1: T.handle, + var_matmul: T.handle, + batch: T.int64, + M: T.int64, + N: T.int64, + K: T.int64, +) -> None: + # function attr dict + T.func_attr({"tir.noalias": True}) + rxplaceholder = T.match_buffer(var_rxplaceholder, [batch, M, K], dtype="float16") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [batch, K, N], dtype="float16") + matmul = T.match_buffer(var_matmul, [batch, M, N], dtype="float16") + # body + # with T.block("root") + for i0, i1, i2, i3 in T.grid(batch, M, N, K): + with T.block("matmul"): + i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[i0_1, i1_1, k], rxplaceholder_1[i0_1, k, i2_1]) + T.writes(matmul[i0_1, i1_1, i2_1]) + with T.init(): + matmul[i0_1, i1_1, i2_1] = T.float16(0) + matmul[i0_1, i1_1, i2_1] = ( + matmul[i0_1, i1_1, i2_1] + + rxplaceholder[i0_1, i1_1, k] * rxplaceholder_1[i0_1, k, i2_1] + ) + + +@T.prim_func +def copy_4d_fp16( + A_handle: T.handle, + B_handle: T.handle, + N: T.int64, + H: T.int64, + W: T.int64, + C: T.int64, +) -> None: + A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16") + B = T.match_buffer(B_handle, [N, H, W, C], dtype="float16") + # body + # with T.block("root") + for n, h, w, c in T.grid(N, H, W, C): + with T.block("copy"): + vn, vh, vw, vc = T.axis.remap("SSSS", [n, h, w, c]) + T.reads(A[vn, vh, vw, vc]) + T.writes(B[vn, vh, vw, vc]) + B[vn, vh, vw, vc] = A[vn, vh, vw, vc] + + +@T.prim_func +def padding_2d_nhwc_fp16( + A_handle: T.handle, + B_handle: T.handle, + N: T.int64, + H: T.int64, + W: T.int64, + C: T.int64, + pH: T.int64, + pW: T.int64, + lH: T.int64, + lW: T.int64, + rH: T.int64, + rW: T.int64, +) -> None: + A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16") + B = T.match_buffer(B_handle, [N, pH, pW, C], dtype="float16") + # body + # with T.block("root") + for v, v_1, v_2, v_3 in T.grid(N, pH, pW, C): + with T.block("copy"): + v_4, v_5, v_6, v_7 = T.axis.remap("SSSS", [v, v_1, v_2, v_3]) + T.reads(A[v_4, v_5 - lH, v_6 - lW, v_7]) + T.writes(B[v_4, v_5, v_6, v_7]) + B[v_4, v_5, v_6, v_7] = T.if_then_else( + lH <= v_5 and v_5 < rH and lW <= v_6 and v_6 < rW, + A[v_4, v_5 - lH, v_6 - lW, v_7], + T.float16(0), + dtype="float16", + ) + + +@T.prim_func +def conv2d_nhwc_fp16( + A_handle: T.handle, + B_handle: T.handle, + out_handle: T.handle, + N: T.int64, + pH: T.int64, + pW: T.int64, + H: T.int64, + W: T.int64, + C: T.int64, + O: T.int64, + KH: T.int64, + KW: T.int64, + StrideH: T.int64, + StrideW: T.int64, + DilateH: T.int64, + DilateW: T.int64, +) -> None: + A = T.match_buffer(A_handle, [N, pH, pW, C], dtype="float16") + B = T.match_buffer(B_handle, [O, KH, KW, C], dtype="float16") + out = T.match_buffer(out_handle, [N, H, W, O], dtype="float16") + # body + # with T.block("root") + for v, v_1, v_2, v_3, v_4, v_5, v_6 in T.grid(N, H, W, O, KH, KW, C): + with T.block("conv"): + v_7, v_8, v_9, v_10, v_11, v_12, v_13 = T.axis.remap( + "SSSSRRR", [v, v_1, v_2, v_3, v_4, v_5, v_6] + ) + T.reads( + A[v_7, v_11 * DilateH + v_8 * StrideH, v_12 * DilateW + v_9 * StrideW, v_13], + B[v_10, v_11, v_12, v_13], + ) + T.writes(out[v_7, v_8, v_9, v_10]) + with T.init(): + out[v_7, v_8, v_9, v_10] = T.float16(0) + out[v_7, v_8, v_9, v_10] = ( + out[v_7, v_8, v_9, v_10] + + A[v_7, v_11 * DilateH + v_8 * StrideH, v_12 * DilateW + v_9 * StrideW, v_13] + * B[v_10, v_11, v_12, v_13] + ) + + +@T.prim_func +def bias_add_nhwc_2d_fp16( + A_handle: T.handle, + B_handle: T.handle, + out_handle: T.handle, + N: T.int64, + H: T.int64, + W: T.int64, + C: T.int64, +): + A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16") + B = T.match_buffer(B_handle, [1, 1, 1, C], dtype="float16") + out = T.match_buffer(out_handle, [N, H, W, C], dtype="float16") + for ax0, ax1, ax2, ax3 in T.grid(N, H, W, C): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, T.int64(0), T.int64(0), v_ax3]) + T.writes(out[v_ax0, v_ax1, v_ax2, v_ax3]) + out[v_ax0, v_ax1, v_ax2, v_ax3] = ( + A[v_ax0, v_ax1, v_ax2, v_ax3] + B[v_ax0, T.int64(0), T.int64(0), v_ax3] + ) + + +@T.prim_func +def bias_add_nhwc_1d_fp16( + A_handle: T.handle, + B_handle: T.handle, + out_handle: T.handle, + N: T.int64, + H: T.int64, + W: T.int64, + C: T.int64, +): + A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16") + B = T.match_buffer(B_handle, [1, 1, 1, C], dtype="float16") + out = T.match_buffer(out_handle, [N, H, W, C], dtype="float16") + for ax0, ax1, ax2, ax3 in T.grid(N, H, W, C): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[T.int64(0), T.int64(0), T.int64(0), v_ax3]) + T.writes(out[v_ax0, v_ax1, v_ax2, v_ax3]) + out[v_ax0, v_ax1, v_ax2, v_ax3] = ( + A[v_ax0, v_ax1, v_ax2, v_ax3] + B[T.int64(0), T.int64(0), T.int64(0), v_ax3] + ) + + +@T.prim_func +def elem_add_2d_fp16( + in0_handle: T.handle, + in1_handle: T.handle, + out_handle: T.handle, + N: T.int64, + M: T.int64, +): + in0 = T.match_buffer(in0_handle, [N, M], dtype="float16") + in1 = T.match_buffer(in1_handle, [N, M], dtype="float16") + out = T.match_buffer(out_handle, [N, M], dtype="float16") + for ax0, ax1 in T.grid(N, M): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(in0[v_ax0, v_ax1], in1[v_ax0, v_ax1]) + T.writes(out[v_ax0, v_ax1]) + out[v_ax0, v_ax1] = in0[v_ax0, v_ax1] + in1[v_ax0, v_ax1] + + +@T.prim_func +def elem_add_3d_fp16( + in0_handle: T.handle, + in1_handle: T.handle, + out_handle: T.handle, + B: T.int64, + N: T.int64, + M: T.int64, +): + in0 = T.match_buffer(in0_handle, [B, N, M], dtype="float16") + in1 = T.match_buffer(in1_handle, [B, N, M], dtype="float16") + out = T.match_buffer(out_handle, [B, N, M], dtype="float16") + for ax0, ax1, ax2 in T.grid(B, N, M): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(in0[v_ax0, v_ax1, v_ax2], in1[v_ax0, v_ax1, v_ax2]) + T.writes(out[v_ax0, v_ax1, v_ax2]) + out[v_ax0, v_ax1, v_ax2] = in0[v_ax0, v_ax1, v_ax2] + in1[v_ax0, v_ax1, v_ax2] + + +@T.prim_func +def elem_add_4d_fp16( + A_handle: T.handle, + B_handle: T.handle, + out_handle: T.handle, + N: T.int64, + H: T.int64, + W: T.int64, + C: T.int64, +): + A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16") + B = T.match_buffer(B_handle, [N, H, W, C], dtype="float16") + out = T.match_buffer(out_handle, [N, H, W, C], dtype="float16") + for ax0, ax1, ax2, ax3 in T.grid(N, H, W, C): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(out[v_ax0, v_ax1, v_ax2, v_ax3]) + out[v_ax0, v_ax1, v_ax2, v_ax3] = ( + A[v_ax0, v_ax1, v_ax2, v_ax3] + B[v_ax0, v_ax1, v_ax2, v_ax3] + ) + + +@T.prim_func +def scalar_mul_3d_fp16( + inp0_handle: T.handle, + out_handle: T.handle, + D1: T.int64, + D2: T.int64, + D3: T.int64, + scalar: T.float16, +): + inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float16") + out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float16") + for ax0, ax1, ax2 in T.grid(D1, D2, D3): + with T.block("T_mul"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(inp0[v_ax0, v_ax1, v_ax2]) + T.writes(out[v_ax0, v_ax1, v_ax2]) + out[v_ax0, v_ax1, v_ax2] = inp0[v_ax0, v_ax1, v_ax2] * scalar + + +@T.prim_func +def erf_3d_fp32( + inp0_handle: T.handle, + out_handle: T.handle, + D1: T.int64, + D2: T.int64, + D3: T.int64, +): + inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float32") + out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float32") + for ax0, ax1, ax2 in T.grid(D1, D2, D3): + with T.block("T_erf"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(inp0[v_ax0, v_ax1, v_ax2]) + T.writes(out[v_ax0, v_ax1, v_ax2]) + out[v_ax0, v_ax1, v_ax2] = T.erf(inp0[v_ax0, v_ax1, v_ax2]) + + +@T.prim_func +def scalar_add_3d_fp16( + inp0_handle: T.handle, + out_handle: T.handle, + D1: T.int64, + D2: T.int64, + D3: T.int64, + scalar: T.float16, +): + inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float16") + out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float16") + for ax0, ax1, ax2 in T.grid(D1, D2, D3): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(inp0[v_ax0, v_ax1, v_ax2]) + T.writes(out[v_ax0, v_ax1, v_ax2]) + out[v_ax0, v_ax1, v_ax2] = scalar + inp0[v_ax0, v_ax1, v_ax2] + + +@T.prim_func +def elem_mul_3d_fp16( + inp0_handle: T.handle, + inp1_handle: T.handle, + out_handle: T.handle, + D1: T.int64, + D2: T.int64, + D3: T.int64, +): + inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float16") + inp1 = T.match_buffer(inp1_handle, [D1, D2, D3], dtype="float16") + out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float16") + for ax0, ax1, ax2 in T.grid(D1, D2, D3): + with T.block("T_mul"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(inp0[v_ax0, v_ax1, v_ax2], inp1[v_ax0, v_ax1, v_ax2]) + T.writes(out[v_ax0, v_ax1, v_ax2]) + out[v_ax0, v_ax1, v_ax2] = inp0[v_ax0, v_ax1, v_ax2] * inp1[v_ax0, v_ax1, v_ax2] + + +@T.prim_func +def cast_3d_fp16( + inp0_handle: T.handle, + out_handle: T.handle, + D1: T.int64, + D2: T.int64, + D3: T.int64, +): + inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float32") + out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float16") + for ax0, ax1, ax2 in T.grid(D1, D2, D3): + with T.block("T_cast"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(inp0[v_ax0, v_ax1, v_ax2]) + T.writes(out[v_ax0, v_ax1, v_ax2]) + out[v_ax0, v_ax1, v_ax2] = T.Cast("float16", inp0[v_ax0, v_ax1, v_ax2]) + + +@T.prim_func +def cast_3d_fp32( + inp0_handle: T.handle, + out_handle: T.handle, + D1: T.int64, + D2: T.int64, + D3: T.int64, +): + inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float16") + out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float32") + for ax0, ax1, ax2 in T.grid(D1, D2, D3): + with T.block("T_cast"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(inp0[v_ax0, v_ax1, v_ax2]) + T.writes(out[v_ax0, v_ax1, v_ax2]) + out[v_ax0, v_ax1, v_ax2] = T.Cast("float32", inp0[v_ax0, v_ax1, v_ax2]) + + +def get_tir_pattern() -> List[tvm.tir.PrimFunc]: + """Get the tir patterns for backend dispatch.""" + return [ + matmul_rrr_fp16, + bias_row_2d_fp16, + bias_row_1d_fp16, + batch_bias_row_2d_fp16, + batch_bias_row_1d_fp16, + relu_fp16, + erf_3d_fp32, + batch_matmul_rrr_2d_fp16, + batch_matmul_rrr_3d_fp16, + copy_4d_fp16, + padding_2d_nhwc_fp16, + conv2d_nhwc_fp16, + bias_add_nhwc_2d_fp16, + bias_add_nhwc_1d_fp16, + elem_add_2d_fp16, + elem_add_3d_fp16, + elem_add_4d_fp16, + elem_mul_3d_fp16, + scalar_add_3d_fp16, + scalar_mul_3d_fp16, + cast_3d_fp16, + cast_3d_fp32, + ] diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index c03df804eef7..18321e8dba33 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -693,6 +693,25 @@ def ToMixedPrecision(out_dtype="float32") -> tvm.ir.transform.Pass: return _ffi_api.ToMixedPrecision(out_dtype) # type: ignore +def SplitCallTIRByPattern(patterns, fcodegen) -> tvm.ir.transform.Pass: + """Split a PrimFunc into 2 parts: the first part is a TIR PrimFunc which is + matched with some pattern, and the second part is the rest of the original + PrimFunc. It will call fcodegen to generate the code for the matched pattern + to replace it with a ExternFunc call. + Parameters + ---------- + patterns : List[PrimFunc] + The list of patterns to match. + fcodegen: Callable[[List[MatchResult]], List[Object]] + The function to generate the code for the matched patterns. + Returns + ------- + ret : tvm.transform.Pass + The registered pass for splitting call_tir. + """ + return _ffi_api.SplitCallTIRByPattern(patterns, fcodegen) # type: ignore + + def _wrap_class_function_pass(pass_cls, pass_info): """Wrap a python class as function pass.""" diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index da0ca3a0b55e..b36b5ed4d6c6 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -315,6 +315,17 @@ class CodeGenVM : public ExprFunctor { } Instruction::Arg VisitExpr_(const ExternFuncNode* op) final { + static const constexpr char* kCSource = "c_source"; + static const constexpr char* kCSourceFmt = "c_source_fmt"; + if (Optional opt_code = op->attrs.GetAttr(kCSource)) { + String sym = op->global_symbol; + String fmt = op->attrs.GetAttr(kCSourceFmt).value_or("c"); + String code = opt_code.value(); + Module c_source_module = + codegen::CSourceModuleCreate(/*code=*/code, /*fmt=*/fmt, /*func_names=*/{sym}, + /*const_vars=*/{}); + builder_->exec()->Import(c_source_module); + } builder_->DeclareFunction(op->global_symbol, VMFuncInfo::FuncKind::kPackedFunc); return builder_->GetFunction(op->global_symbol); } diff --git a/src/relax/ir/tir_pattern.cc b/src/relax/ir/tir_pattern.cc new file mode 100644 index 000000000000..cbe4170bb979 --- /dev/null +++ b/src/relax/ir/tir_pattern.cc @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +namespace tvm { +namespace relax { + +MatchResult::MatchResult(TIRPattern pattern, Array symbol_values, + Array matched_buffers) { + auto n = make_object(); + n->pattern = std::move(pattern); + n->symbol_values = std::move(symbol_values); + n->matched_buffers = std::move(matched_buffers); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(MatchResultNode); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/split_call_tir_by_pattern.cc b/src/relax/transform/split_call_tir_by_pattern.cc new file mode 100644 index 000000000000..7fcc2cb34a76 --- /dev/null +++ b/src/relax/transform/split_call_tir_by_pattern.cc @@ -0,0 +1,782 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/to_non_dataflow.cc + * \brief Transform all dataflow structure to non-dataflow version. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../../tir/schedule/ir_comparator.h" + +namespace tvm { + +static const constexpr char* kLibraryKernel = "library_kernel"; +static const constexpr char* kCSource = "c_source"; +static const constexpr char* kCSourceFmt = "c_source_fmt"; +static const constexpr char* kCSourceFmtCuda = "cu"; + +namespace tir { + +using relax::FCodegen; +using relax::MatchResult; +using relax::TIRPattern; + +/*! \brief helper to match a for stmt to a pattern*/ +class ForMatcher : public TensorizeComparator { + public: + using SymbolMap = std::unordered_map; + explicit ForMatcher(const tir::PrimFunc& pattern, const Array& pattern_vars) + : TensorizeComparator(IRModule({{GlobalVar(""), pattern}}), false), pattern_(pattern) { + for (const auto& pattern_var : pattern_vars) { + this->pattern_vars_.insert(pattern_var); + } + this->evaluated_symbols.push_back(SymbolMap()); + } + + bool Match(const For& top) { + const ForNode* pattern_top = pattern_->body.as()->block->body.as(); + ICHECK(pattern_top) << "Invalid pattern function"; + if (!VisitStmt(top, GetRef(pattern_top))) { + return false; + } + // Get evaluated symbols, buffers from the pattern. + for (const auto& arg : pattern_->params) { + auto it = pattern_->buffer_map.find(arg); + if (it != pattern_->buffer_map.end()) { + auto itt = rhs_buffer_map_.find((*it).second); + ICHECK(itt != rhs_buffer_map_.end()); + evaluated_buffers.push_back(itt->second); + } + } + return true; + } + + std::vector evaluated_symbols; + std::vector evaluated_buffers; + + private: + using ExprComparator::VisitExpr_; + + Optional QueryEvaluatedSymbols(const Var& var) { + for (const SymbolMap& symbol_map : evaluated_symbols) { + auto it = symbol_map.find(var); + if (it != symbol_map.end()) { + return it->second; + } + } + return NullOpt; + } + + bool VisitExpr(const PrimExpr& lhs, const PrimExpr& rhs) final { + if (const auto* op = rhs.as()) { + if (pattern_vars_.count(GetRef(op))) { + // special case for pattern vars + const auto* lhs_ptr = lhs.as(); + if (lhs_ptr == nullptr) { + if (lhs->IsInstance() || lhs->IsInstance()) { + Optional value = QueryEvaluatedSymbols(GetRef(op)); + if (value.defined()) { + if (!analyzer_.CanProveEqual(lhs, value.value())) return false; + } else { + evaluated_symbols.back()[GetRef(op)] = lhs; + } + return true; + } else { + return false; + } + } + } + } + // pattern_var * expr + if (const auto* rhs_ptr = rhs.as()) { + const auto* operand_a = rhs_ptr->a.as(); + const auto* operand_b = rhs_ptr->b.as(); + if (operand_a != nullptr && pattern_vars_.count(GetRef(operand_a))) { + // pattern var is on the left + evaluated_symbols.push_back(SymbolMap()); + bool match = VisitExpr(lhs, rhs_ptr->b); + SymbolMap symbol_map = std::move(evaluated_symbols.back()); + evaluated_symbols.pop_back(); + if (match) { + evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); + evaluated_symbols.back()[GetRef(operand_a)] = MakeConstScalar(rhs_ptr->b.dtype(), 1); + return true; + } + } + if (operand_b != nullptr && pattern_vars_.count(GetRef(operand_b))) { + // pattern var is on the right + evaluated_symbols.push_back(SymbolMap()); + bool match = VisitExpr(lhs, rhs_ptr->a); + SymbolMap symbol_map = std::move(evaluated_symbols.back()); + evaluated_symbols.pop_back(); + if (match) { + evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); + evaluated_symbols.back()[GetRef(operand_b)] = MakeConstScalar(rhs_ptr->a.dtype(), 1); + return true; + } + } + } + // pattern_Var + expr + if (const auto* rhs_ptr = rhs.as()) { + const auto* operand_a = rhs_ptr->a.as(); + const auto* operand_b = rhs_ptr->b.as(); + if (operand_a != nullptr && pattern_vars_.count(GetRef(operand_a))) { + // pattern var is on the left + evaluated_symbols.push_back(SymbolMap()); + bool match = VisitExpr(lhs, rhs_ptr->b); + SymbolMap symbol_map = std::move(evaluated_symbols.back()); + evaluated_symbols.pop_back(); + if (match) { + evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); + evaluated_symbols.back()[GetRef(operand_a)] = MakeConstScalar(rhs_ptr->b.dtype(), 0); + return true; + } + } + if (operand_b != nullptr && pattern_vars_.count(GetRef(operand_b))) { + // pattern var is on the right + evaluated_symbols.push_back(SymbolMap()); + bool match = VisitExpr(lhs, rhs_ptr->a); + SymbolMap symbol_map = std::move(evaluated_symbols.back()); + evaluated_symbols.pop_back(); + if (match) { + evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); + evaluated_symbols.back()[GetRef(operand_b)] = MakeConstScalar(rhs_ptr->a.dtype(), 0); + return true; + } + } + } + return TensorizeComparator::VisitExpr(lhs, rhs); + } + + bool VisitExpr_(const tir::AddNode* add, const PrimExpr& other) final { + const auto* rhs = other.as(); + if (rhs == nullptr) return false; + { + this->evaluated_symbols.push_back(SymbolMap()); + bool match = VisitExpr(add->a, rhs->a) && VisitExpr(add->b, rhs->b); + SymbolMap symbol_map = std::move(evaluated_symbols.back()); + this->evaluated_symbols.pop_back(); + if (match) { + this->evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); + return true; + } + } + { + this->evaluated_symbols.push_back(SymbolMap()); + bool match = VisitExpr(add->a, rhs->b) && VisitExpr(add->b, rhs->a); + SymbolMap symbol_map = std::move(evaluated_symbols.back()); + this->evaluated_symbols.pop_back(); + if (match) { + this->evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); + return true; + } + } + return false; + } + + bool VisitExpr_(const tir::MulNode* mul, const PrimExpr& other) final { + const auto* rhs = other.as(); + if (rhs == nullptr) return false; + { + this->evaluated_symbols.push_back(SymbolMap()); + bool match = VisitExpr(mul->a, rhs->a) && VisitExpr(mul->b, rhs->b); + SymbolMap symbol_map = std::move(evaluated_symbols.back()); + this->evaluated_symbols.pop_back(); + if (match) { + this->evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); + return true; + } + } + { + this->evaluated_symbols.push_back(SymbolMap()); + bool match = VisitExpr(mul->a, rhs->b) && VisitExpr(mul->b, rhs->a); + SymbolMap symbol_map = std::move(evaluated_symbols.back()); + this->evaluated_symbols.pop_back(); + if (match) { + this->evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); + return true; + } + } + return false; + } + + bool VisitExpr_(const tir::CallNode* call, const PrimExpr& other) final { + const auto* rhs = other.as(); + if (rhs == nullptr) return false; + const auto* lhs_op = call->op.as(); + const auto* rhs_op = rhs->op.as(); + if (lhs_op == nullptr || rhs_op == nullptr) return false; + if (lhs_op->name != rhs_op->name) return false; + if (call->args.size() != rhs->args.size()) return false; + for (size_t i = 0; i < call->args.size(); ++i) { + if (!VisitExpr(call->args[i], rhs->args[i])) return false; + } + return true; + } + + bool VisitStmt_(const tir::ForNode* op, const Stmt& other) final { + const auto* rhs = other.as(); + loop_stack_lhs_.push_back(GetRef(op)); + loop_stack_rhs_.push_back(GetRef(rhs)); + // The body of loop must be loop or BlockRealize + if (!op->body->IsInstance() && !op->body->IsInstance()) { + return false; + } + if (!rhs->body->IsInstance() && !rhs->body->IsInstance()) { + return false; + } + // Build mapping between the loop vars + if (!DefEqual(op->loop_var, rhs->loop_var)) return false; + // Only handle the case where the loop start from 0 + if (!is_zero(op->min) || !is_zero(rhs->min)) return false; + if (op->thread_binding.defined() || rhs->thread_binding.defined()) return false; + if (op->kind != ForKind::kSerial || op->kind != rhs->kind) return false; + if (!op->annotations.empty() || !rhs->annotations.empty()) return false; + // Match the extents of loops + if (!VisitExpr(op->extent, rhs->extent)) return false; + return VisitStmt(op->body, rhs->body); + } + + bool VisitStmt_(const tir::BlockNode* op, const Stmt& other) final { + const auto* rhs = other.as(); + // Check block equality. + // All iter vars and buffer regions including the order should match. + // When checking iter vars, DefEqual is used to remap variables. + if (!CompareArray(op->iter_vars, rhs->iter_vars, &ForMatcher::CompareIterVar)) { + return false; + } + // disallow alloc buffers inside the block + if (!op->alloc_buffers.empty() || !rhs->alloc_buffers.empty()) return false; + if (!CompareArray(op->writes, rhs->writes, &ForMatcher::CompareBufferRegion)) { + return false; + } + if (!CompareArray(op->reads, rhs->reads, &ForMatcher::CompareBufferRegion)) { + return false; + } + // The body of the block has to be BufferStore + if (!op->body->IsInstance() || !rhs->body->IsInstance()) { + return false; + } + // Handle init block + if (op->init.defined() && !rhs->init.defined()) return false; + if (!op->init.defined() && rhs->init.defined()) return false; + if (op->init.defined() && rhs->init.defined()) { + if (!VisitStmt(op->init.value(), rhs->init.value())) return false; + } + return VisitStmt(op->body, rhs->body); + } + + bool VisitStmt_(const BlockRealizeNode* op, const Stmt& other) final { + const auto* rhs = other.as(); + // Only allow trivial bindings + for (size_t i = 0; i < op->iter_values.size(); ++i) { + if (!op->iter_values[i].same_as(loop_stack_lhs_[i]->loop_var)) return false; + } + for (size_t i = 0; i < rhs->iter_values.size(); ++i) { + if (!rhs->iter_values[i].same_as(loop_stack_rhs_[i]->loop_var)) return false; + } + // Disallow predicates now + if (!is_one(op->predicate) || !is_one(rhs->predicate)) return false; + return VisitStmt(op->block, rhs->block); + } + + bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) { + const auto* rhs = other.as(); + return CompareBufferAccess(op, rhs) && VisitExpr(op->value, rhs->value); + } + + bool VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return CompareBufferAccess(op, rhs); + } + + bool CompareBuffer(const Buffer& lhs, const Buffer& rhs) { + if (lhs.same_as(rhs)) return true; + auto it = rhs_buffer_map_.find(rhs); + bool equal; + if (it != rhs_buffer_map_.end()) { + equal = (*it).second.same_as(lhs); + } else { + // Compare shape + if (lhs->shape.size() != rhs->shape.size()) return false; + for (size_t i = 0; i < lhs->shape.size(); ++i) { + if (!VisitExpr(lhs->shape[i], rhs->shape[i])) return false; + } + // Remap both buffer itself and buffer data + equal = + DefEqual(lhs->data, rhs->data) && lhs->dtype == rhs->dtype && lhs.scope() == rhs.scope(); + if (equal) { + rhs_buffer_map_[rhs] = lhs; + } + } + return equal; + } + + bool CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs) { + if (!CompareBuffer(lhs->buffer, rhs->buffer)) { + return false; + } + return CompareArray(lhs->region, rhs->region, &ForMatcher::CompareRange); + } + + template + bool CompareBufferAccess(const T* lhs, const T* rhs) { + if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false; + return CompareArray(lhs->indices, rhs->indices, &ForMatcher::VisitExpr); + } + + template + bool CompareArray(const Array& lhs, const Array& rhs, F Self::*cmp) { + if (lhs.same_as(rhs)) return true; + if (lhs.size() != rhs.size()) return false; + for (size_t i = 0; i < lhs.size(); ++i) { + if (!(static_cast(this)->*cmp)(lhs[i], rhs[i])) return false; + } + return true; + } + + arith::Analyzer analyzer_; + std::vector loop_stack_lhs_, loop_stack_rhs_; + tir::PrimFunc pattern_; + std::unordered_set pattern_vars_; +}; + +/*! \brief Analyze the function and match it with a list of patterns */ +class TIRPatternMatcher { + public: + static Array Match(Array patterns, Stmt body) { + TIRPatternMatcher matcher(patterns); + matcher.OpMatternMatch(body); + if (matcher.fail_) return {}; + return matcher.match_results_; + } + + private: + explicit TIRPatternMatcher(Array patterns) : patterns_(patterns) {} + + // Find an op that matches this block + bool BlockPatternMatch(const For& top) { + for (const TIRPattern& pattern : patterns_) { + tir::PrimFunc pattern_func = pattern; + Array pattern_symbolic_vars; + int buffer_count = pattern_func->buffer_map.size(); + for (int i = buffer_count; i < static_cast(pattern_func->params.size()); i++) { + pattern_symbolic_vars.push_back(pattern_func->params[i]); + } + ForMatcher block_matcher(pattern_func, pattern_symbolic_vars); + if (block_matcher.Match(top)) { + // We have found a match + Array symbol_values; + for (int i = buffer_count; i < static_cast(pattern_func->params.size()); i++) { + symbol_values.push_back(block_matcher.evaluated_symbols.back()[pattern_func->params[i]]); + } + match_results_.push_back( + MatchResult(pattern, symbol_values, block_matcher.evaluated_buffers)); + return true; + } + } + // The block fails to match any pattern + return false; + } + + // For each block in the body, try to find its corresponding pattern one by one + void OpMatternMatch(const Stmt& body) { + Array blocks; + if (body->IsInstance()) { + // {for} + blocks = {body}; + } else if (const SeqStmtNode* seq = body.as()) { + blocks = seq->seq; + } else { + fail_ = true; + return; + } + for (const Stmt& stmt : blocks) { + const ForNode* loop = stmt.as(); + if (loop == nullptr || !BlockPatternMatch(GetRef(loop))) { + break; + } + } + if (match_results_.empty()) { + fail_ = true; + } + } + /*! \brief Indicate whether we fail to match.*/ + bool fail_ = false; + /*! \brief The patterns we match the target stmt to.*/ + Array patterns_; + /*! \brief The results of the matching process.*/ + Array match_results_; +}; + +/*! \brief helper class to partition a function into 2 parts. Return function information which we + * can use to construct the two partitioned parts.*/ +class FunctionPartitioner : public StmtExprVisitor { + public: + explicit FunctionPartitioner(int num_matched_ops) : num_matched_ops_(num_matched_ops) {} + /*! \brief alloc_buffers for the first function */ + std::unordered_set allocs1; + /*! \brief alloc_buffers for the second function */ + std::unordered_set allocs2; + /*! \brief whether the current block is in the first function */ + Map block_partition; + /*! \brief input buffers for the first function */ + std::unordered_set input1; + /*! \brief input buffers for the second function */ + std::unordered_set input2; + /*! \brief The output buffer for the first function, which is also the input buffer for the second + function */ + Buffer intermediate_buffer; + /*! \brief Indicate whether we have failed. If failed, we will not do any further analysis and + directly return the original one. */ + bool fail = false; + + private: + void VisitStmt_(const BlockNode* op) final { + block_counter_++; + bool is_matching_ = block_counter_ <= num_matched_ops_; + if (block_counter_ == num_matched_ops_) { + allocs1.erase(intermediate_buffer); + } + for (const auto& read : op->reads) { + if (is_matching_) { + input1.insert(read->buffer); + } else { + input2.insert(read->buffer); + } + } + for (const auto& write : op->writes) { + if (is_matching_) { + allocs1.insert(write->buffer); + } else if (allocs1.count(write->buffer)) { + fail = true; + return; + } else { + allocs2.insert(write->buffer); + } + if (is_matching_) { + intermediate_buffer = write->buffer; + } else { + input2.insert(write->buffer); + } + } + block_partition.Set(GetRef(op), Bool(is_matching_)); + } + // The number of matched ops in the function + size_t num_matched_ops_; + size_t block_counter_ = 0; +}; + +/*! \brief remove parts according to block partition, and update the alloc_buffers for blocks */ +class BlockRemover : public StmtExprMutator { + public: + static Stmt RemoveBlockByPartition( + Stmt stmt, const Map& block_partition, + const std::unordered_set& allocs, + bool is_library_part) { + BlockRemover remover(block_partition, allocs, is_library_part); + return remover(stmt); + } + + private: + BlockRemover(const Map& block_partition, + const std::unordered_set& allocs, + bool is_library_part) + : block_partition(block_partition), allocs_(allocs), is_library_part_(is_library_part) {} + + Stmt VisitStmt_(const BlockNode* op) final { + Block block = Downcast(StmtExprMutator::VisitStmt_(op)); + ObjectPtr n = make_object(*block.operator->()); + if (op->name_hint != "root") { + ICHECK(block_partition.count(GetRef(op))); + bool block_is_library = block_partition[GetRef(op)]->value; + if (!(is_library_part_ ^ block_is_library)) { + n->body = block->body; + } else { + erased_ = true; + } + } + Array alloc_buffers; + for (const Buffer& b : block->alloc_buffers) { + if (allocs_.count(b)) { + alloc_buffers.push_back(b); + } + } + n->alloc_buffers = alloc_buffers; + return Block(n); + } + + Stmt VisitStmt_(const SeqStmtNode* op) final { + Array seq; + for (const Stmt& s : op->seq) { + Stmt new_s = VisitStmt(s); + if (erased_) { + erased_ = false; + } else { + seq.push_back(new_s); + } + } + return SeqStmt::Flatten(seq); + } + + bool erased_ = false; + Map block_partition; + std::unordered_set allocs_; + bool is_library_part_ = false; +}; + +/*! + * \brief Split the input function into two functions, one for the library kernel and one for the + * rest. + * \param func The input function. + * \param arg_partition The input arg for the functions after split. + * \param patterns The patterns to match. + * \param f_codegen The function to generate the code for the library kernel. + * \return A pair of functions, the first one is the library kernel and the second one is the + * rest. + */ +std::pair> SplitFunctions(PrimFunc func, + std::vector>* arg_partition, + Array patterns, + FCodegen f_codegen) { + // Step 1. Find the library kernel and the rest. + Stmt body = func->body.as()->block->body; + Array match_results = + TIRPatternMatcher::Match(patterns, func->body.as()->block->body); + if (match_results.empty()) { + return {func, NullOpt}; + } + Array codegen_result = f_codegen(match_results); + ICHECK(codegen_result.size() == 3); + String library_code = Downcast(codegen_result[0]); + int num_matched_ops = Downcast(codegen_result[1])->value; + Array func1_args = Downcast>(codegen_result[2]); + if (num_matched_ops == 0) { + return {func, NullOpt}; + } + FunctionPartitioner partitioner(num_matched_ops); + partitioner(body); + if (partitioner.fail) { + return {func, NullOpt}; + } + bool has_second_func = false; + for (const auto& pr : partitioner.block_partition) { + if (!pr.second->value) { + has_second_func = true; + break; + } + } + if (!has_second_func) { + // No need to split the function. + return {WithAttr(func, kLibraryKernel, library_code), NullOpt}; + } + // Step 2. Split the function into two functions. + Stmt body1 = BlockRemover::RemoveBlockByPartition(func->body, partitioner.block_partition, + partitioner.allocs1, true); + Stmt body2 = BlockRemover::RemoveBlockByPartition(func->body, partitioner.block_partition, + partitioner.allocs2, false); + // Step 3. Craft the first function. + Array new_params1; + std::vector arg_partition1; + ICHECK_LE(func1_args.size(), partitioner.input1.size()); + for (const auto& buffer : func1_args) { + ICHECK(partitioner.input1.find(buffer) != partitioner.input1.end()); + for (size_t i = 0; i < func->params.size(); i++) { + if (func->buffer_map[func->params[i]].same_as(buffer)) { + new_params1.push_back(func->params[i]); + arg_partition1.push_back(i); + break; + } + } + } + arg_partition->push_back(arg_partition1); + new_params1.push_back(Var("output", DataType::Handle())); + Map new_buffer_map1; + for (const auto& kv : func->buffer_map) { + if (partitioner.input1.count(kv.second)) { + new_buffer_map1.Set(kv.first, kv.second); + } + } + new_buffer_map1.Set(new_params1.back(), partitioner.intermediate_buffer); + PrimFunc func1 = PrimFunc(new_params1, body1, func->ret_type, new_buffer_map1, func->attrs); + func1 = WithAttr(func1, kLibraryKernel, library_code); + // Step 4. Craft the second function. + Array new_params2; + std::vector arg_partition2; + new_params2.push_back(Var("input", DataType::Handle())); + for (int i = 0; i < static_cast(func->params.size()); i++) { + Var param = func->params[i]; + if (partitioner.input2.count(func->buffer_map[param])) { + new_params2.push_back(param); + if (i != static_cast(func->params.size()) - 1) { + arg_partition2.push_back(i); + } + } + } + arg_partition->push_back(arg_partition2); + Map new_buffer_map2; + new_buffer_map2.Set(new_params2[0], partitioner.intermediate_buffer); + for (const auto& kv : func->buffer_map) { + if (partitioner.input2.count(kv.second)) { + new_buffer_map2.Set(kv.first, kv.second); + } + } + PrimFunc func2 = PrimFunc(new_params2, body2, func->ret_type, new_buffer_map2, func->attrs); + return {func1, func2}; +} +} // namespace tir + +namespace relax { +void StringReplace(std::string* subject, const std::string& search, const std::string& replace) { + for (size_t pos = 0; (pos = subject->find(search, pos)) != std::string::npos; + pos += replace.length()) { + subject->replace(pos, search.length(), replace); + } +} + +tvm::BaseFunc CodegenWithLibrary(const tir::PrimFuncNode* pf, String global_symbol) { + using namespace tvm::tir; + Optional library_code = pf->attrs.GetAttr(kLibraryKernel); + if (!library_code.defined()) { + return GetRef(pf); + } + std::string source = library_code.value(); + StringReplace(&source, "{global_symbol}", global_symbol); + ExternFunc ret(global_symbol); + ret = WithAttrs(std::move(ret), Map{ + {String(kCSource), String(source)}, + {String(kCSourceFmt), String(kCSourceFmtCuda)}, + }); + return ret; +} + +/*! \brief Emit 2 calls to the library kernel and the rest of the function. */ +class SplitMutator : public ExprMutator { + public: + SplitMutator(const tvm::IRModule& mod, Array patterns, FCodegen fcodegen) + : ExprMutator(mod), mod_(mod), patterns_(patterns), fcodegen_(fcodegen) {} + static IRModule Transform(const IRModule& mod, Array patterns, FCodegen fcodegen) { + SplitMutator mutator(mod, patterns, fcodegen); + for (auto& kv : mod->functions) { + if (auto* func = kv.second.as()) { + Function new_func = Downcast(mutator(GetRef(func))); + mutator.builder_->UpdateFunction(kv.first, new_func); + } + } + return mutator.builder_->GetContextIRModule(); + } + + private: + using ExprMutator::VisitExpr_; + + inline Array GetCallTIRArgs(Expr args) { + if (args.as()) { + return args.as()->fields; + } else { + return {args}; + } + } + + Expr VisitExpr_(const CallNode* op) final { + Call call = Downcast(ExprMutator::VisitExpr_(op)); + static const Op& call_tir_op_ = Op::Get("relax.call_tir"); + static const Op& call_dps_packed_ = Op::Get("relax.call_dps_packed"); + if (!call->op.same_as(call_tir_op_)) return call; + // the first argument is the function to be called + const auto* gv_ptr = call->args[0].as(); + if (gv_ptr == nullptr) return call; + GlobalVar gv = GetRef(gv_ptr); + // retrieve the function from the module and split it + tir::PrimFunc func = Downcast(mod_->Lookup(gv)); + std::vector> arg_partition; + // split the function into two functions, one for the library kernel and one for the rest. + std::pair> split_funcs = + tir::SplitFunctions(func, &arg_partition, patterns_, fcodegen_); + if (!split_funcs.second.defined()) { + // no need to split, the function itself a library kernel + tvm::BaseFunc lib_func = CodegenWithLibrary(split_funcs.first.get(), gv->name_hint); + if (lib_func->IsInstance()) return GetRef(op); + // Update the function in the module with the library kernel + ICHECK(lib_func->IsInstance()); + builder_->UpdateFunction(gv, lib_func); + // emit the call to the library kernel + ObjectPtr new_call = make_object(*call.operator->()); + new_call->op = this->call_dps_packed_; + new_call->args = {lib_func, call->args[1]}; + return Call(new_call); + } + tir::PrimFunc func1 = tir::RenewDefs(split_funcs.first); + tir::PrimFunc func2 = tir::RenewDefs(split_funcs.second.value()); + ICHECK(arg_partition.size() == 2); + // emit the first call to the library kernel + Array args1; + for (int p : arg_partition[0]) { + args1.push_back(GetCallTIRArgs(call->args[1])[p]); + } + // replace the function in the module with the library kernel + tvm::BaseFunc lib_func = CodegenWithLibrary(func1.get(), gv->name_hint); + if (lib_func->IsInstance()) return GetRef(op); + ICHECK(lib_func->IsInstance()); + builder_->UpdateFunction(gv, lib_func); + tir::Buffer intermediate_buffer = func1->buffer_map.at(func1->params.back()); + DataType dtype = intermediate_buffer->dtype; + Call call1(call_dps_packed_, {lib_func, Tuple(args1)}, call->attrs, + {TensorStructInfo(ShapeExpr(intermediate_buffer->shape), dtype)}); + Var call_var1 = builder_->Emit(call1); + // emit the second call to the rest of the function + Array args2; + args2.push_back(call_var1); + for (int p : arg_partition[1]) { + args2.push_back(GetCallTIRArgs(call->args[1])[p]); + } + GlobalVar gv2 = builder_->AddFunction(func2, "unfused_epilogue"); + Call call2(call_tir_op_, {gv2, Tuple(args2)}, call->attrs, call->sinfo_args); + builder_->UpdateFunction(gv, WithoutAttr(func, "global_symbol")); + return call2; + } + + const Op& call_dps_packed_ = Op::Get("relax.call_dps_packed"); + tvm::IRModule mod_; + Array patterns_; + FCodegen fcodegen_; +}; + +namespace transform { +Pass SplitCallTIRByPattern(Array patterns, FCodegen fcodegen) { + runtime::TypedPackedFunc pass_func = // + [=](IRModule m, PassContext pc) { return SplitMutator::Transform(m, patterns, fcodegen); }; + return CreateModulePass(/*pass_function=*/pass_func, // + /*opt_level=*/0, // + /*pass_name=*/"SplitCallTIRByPattern", // + /*required=*/{}); +} +TVM_REGISTER_GLOBAL("relax.transform.SplitCallTIRByPattern").set_body_typed(SplitCallTIRByPattern); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_codegen_tir_cutlass.py b/tests/python/relax/test_codegen_tir_cutlass.py new file mode 100644 index 000000000000..7e3642b1b6ce --- /dev/null +++ b/tests/python/relax/test_codegen_tir_cutlass.py @@ -0,0 +1,755 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations +import tempfile + +from tvm import relax, runtime +import tvm +import tvm.testing +from tvm import relax +import scipy +from scipy.special import erf +import numpy as np +from tvm.target import Target +from tvm.relax.vm_build import build as relax_build +from tvm.script.ir_builder import relax as R +from tvm.script.ir_builder import ir as I +from tvm.script.ir_builder import tir as T +from tvm.script.ir_builder import IRBuilder + +from tvm.relax.backend_tir import get_tir_pattern +from tvm.relax.backend_tir.contrib.cutlass import cutlass_fcodegen, compile_options + +A_TYPE = "float16" +B_TYPE = "float16" +C_TYPE = "float16" + +target = Target("cuda") + + +def f_run(rt_mod: runtime.Module, device: runtime.ndarray.Device, *input): + vm = relax.vm.VirtualMachine(rt_mod=rt_mod, device=device) + return vm["main"](*input) + + +def build(mod): + mod = relax.transform.LegalizeOps()(mod) + mod = relax.transform.AnnotateTIROpPattern()(mod) + mod = relax.transform.FuseOps()(mod) + mod = relax.transform.FuseTIR()(mod) + mod = relax.transform.SplitCallTIRByPattern(get_tir_pattern(), cutlass_fcodegen())(mod) + mod = relax.transform.DeadCodeElimination()(mod) + print(mod.script()) + f = tempfile.NamedTemporaryFile(suffix=".so", delete=True) + executable = relax_build(mod, target) + + executable.mod.export_library(f.name, **compile_options(target)) + rt_mod = runtime.load_module(f.name) + f.close() + return rt_mod + + +def constructGEMM(M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((K, N), B_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + R.output(C) + (C,) = df.output_vars + R.func_ret_value(C) + relax_mod = ib.get() + return relax_mod + + +@tvm.testing.requires_cutlass +def test_cutlass_dense(): + m, n, k = 128, 64, 256 + executable = build(constructGEMM(m, n, k)) + dev = tvm.cuda() + A = np.random.randn(m, k).astype("float16") + B = np.random.randn(k, n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + result = f_run(executable, dev, A_tvm, B_tvm) + np.testing.assert_allclose(result.numpy(), A @ B, rtol=5e-2, atol=5e-2) + + +def constructGEMM_bias(M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((K, N), B_TYPE) + ) # pylint: disable=invalid-name + bias = R.arg( + "bias", relax.TensorStructInfo((1, N), A_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + D = R.emit(R.add(C, bias)) + R.output(D) + (D,) = df.output_vars + R.func_ret_value(D) + relax_mod = ib.get() + return relax_mod + + +def constructGEMM_bias2(M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((K, N), B_TYPE) + ) # pylint: disable=invalid-name + bias = R.arg( + "bias", relax.TensorStructInfo((N,), A_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + D = R.emit(R.add(C, bias)) + R.output(D) + (D,) = df.output_vars + R.func_ret_value(D) + relax_mod = ib.get() + return relax_mod + + +@tvm.testing.requires_cutlass +def test_cutlass_dense_bias(): + m, n, k = 128, 64, 256 + executable = build(constructGEMM_bias(m, n, k)) + dev = tvm.cuda() + A = np.random.randn(m, k).astype("float16") + B = np.random.randn(k, n).astype("float16") + bias = np.random.randn(1, n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + bias_tvm = tvm.nd.array(bias, dev) + result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) + np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2, atol=5e-2) + + +@tvm.testing.requires_cutlass +def test_cutlass_dense_bias2(): + m, n, k = 128, 64, 256 + executable = build(constructGEMM_bias2(m, n, k)) + dev = tvm.cuda() + A = np.random.randn(m, k).astype("float16") + B = np.random.randn(k, n).astype("float16") + bias = np.random.randn(n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + bias_tvm = tvm.nd.array(bias, dev) + result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) + np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2, atol=5e-2) + + +def constructGEMM_bias_relu(M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((K, N), B_TYPE) + ) # pylint: disable=invalid-name + bias = R.arg( + "bias", relax.TensorStructInfo((1, N), A_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + D = R.emit(R.add(C, bias)) + E = R.emit(R.nn.relu(D)) + R.output(E) + (E,) = df.output_vars + R.func_ret_value(E) + relax_mod = ib.get() + return relax_mod + + +@tvm.testing.requires_cutlass +def test_cutlass_dense_bias_relu(): + m, n, k = 128, 64, 256 + executable = build(constructGEMM_bias_relu(m, n, k)) + dev = tvm.cuda() + A = np.random.randn(m, k).astype("float16") + B = np.random.randn(k, n).astype("float16") + bias = np.random.randn(1, n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + bias_tvm = tvm.nd.array(bias, dev) + result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) + np.testing.assert_allclose(result.numpy(), np.maximum(A @ B + bias, 0), rtol=5e-2, atol=5e-2) + + +def constructBatchGEMM(batch, M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((batch, M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((K, N), B_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + R.output(C) + (C,) = df.output_vars + R.func_ret_value(C) + relax_mod = ib.get() + return relax_mod + + +@tvm.testing.requires_cutlass +def test_cutlass_batch_dense(): + b, m, n, k = 2, 128, 256, 64 + executable = build(constructBatchGEMM(b, m, n, k)) + dev = tvm.cuda() + A = np.random.randn(b, m, k).astype("float16") + B = np.random.randn(k, n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + result = f_run(executable, dev, A_tvm, B_tvm) + np.testing.assert_allclose(result.numpy(), A @ B, rtol=5e-2, atol=5e-2) + + +def constructBatchGEMM2(batch, M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((batch, M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((batch, K, N), B_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + R.output(C) + (C,) = df.output_vars + R.func_ret_value(C) + relax_mod = ib.get() + return relax_mod + + +@tvm.testing.requires_cutlass +def test_cutlass_batch_dense2(): + b, m, n, k = 2, 128, 256, 64 + executable = build(constructBatchGEMM2(b, m, n, k)) + dev = tvm.cuda() + A = np.random.randn(b, m, k).astype("float16") + B = np.random.randn(b, k, n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + result = f_run(executable, dev, A_tvm, B_tvm) + np.testing.assert_allclose(result.numpy(), A @ B, rtol=5e-2, atol=5e-2) + + +def constructBatchGEMM_bias(batch, M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((batch, M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((K, N), B_TYPE) + ) # pylint: disable=invalid-name + bias = R.arg( + "bias", relax.TensorStructInfo((1, N), A_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + D = R.emit(R.add(C, bias)) + R.output(D) + (D,) = df.output_vars + R.func_ret_value(D) + relax_mod = ib.get() + return relax_mod + + +@tvm.testing.requires_cutlass +def test_cutlass_batch_dense_bias(): + b, m, n, k = 2, 128, 256, 64 + executable = build(constructBatchGEMM_bias(b, m, n, k)) + dev = tvm.cuda() + A = np.random.randn(b, m, k).astype("float16") + B = np.random.randn(k, n).astype("float16") + bias = np.random.randn(1, n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + bias_tvm = tvm.nd.array(bias, dev) + result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) + np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2, atol=5e-2) + + +def constructBatchGEMM_bias2(batch, M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((batch, M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((K, N), B_TYPE) + ) # pylint: disable=invalid-name + bias = R.arg( + "bias", relax.TensorStructInfo((N,), A_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + D = R.emit(R.add(C, bias)) + R.output(D) + (D,) = df.output_vars + R.func_ret_value(D) + relax_mod = ib.get() + return relax_mod + + +@tvm.testing.requires_cutlass +def test_cutlass_batch_dense_bias2(): + b, m, n, k = 2, 128, 256, 64 + executable = build(constructBatchGEMM_bias2(b, m, n, k)) + dev = tvm.cuda() + A = np.random.randn(b, m, k).astype("float16") + B = np.random.randn(k, n).astype("float16") + bias = np.random.randn(n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + bias_tvm = tvm.nd.array(bias, dev) + result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) + np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2, atol=5e-2) + + +def constructBatchGEMM_bias2_gelu(batch, M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((batch, M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((K, N), B_TYPE) + ) # pylint: disable=invalid-name + bias = R.arg( + "bias", relax.TensorStructInfo((N,), A_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + D = R.emit(R.add(C, bias)) + E = R.emit(R.nn.gelu(D)) + R.output(E) + (E,) = df.output_vars + R.func_ret_value(E) + relax_mod = ib.get() + return relax_mod + + +@tvm.testing.requires_cutlass +def test_cutlass_batch_dense_bias2_gelu(): + b, m, n, k = 2, 128, 64, 256 + executable = build(constructBatchGEMM_bias2_gelu(b, m, n, k)) + dev = tvm.cuda() + A = np.random.randn(b, m, k).astype("float16") + B = np.random.randn(k, n).astype("float16") + bias = np.random.randn(n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + bias_tvm = tvm.nd.array(bias, dev) + result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) + C = A @ B + bias + O = 0.5 * C * (1 + erf(C / np.sqrt(2))) + np.testing.assert_allclose(result.numpy(), O, rtol=5e-2, atol=5e-2) + + +def constructBatchGEMM_bias2_mul(batch, M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((batch, M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((K, N), B_TYPE) + ) # pylint: disable=invalid-name + bias = R.arg( + "bias", relax.TensorStructInfo((N,), A_TYPE) + ) # pylint: disable=invalid-name + residual = R.arg("residual", relax.TensorStructInfo((batch, M, N), A_TYPE)) + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + D = R.emit(R.add(C, bias)) + E = R.emit(R.multiply(D, residual)) + R.output(E) + (E,) = df.output_vars + R.func_ret_value(E) + relax_mod = ib.get() + return relax_mod + + +@tvm.testing.requires_cutlass +def test_cutlass_batch_dense_bias2_mul(): + b, m, n, k = 2, 128, 256, 64 + executable = build(constructBatchGEMM_bias2_mul(b, m, n, k)) + dev = tvm.cuda() + A = np.random.randn(b, m, k).astype("float16") + B = np.random.randn(k, n).astype("float16") + bias = np.random.randn(n).astype("float16") + residual = np.random.randn(b, m, n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + bias_tvm = tvm.nd.array(bias, dev) + residual_tvm = tvm.nd.array(residual, dev) + result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm, residual_tvm) + np.testing.assert_allclose(result.numpy(), ((A @ B) + bias) * residual, rtol=5e-2, atol=5e-2) + + +def constructBatchGEMM2_bias(batch, M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((batch, M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((batch, K, N), B_TYPE) + ) # pylint: disable=invalid-name + bias = R.arg( + "bias", relax.TensorStructInfo((1, N), A_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + D = R.emit(R.add(C, bias)) + R.output(D) + (D,) = df.output_vars + R.func_ret_value(D) + relax_mod = ib.get() + return relax_mod + + +@tvm.testing.requires_cutlass +def test_cutlass_batch_dense2_bias(): + b, m, n, k = 2, 128, 256, 64 + executable = build(constructBatchGEMM2_bias(b, m, n, k)) + dev = tvm.cuda() + A = np.random.randn(b, m, k).astype("float16") + B = np.random.randn(b, k, n).astype("float16") + bias = np.random.randn(1, n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + bias_tvm = tvm.nd.array(bias, dev) + result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) + np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2, atol=5e-2) + + +def constructConv2D(N, C, H, W, KH, KW, O, strides, padding, dilation): + from tvm.script.ir_builder import IRBuilder + from tvm.script.ir_builder import ir as I + from tvm.script.ir_builder import relax as R + from tvm.script.ir_builder import tir as T + + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + x = R.arg( + "x", relax.TensorStructInfo((N, H, W, C), A_TYPE) + ) # pylint: disable=invalid-name + w = R.arg( + "w", relax.TensorStructInfo((O, KH, KW, C), B_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit( + R.nn.conv2d( + x, + w, + strides=strides, + padding=padding, + dilation=dilation, + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype=C_TYPE, + ) + ) + R.output(C) + (C,) = df.output_vars + R.func_ret_value(C) + mod = ib.get() + return mod + + +@tvm.testing.requires_cutlass +def test_cutlass_conv2d(): + import torch + + n, c, h, w = 1, 3, 224, 224 + kh, kw, o = 3, 3, 64 + counter = 0 + for strides in [(1, 1), (2, 2)]: + for padding in [(0, 0), (3, 3)]: + for dilation in [(1, 1), (4, 4)]: + executable = build( + constructConv2D(n, c, h, w, kh, kw, o, strides, padding, dilation) + ) + dev = tvm.cuda() + np.random.seed(0) + A = np.random.randn(n, h, w, c).astype("float16") + B = np.random.randn(o, kh, kw, c).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + result = f_run(executable, dev, A_tvm, B_tvm) + A_torch = torch.from_numpy(np.transpose(A, (0, 3, 1, 2))).to( + torch.float32 + ) # .cuda() + B_torch = torch.from_numpy(np.transpose(B, (0, 3, 1, 2))).to( + torch.float32 + ) # .cuda() + C_torch = torch.nn.functional.conv2d( + A_torch, B_torch, stride=strides, padding=padding, dilation=dilation + ) + np.testing.assert_allclose( + np.transpose(result.numpy(), (0, 3, 1, 2)), + C_torch.cpu().numpy(), + rtol=5e-2, + atol=5e-2, + ) + counter += 1 + + +def constructConv2D_bias(N, C, H, W, KH, KW, O, strides, padding, dilation): + from tvm.script.ir_builder import IRBuilder + from tvm.script.ir_builder import ir as I + from tvm.script.ir_builder import relax as R + from tvm.script.ir_builder import tir as T + + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + x = R.arg( + "x", relax.TensorStructInfo((N, H, W, C), A_TYPE) + ) # pylint: disable=invalid-name + w = R.arg( + "w", relax.TensorStructInfo((O, KH, KW, C), B_TYPE) + ) # pylint: disable=invalid-name + bias = R.arg( + "bias", relax.TensorStructInfo((1, 1, 1, O), A_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit( + R.nn.conv2d( + x, + w, + strides=strides, + padding=padding, + dilation=dilation, + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype=C_TYPE, + ) + ) + D = R.emit(R.add(C, bias)) + R.output(D) + (D,) = df.output_vars + R.func_ret_value(D) + mod = ib.get() + return mod + + +@tvm.testing.requires_cutlass +def test_cutlass_conv2d_bias(): + import torch + + c, h, w = 3, 224, 224 + kh, kw, o = 3, 3, 64 + counter = 0 + for n in [1, 2]: + for strides in [(1, 1), (2, 2)]: + for padding in [(0, 0), (3, 3)]: + for dilation in [(1, 1), (4, 4)]: + filename = "/tmp/" + "test_transform_cutlass_codegen" + str(counter) + ".so" + executable = build( + constructConv2D_bias(n, c, h, w, kh, kw, o, strides, padding, dilation), + ) + dev = tvm.cuda() + np.random.seed(0) + A = np.random.randn(n, h, w, c).astype("float16") + B = np.random.randn(o, kh, kw, c).astype("float16") + bias = np.random.randn(o).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + bias_tvm = tvm.nd.array(bias.reshape(1, 1, 1, o), dev) + result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) + A_torch = torch.from_numpy(np.transpose(A, (0, 3, 1, 2))).to( + torch.float32 + ) # .cuda() + B_torch = torch.from_numpy(np.transpose(B, (0, 3, 1, 2))).to( + torch.float32 + ) # .cuda() + bias_torch = torch.from_numpy(bias).to(torch.float32) # .cuda() + C_torch = torch.nn.functional.conv2d( + A_torch, + B_torch, + bias=bias_torch, + stride=strides, + padding=padding, + dilation=dilation, + ) + np.testing.assert_allclose( + np.transpose(result.numpy(), (0, 3, 1, 2)), + C_torch.cpu().numpy(), + rtol=5e-2, + atol=5e-2, + ) + counter += 1 + + +def constructConv2D_bias_add(N, C, H, W, KH, KW, O, OH, OW, strides, padding, dilation): + from tvm.script.ir_builder import IRBuilder + from tvm.script.ir_builder import ir as I + from tvm.script.ir_builder import relax as R + from tvm.script.ir_builder import tir as T + + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + x = R.arg( + "x", relax.TensorStructInfo((N, H, W, C), A_TYPE) + ) # pylint: disable=invalid-name + w = R.arg( + "w", relax.TensorStructInfo((O, KH, KW, C), B_TYPE) + ) # pylint: disable=invalid-name + bias = R.arg( + "bias", relax.TensorStructInfo((1, 1, 1, O), A_TYPE) + ) # pylint: disable=invalid-name + res = R.arg( + "res", relax.TensorStructInfo((N, OH, OW, O), A_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit( + R.nn.conv2d( + x, + w, + strides=strides, + padding=padding, + dilation=dilation, + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype=C_TYPE, + ) + ) + D = R.emit(R.add(C, bias)) + E = R.emit(R.add(D, res)) + R.output(E) + (E,) = df.output_vars + R.func_ret_value(E) + mod = ib.get() + return mod + + +@tvm.testing.requires_cutlass +def test_cutlass_conv2d_bias_add(): + import torch + + n, c, h, w = 2, 3, 224, 224 + kh, kw, o = 3, 3, 64 + counter = 0 + for strides in [(1, 1), (2, 2)]: + for padding in [(0, 0), (3, 3)]: + for dilation in [(1, 1), (4, 4)]: + oh = (h + 2 * padding[0] - dilation[0] * (kh - 1) - 1) // strides[0] + 1 + ow = (w + 2 * padding[1] - dilation[1] * (kw - 1) - 1) // strides[1] + 1 + executable = build( + constructConv2D_bias_add( + n, c, h, w, kh, kw, o, oh, ow, strides, padding, dilation + ) + ) + dev = tvm.cuda() + np.random.seed(0) + A = np.random.randn(n, h, w, c).astype("float16") + B = np.random.randn(o, kh, kw, c).astype("float16") + bias = np.random.randn(o).astype("float16") + res = np.random.randn(n, oh, ow, o).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + bias_tvm = tvm.nd.array(bias.reshape(1, 1, 1, o), dev) + res_tvm = tvm.nd.array(res, dev) + result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm, res_tvm) + A_torch = torch.from_numpy(np.transpose(A, (0, 3, 1, 2))).to( + torch.float32 + ) # .cuda() + B_torch = torch.from_numpy(np.transpose(B, (0, 3, 1, 2))).to( + torch.float32 + ) # .cuda() + bias_torch = torch.from_numpy(bias).to(torch.float32) # .cuda() + res_torch = torch.from_numpy(np.transpose(res, (0, 3, 1, 2))).to( + torch.float32 + ) # .cuda() + C_torch = torch.nn.functional.conv2d( + A_torch, + B_torch, + bias=bias_torch, + stride=strides, + padding=padding, + dilation=dilation, + ) + D_torch = C_torch + res_torch + np.testing.assert_allclose( + np.transpose(result.numpy(), (0, 3, 1, 2)), + D_torch.cpu().numpy(), + rtol=5e-2, + atol=5e-2, + ) + counter += 1 + + +if __name__ == "__main__": + tvm.testing.main() From ab49c45ac1bf2b985122786a7803b5656247b397 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Thu, 23 Mar 2023 20:19:46 -0700 Subject: [PATCH 2/2] . --- .../python/relax/test_codegen_tir_cutlass.py | 106 +++++------------- 1 file changed, 30 insertions(+), 76 deletions(-) diff --git a/tests/python/relax/test_codegen_tir_cutlass.py b/tests/python/relax/test_codegen_tir_cutlass.py index 7e3642b1b6ce..9c960ed355d3 100644 --- a/tests/python/relax/test_codegen_tir_cutlass.py +++ b/tests/python/relax/test_codegen_tir_cutlass.py @@ -64,6 +64,16 @@ def build(mod): return rt_mod +def build_and_run_reference(mod, inputs_np): + mod = relax.transform.LegalizeOps()(mod) + dev = tvm.device("llvm", 0) + ex = relax.build(mod, "llvm") + vm = relax.VirtualMachine(ex, dev) + f = vm["main"] + inputs = [tvm.nd.array(inp, dev) for inp in inputs_np] + return f(*inputs).numpy() + + def constructGEMM(M, N, K): with IRBuilder() as ib: # pylint: disable=invalid-name with I.ir_module() as frame: @@ -523,17 +533,13 @@ def constructConv2D(N, C, H, W, KH, KW, O, strides, padding, dilation): @tvm.testing.requires_cutlass def test_cutlass_conv2d(): - import torch - n, c, h, w = 1, 3, 224, 224 kh, kw, o = 3, 3, 64 - counter = 0 for strides in [(1, 1), (2, 2)]: for padding in [(0, 0), (3, 3)]: for dilation in [(1, 1), (4, 4)]: - executable = build( - constructConv2D(n, c, h, w, kh, kw, o, strides, padding, dilation) - ) + mod = constructConv2D(n, c, h, w, kh, kw, o, strides, padding, dilation) + executable = build(mod) dev = tvm.cuda() np.random.seed(0) A = np.random.randn(n, h, w, c).astype("float16") @@ -541,22 +547,13 @@ def test_cutlass_conv2d(): A_tvm = tvm.nd.array(A, dev) B_tvm = tvm.nd.array(B, dev) result = f_run(executable, dev, A_tvm, B_tvm) - A_torch = torch.from_numpy(np.transpose(A, (0, 3, 1, 2))).to( - torch.float32 - ) # .cuda() - B_torch = torch.from_numpy(np.transpose(B, (0, 3, 1, 2))).to( - torch.float32 - ) # .cuda() - C_torch = torch.nn.functional.conv2d( - A_torch, B_torch, stride=strides, padding=padding, dilation=dilation - ) + result_ref = build_and_run_reference(mod, [A, B]) np.testing.assert_allclose( - np.transpose(result.numpy(), (0, 3, 1, 2)), - C_torch.cpu().numpy(), + result.numpy(), + result_ref, rtol=5e-2, atol=5e-2, ) - counter += 1 def constructConv2D_bias(N, C, H, W, KH, KW, O, strides, padding, dilation): @@ -603,50 +600,30 @@ def constructConv2D_bias(N, C, H, W, KH, KW, O, strides, padding, dilation): @tvm.testing.requires_cutlass def test_cutlass_conv2d_bias(): - import torch - c, h, w = 3, 224, 224 kh, kw, o = 3, 3, 64 - counter = 0 for n in [1, 2]: for strides in [(1, 1), (2, 2)]: for padding in [(0, 0), (3, 3)]: for dilation in [(1, 1), (4, 4)]: - filename = "/tmp/" + "test_transform_cutlass_codegen" + str(counter) + ".so" - executable = build( - constructConv2D_bias(n, c, h, w, kh, kw, o, strides, padding, dilation), - ) + mod = constructConv2D_bias(n, c, h, w, kh, kw, o, strides, padding, dilation) + executable = build(mod) dev = tvm.cuda() np.random.seed(0) A = np.random.randn(n, h, w, c).astype("float16") B = np.random.randn(o, kh, kw, c).astype("float16") - bias = np.random.randn(o).astype("float16") + bias = np.random.randn(1, 1, 1, o).astype("float16") A_tvm = tvm.nd.array(A, dev) B_tvm = tvm.nd.array(B, dev) - bias_tvm = tvm.nd.array(bias.reshape(1, 1, 1, o), dev) + bias_tvm = tvm.nd.array(bias, dev) result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) - A_torch = torch.from_numpy(np.transpose(A, (0, 3, 1, 2))).to( - torch.float32 - ) # .cuda() - B_torch = torch.from_numpy(np.transpose(B, (0, 3, 1, 2))).to( - torch.float32 - ) # .cuda() - bias_torch = torch.from_numpy(bias).to(torch.float32) # .cuda() - C_torch = torch.nn.functional.conv2d( - A_torch, - B_torch, - bias=bias_torch, - stride=strides, - padding=padding, - dilation=dilation, - ) + result_ref = build_and_run_reference(mod, [A, B, bias]) np.testing.assert_allclose( - np.transpose(result.numpy(), (0, 3, 1, 2)), - C_torch.cpu().numpy(), + result.numpy(), + result_ref, rtol=5e-2, atol=5e-2, ) - counter += 1 def constructConv2D_bias_add(N, C, H, W, KH, KW, O, OH, OW, strides, padding, dilation): @@ -697,58 +674,35 @@ def constructConv2D_bias_add(N, C, H, W, KH, KW, O, OH, OW, strides, padding, di @tvm.testing.requires_cutlass def test_cutlass_conv2d_bias_add(): - import torch - n, c, h, w = 2, 3, 224, 224 kh, kw, o = 3, 3, 64 - counter = 0 for strides in [(1, 1), (2, 2)]: for padding in [(0, 0), (3, 3)]: for dilation in [(1, 1), (4, 4)]: oh = (h + 2 * padding[0] - dilation[0] * (kh - 1) - 1) // strides[0] + 1 ow = (w + 2 * padding[1] - dilation[1] * (kw - 1) - 1) // strides[1] + 1 - executable = build( - constructConv2D_bias_add( - n, c, h, w, kh, kw, o, oh, ow, strides, padding, dilation - ) + mod = constructConv2D_bias_add( + n, c, h, w, kh, kw, o, oh, ow, strides, padding, dilation ) + executable = build(mod) dev = tvm.cuda() np.random.seed(0) A = np.random.randn(n, h, w, c).astype("float16") B = np.random.randn(o, kh, kw, c).astype("float16") - bias = np.random.randn(o).astype("float16") + bias = np.random.randn(1, 1, 1, o).astype("float16") res = np.random.randn(n, oh, ow, o).astype("float16") A_tvm = tvm.nd.array(A, dev) B_tvm = tvm.nd.array(B, dev) - bias_tvm = tvm.nd.array(bias.reshape(1, 1, 1, o), dev) + bias_tvm = tvm.nd.array(bias, dev) res_tvm = tvm.nd.array(res, dev) result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm, res_tvm) - A_torch = torch.from_numpy(np.transpose(A, (0, 3, 1, 2))).to( - torch.float32 - ) # .cuda() - B_torch = torch.from_numpy(np.transpose(B, (0, 3, 1, 2))).to( - torch.float32 - ) # .cuda() - bias_torch = torch.from_numpy(bias).to(torch.float32) # .cuda() - res_torch = torch.from_numpy(np.transpose(res, (0, 3, 1, 2))).to( - torch.float32 - ) # .cuda() - C_torch = torch.nn.functional.conv2d( - A_torch, - B_torch, - bias=bias_torch, - stride=strides, - padding=padding, - dilation=dilation, - ) - D_torch = C_torch + res_torch + result_ref = build_and_run_reference(mod, [A, B, bias, res]) np.testing.assert_allclose( - np.transpose(result.numpy(), (0, 3, 1, 2)), - D_torch.cpu().numpy(), + result.numpy(), + result_ref, rtol=5e-2, atol=5e-2, ) - counter += 1 if __name__ == "__main__":