diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake index 96d5922e84d9..1502c4f8bc80 100644 --- a/cmake/modules/CUDA.cmake +++ b/cmake/modules/CUDA.cmake @@ -50,8 +50,8 @@ if(USE_CUDA) if(USE_CUBLAS) message(STATUS "Build with cuBLAS support") - tvm_file_glob(GLOB CUBLAS_RELAY_CONTRIB_SRC src/relay/backend/contrib/cublas/*.cc) - list(APPEND COMPILER_SRCS ${CUBLAS_RELAY_CONTRIB_SRC}) + tvm_file_glob(GLOB CUBLAS_CONTRIB_SRC src/relay/backend/contrib/cublas/*.cc src/relax/backend/contrib/cublas/*.cc) + list(APPEND COMPILER_SRCS ${CUBLAS_CONTRIB_SRC}) tvm_file_glob(GLOB CONTRIB_CUBLAS_SRCS src/runtime/contrib/cublas/*.cc) list(APPEND RUNTIME_SRCS ${CONTRIB_CUBLAS_SRCS}) list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUBLAS_LIBRARY}) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 93d1331ac443..43494991a04c 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -560,22 +560,9 @@ def _extract_relax_function_signature(f): def _extract_arg_idx(pattern_name, f): - pattern_entry = relax.backend.get_pattern(pattern_name) - if pattern_entry is None: - raise ValueError(f"Unsupported op_type {pattern_name}") - var2val = relax.analysis.get_var2val(f) - matched_expr = pattern_entry.pattern.extract_matched_expr(f.body.body, var2val) - - func_args = list(f.params) - - arg_idx = {} - for name, annotation_pattern in pattern_entry.annotation_patterns.items(): - arg_expr = matched_expr[annotation_pattern] - if arg_expr not in func_args: - continue - arg_idx[name] = func_args.index(arg_expr) - - return arg_idx + extract_func = tvm.get_global_func("relax.contrib.extract_arg_idx") + arg_indices = extract_func(pattern_name, f) + return {k: int(v) for k, v in arg_indices.items()} def is_shape_valid_for_cutlass_matmul( diff --git a/python/tvm/relax/backend/contrib/cublas.py b/python/tvm/relax/backend/contrib/cublas.py new file mode 100644 index 000000000000..627c9369935b --- /dev/null +++ b/python/tvm/relax/backend/contrib/cublas.py @@ -0,0 +1,154 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Pattern table for cuBLAS backend""" +import operator +from functools import reduce + +import tvm +from tvm.relax import transform +from tvm.relax.transform import PatternCheckContext + +from ..pattern_registry import get_patterns_with_prefix, register_patterns +from ..patterns import make_matmul_pattern + + +def _is_supported_dtype(lhs_dtype, rhs_dtype): + """Check if dtypes in the given workload are supported by cuBLAS BYOC.""" + return (lhs_dtype == "float16" and rhs_dtype == "float16") or ( + lhs_dtype == "float32" and rhs_dtype == "float32" + ) + + +def _check_matmul(context: PatternCheckContext) -> bool: + lhs = context.annotated_expr["lhs"] + rhs = context.annotated_expr["rhs"] + + lhs_dtype = lhs.struct_info.dtype + rhs_dtype = rhs.struct_info.dtype + if not _is_supported_dtype(lhs_dtype, rhs_dtype): + return False + + lhs_shape = lhs.struct_info.shape.values + rhs_shape = rhs.struct_info.shape.values + + if not isinstance(lhs_shape[-1], (tvm.tir.expr.IntImm, int)): + # Reduction axis must be constant + return False + + lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1) + rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1) + + # cuBLASLt does not seem to support batched GEMM with one of matrices having + # one batch (with batch_stride 0). So for batched GEMM, the two batch counts + # must be equal. + return ( + (lhs_batches == 1 and rhs_batches == 1) + or isinstance(lhs_batches, tvm.tir.Var) + or isinstance(rhs_batches, tvm.tir.Var) + or (int(lhs_batches) == int(rhs_batches)) + ) + + +register_patterns( + [ + ( + "cublas.matmul", + *make_matmul_pattern( + with_bias=False, + ), + _check_matmul, + ), + ( + "cublas.matmul_bias", + *make_matmul_pattern( + with_bias=True, + ), + _check_matmul, + ), + ( + "cublas.matmul_bias_relu", + *make_matmul_pattern( + with_bias=True, + activation="relax.nn.relu", + ), + _check_matmul, + ), + ( + "cublas.matmul_bias_gelu", + *make_matmul_pattern( + with_bias=True, + activation="relax.nn.gelu", + ), + _check_matmul, + ), + ( + "cublas.matmul_transposed", + *make_matmul_pattern( + with_bias=False, + transposed_rhs=True, + ), + _check_matmul, + ), + ( + "cublas.matmul_transposed_bias", + *make_matmul_pattern( + with_bias=True, + transposed_rhs=True, + ), + _check_matmul, + ), + ( + "cublas.matmul_transposed_bias_relu", + *make_matmul_pattern( + with_bias=True, + activation="relax.nn.relu", + transposed_rhs=True, + ), + _check_matmul, + ), + ( + "cublas.matmul_transposed_bias_gelu", + *make_matmul_pattern( + with_bias=True, + activation="relax.nn.gelu", + transposed_rhs=True, + ), + _check_matmul, + ), + ] +) + + +def partition_for_cublas(mod): + """ + Partition the input module into cuBLAS-supported subgraphs. + + Parameters + ---------- + mod: tvm.IRModule + The IRModule to be partitioned. + + Returns + ------- + mod: tvm.IRModule + The resulting IRModule, containing partitioned subgraphs to be + offloaded to the cuBLAS backend. + """ + + patterns = get_patterns_with_prefix("cublas") + return transform.FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=True)(mod) diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index c03c913d63cd..856cd4d7871f 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -17,11 +17,10 @@ """Pattern table for CUTLASS backend""" -from typing import Mapping, Optional, Sequence, Tuple +from typing import Mapping, Sequence -import tvm from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul -from tvm.relax import DataflowVar, ShapeExpr, Var, transform +from tvm.relax import DataflowVar, Var, transform from tvm.relax.transform import PatternCheckContext from ..pattern_registry import get_patterns_with_prefix, register_patterns @@ -33,16 +32,6 @@ ) -def _get_static_shape(shape: ShapeExpr) -> Optional[Tuple[int]]: - result = [] - for dim in shape.values: - if isinstance(dim, tvm.tir.expr.IntImm): - result.append(int(dim)) - else: - return None - return result - - def _is_supported_dtype(lhs_dtype, rhs_dtype): """Check if dtypes in the given workload are supported by CUTLASS.""" return ( diff --git a/python/tvm/relax/testing/__init__.py b/python/tvm/relax/testing/__init__.py index a6e3a9425147..4256ebc3be89 100644 --- a/python/tvm/relax/testing/__init__.py +++ b/python/tvm/relax/testing/__init__.py @@ -20,3 +20,4 @@ from .nn import * from .relay_translator import * from .ast_printer import dump_ast +from .matmul import * diff --git a/python/tvm/relax/testing/matmul.py b/python/tvm/relax/testing/matmul.py new file mode 100644 index 000000000000..bac6fc6c9ae8 --- /dev/null +++ b/python/tvm/relax/testing/matmul.py @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Utilities to construct matmul workloads.""" +import tvm +from tvm.script import relax as R +from tvm.script.ir_builder import IRBuilder +from tvm.script.ir_builder import relax as relax_builder + + +def get_relax_matmul_module( + x_shape, + y_shape, + dtype, + transposed_y=False, + with_bias=False, + activation=None, + residual_bin_op=None, + residual_activation=None, +): + """Create a matmul op followd by epilogue operations.""" + if transposed_y: + n = y_shape[-2] + else: + n = y_shape[-1] + + with IRBuilder() as builder: + with relax_builder.function(): + R.func_name("main") + x = R.arg("x", R.Tensor(x_shape, dtype)) + y = R.arg("y", R.Tensor(y_shape, dtype)) + if with_bias: + bias = R.arg("bias", R.Tensor((n,), dtype)) + + with R.dataflow() as frame: + if transposed_y: + axes = list(range(len(y_shape) - 2)) + [-1, -2] + y = R.emit(R.permute_dims(y, axes=axes)) + result = R.emit(R.matmul(x, y, out_dtype=dtype)) + if with_bias: + result = R.emit(result + bias) + if activation is not None: + result = R.emit(activation(result)) + if residual_bin_op is not None: + result = R.emit(residual_bin_op(result, x)) + if residual_activation is not None: + result = R.emit(residual_activation(result)) + R.output(result) + + R.func_ret_value(frame.output_vars[0]) + + func = builder.get() + return tvm.IRModule({"main": func}) diff --git a/src/relax/backend/contrib/cublas/codegen.cc b/src/relax/backend/contrib/cublas/codegen.cc new file mode 100644 index 000000000000..e573d9a12385 --- /dev/null +++ b/src/relax/backend/contrib/cublas/codegen.cc @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/contrib/cublas/codegen.cc + * \brief Implementation of the CUBLAS JSON serializer. + */ +#include + +#include + +#include "../codegen_json/codegen_json.h" +#include "../utils.h" + +namespace tvm { +namespace relax { +namespace contrib { + +using JSONGraphNode = tvm::runtime::json::JSONGraphNode; +using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; +using JSONSerializer = backend::contrib::JSONSerializer; +using backend::contrib::NodeEntries; + +class CublasJSONSerializer : public JSONSerializer { + public: + CublasJSONSerializer(Map constant_names, Map bindings) + : JSONSerializer(constant_names), bindings_(bindings) {} + + using JSONSerializer::VisitExpr_; + + NodeEntries VisitExpr_(const CallNode* call_node) final { + const auto* fn_var = call_node->op.as(); + ICHECK(fn_var); + const auto fn = Downcast(bindings_[GetRef(fn_var)]); + ICHECK(fn.defined()) << "Expects the callee to be a function."; + + auto composite_opt = fn->GetAttr(attr::kComposite); + ICHECK(composite_opt.defined()) << "Only composite functions are supported."; + + std::string composite_name = composite_opt.value(); + + NodeEntries inputs_tmp; + for (const auto& arg : call_node->args) { + auto res = VisitExpr(arg); + inputs_tmp.insert(inputs_tmp.end(), res.begin(), res.end()); + } + + ICHECK(inputs_tmp.size() <= 3); + NodeEntries inputs(inputs_tmp.size()); + + auto arg_idx = backend::ExtractArgIdx(composite_name, fn); + inputs[0] = inputs_tmp[arg_idx["lhs"]->value]; + inputs[1] = inputs_tmp[arg_idx["rhs"]->value]; + if (inputs_tmp.size() == 3) { + inputs[2] = inputs_tmp[arg_idx["bias"]->value]; + } + + auto node = std::make_shared(composite_name, /* name_ */ + "kernel", /* op_type_ */ + inputs, 1 /* num_outputs_ */); + + const CallNode* root_call = backend::GetOpInFunction(fn, "relax.matmul"); + SetCallNodeAttribute(node, root_call); + return AddNode(node, GetRef(call_node)); + } + + private: + /*! \brief The bindings to look up composite functions. */ + Map bindings_; +}; + +Array CublasCompiler(Array functions, Map /*unused*/, + Map constant_names) { + Array compiled_functions; + + for (const auto& func : functions) { + CublasJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); + serializer.serialize(func); + auto graph_json = serializer.GetJSON(); + auto constant_names = serializer.GetConstantNames(); + const auto* pf = runtime::Registry::Get("runtime.CublasJSONRuntimeCreate"); + ICHECK(pf != nullptr) << "Cannot find CUBLAS runtime module create function."; + auto func_name = GetExtSymbol(func); + compiled_functions.push_back((*pf)(func_name, graph_json, constant_names)); + } + + return compiled_functions; +} + +TVM_REGISTER_GLOBAL("relax.ext.cublas").set_body_typed(CublasCompiler); + +} // namespace contrib +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/contrib/utils.cc b/src/relax/backend/contrib/utils.cc new file mode 100644 index 000000000000..565b7769f0ad --- /dev/null +++ b/src/relax/backend/contrib/utils.cc @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "utils.h" + +#include +#include +#include + +#include + +#include "../pattern_registry.h" + +namespace tvm { +namespace relax { +namespace backend { + +Map ExtractArgIdx(String pattern_name, Function f) { + Map arg_idx; + auto pattern = backend::GetPattern(pattern_name); + ICHECK(pattern) << "Unsupported op_type " << pattern_name; + + auto bindings = AnalyzeVar2Value(f); + auto inner_body = Downcast(f->body)->body; + auto matched_expr = relax::ExtractMatchedExpr(pattern.value()->pattern, inner_body, bindings); + ICHECK(matched_expr); + + auto find_index = [](const Array& params, Var v) -> std::optional { + for (size_t i = 0; i < params.size(); ++i) { + if (params[i] == v) { + return i; + } + } + return std::nullopt; + }; + + for (const auto& [name, pat] : pattern.value()->annotation_patterns) { + auto exp = matched_expr.value()[pat]; + if (auto arg_var = exp.as()) { + if (auto idx = find_index(f->params, GetRef(arg_var))) { + arg_idx.Set(name, IntImm(DataType::Int(64), *idx)); + } + } + } + + return arg_idx; +} + +TVM_REGISTER_GLOBAL("relax.contrib.extract_arg_idx").set_body_typed(ExtractArgIdx); + +} // namespace backend +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/contrib/utils.h b/src/relax/backend/contrib/utils.h index 4190ad66b6df..ee1240aaed2e 100644 --- a/src/relax/backend/contrib/utils.h +++ b/src/relax/backend/contrib/utils.h @@ -120,6 +120,19 @@ inline const CallNode* GetOpInFunction(Function f, const std::string& op_name) { return nullptr; } +/*! + * \brief Extract indices of the argument patterns in the function parameter list. + * Each composite function pattern can register a mapping between variable names and the + * corresponding patterns. This function tells at which index a given parameter + * in the function pattern, identified by its name, appears in the partitioned function parameter + * list. + * \param pattern_name The name the composite function pattern. + * \param f The function partitioned according to the function pattern. + * \return A mapping between variable pattern names and their positions in the partitioned + * function parameter list. + */ +Map ExtractArgIdx(String pattern_name, Function f); + } // namespace backend } // namespace relax } // namespace tvm diff --git a/src/runtime/contrib/cblas/gemm_common.h b/src/runtime/contrib/cblas/gemm_common.h index 4724b14bffa1..af073da9ba1a 100644 --- a/src/runtime/contrib/cblas/gemm_common.h +++ b/src/runtime/contrib/cblas/gemm_common.h @@ -35,7 +35,7 @@ namespace tvm { namespace contrib { using namespace runtime; -inline int ColumnStride(DLTensor* tensor) { +inline int ColumnStride(const DLTensor* tensor) { // If the tensor itself is transposed then it will have strides // backward from what we expect. Regardless, the max of the strides // (the other stride is 1) is the column stride. @@ -46,7 +46,7 @@ inline int ColumnStride(DLTensor* tensor) { } } -inline int ElementStride(DLTensor* tensor) { +inline int ElementStride(const DLTensor* tensor) { if (tensor->strides) { return std::min(tensor->strides[0], tensor->strides[1]); } else { @@ -55,13 +55,17 @@ inline int ElementStride(DLTensor* tensor) { } // Reversed strides indicates an in-place transpose operation. -inline bool IsInPlaceTransposed(DLTensor* tensor) { +inline bool IsInPlaceTransposed(const DLTensor* tensor) { return tensor->strides && (tensor->strides[1] > tensor->strides[0]); } -inline int RowCount(DLTensor* tensor, bool trans) { return tensor->shape[trans ? 1 : 0]; } +inline int RowCount(const DLTensor* tensor, bool trans, int batch_offset = 0) { + return tensor->shape[batch_offset + (trans ? 1 : 0)]; +} -inline int ColumnCount(DLTensor* tensor, bool trans) { return tensor->shape[trans ? 0 : 1]; } +inline int ColumnCount(const DLTensor* tensor, bool trans, int batch_offset = 0) { + return tensor->shape[batch_offset + (trans ? 0 : 1)]; +} // Call a column major blas. Note that data is stored in tvm as row // major, so this we switch the arguments. @@ -159,7 +163,7 @@ inline int ColumnStride3D(DLTensor* tensor) { return tensor->shape[2]; } } -inline int ElementStride3D(DLTensor* tensor) { +inline int ElementStride3D(const DLTensor* tensor) { if (tensor->strides) { return std::min(tensor->strides[1], tensor->strides[2]); } else { diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index ee0f50e3495b..b49f15008cff 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -24,6 +24,7 @@ #include #include +#include "../../3rdparty/compiler-rt/builtin_fp16.h" #include "../cblas/gemm_common.h" #include "cublas_utils.h" @@ -133,6 +134,120 @@ bool CheckMixPrecisionType(DLDataType in_dtype, DLDataType out_dtype, bool int_s int roundoff(int v, int d) { return (v + d - 1) / d * d; } #if CUDART_VERSION >= 10010 + +void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, const DLTensor* bias, + const DLTensor* C, bool transa, bool transb, cublasLtEpilogue_t epilogue) { + ICHECK(TypeEqual(A->dtype, B->dtype)); + // Reversed strides indicates an in-place transpose operation. + transa = IsInPlaceTransposed(A) ? !transa : transa; + transb = IsInPlaceTransposed(B) ? !transb : transb; + + auto compute_type = CUBLAS_COMPUTE_32F; + auto scale_type = CUDA_R_32F; + cudaDataType_t ab_type = CUDA_R_32F; + cudaDataType_t c_type = CUDA_R_32F; + float one_fp32 = 1.0; + float zero_fp32 = 0.0; + auto one_fp16 = __truncXfYf2__(1.0); + auto zero_fp16 = __truncXfYf2__(0.0); + void* alpha = &one_fp32; + void* beta = &zero_fp32; + + if (A->dtype.bits == 16 && A->dtype.code == kDLFloat) { + ab_type = CUDA_R_16F; + } + + if (C->dtype.bits == 16 && C->dtype.code == kDLFloat) { + c_type = CUDA_R_16F; + compute_type = CUBLAS_COMPUTE_16F; + scale_type = CUDA_R_16F; + alpha = &one_fp16; + beta = &zero_fp16; + } + + cublasLtMatmulDesc_t op_desc; + cublasOperation_t op_transa = CUBLASBooleanToTranspose(transa); + cublasOperation_t op_transb = CUBLASBooleanToTranspose(transb); + + CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&op_desc, compute_type, scale_type)); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_TRANSA, + &op_transb, sizeof(op_transa))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_TRANSB, + &op_transa, sizeof(op_transb))); + + if (bias != nullptr) { + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, + &bias->data, sizeof(float*))); + } + + if (epilogue != CUBLASLT_EPILOGUE_DEFAULT) { + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, + &epilogue, sizeof(epilogue))); + } + + int batch_offset_A = A->ndim - 2; + int batch_offset_B = B->ndim - 2; + + int M = ColumnCount(B, transb, batch_offset_B); + int N = RowCount(A, transa, batch_offset_A); + int K = ColumnCount(A, transa, batch_offset_A); + + int lda = transb ? K : M; + int ldb = transa ? N : K; + int ldc = M; + + cublasLtMatrixLayout_t A_desc, B_desc, C_desc; + CHECK_CUBLAS_ERROR( + cublasLtMatrixLayoutCreate(&A_desc, ab_type, !transb ? M : K, !transb ? K : M, lda)); + CHECK_CUBLAS_ERROR( + cublasLtMatrixLayoutCreate(&B_desc, ab_type, !transa ? K : N, !transa ? N : K, ldb)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&C_desc, c_type, M, N, ldc)); + + if (A->ndim != 2 || B->ndim != 2) { + auto get_batch_count = [](int64_t* shape, int batch_offset) { + int64_t count = 1; + for (int i = 0; i < batch_offset; ++i) { + count *= shape[i]; + } + return count; + }; + auto set_batch = [](cublasLtMatrixLayout_t mat_desc, int batch_count, int64_t batch_stride) { + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + mat_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(batch_count))); + CHECK_CUBLAS_ERROR( + cublasLtMatrixLayoutSetAttribute(mat_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &batch_stride, sizeof(batch_stride))); + }; + + int batch_count_A = get_batch_count(A->shape, batch_offset_A); + int batch_count_B = get_batch_count(B->shape, batch_offset_B); + int batch_count_C = get_batch_count(C->shape, C->ndim - 2); + int64_t batch_stride_A = M * K; + int64_t batch_stride_B = K * N; + int64_t batch_stride_C = M * N; + + // cuBLASLt does not seem to support batched GEMM with one of matrices having + // one batch (with batch_stride 0). + ICHECK_EQ(batch_count_A, batch_count_B); + + set_batch(A_desc, batch_count_A, batch_stride_A); + set_batch(B_desc, batch_count_B, batch_stride_B); + set_batch(C_desc, batch_count_C, batch_stride_C); + } + + auto A_data = static_cast(A->data) + A->byte_offset; + auto B_data = static_cast(B->data) + B->byte_offset; + auto C_data = static_cast(C->data) + C->byte_offset; + + CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl, op_desc, alpha, B_data, A_desc, A_data, B_desc, beta, + C_data, C_desc, C_data, C_desc, nullptr, nullptr, 0, nullptr)); + + cublasLtMatmulDescDestroy(op_desc); + cublasLtMatrixLayoutDestroy(A_desc); + cublasLtMatrixLayoutDestroy(B_desc); + cublasLtMatrixLayoutDestroy(C_desc); +} + inline void CallLtIgemm(TVMArgs args, TVMRetValue* ret, cublasLtHandle_t hdl) { DLTensor* A = args[0]; DLTensor* B = args[1]; @@ -172,7 +287,6 @@ inline void CallLtIgemm(TVMArgs args, TVMRetValue* ret, cublasLtHandle_t hdl) { auto B_data = reinterpret_cast(static_cast(B->data) + B->byte_offset); auto C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); - cublasOperation_t opTranspose = CUBLAS_OP_T; cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; cublasLtOrder_t order_COL4_4R2_8C = CUBLASLT_ORDER_COL4_4R2_8C; cublasLtMatmulDesc_t operationDesc = nullptr; @@ -181,8 +295,6 @@ inline void CallLtIgemm(TVMArgs args, TVMRetValue* ret, cublasLtHandle_t hdl) { #else CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&operationDesc, CUDA_R_32I)); #endif - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, - &opTranspose, sizeof(opTranspose))); cublasOperation_t opTransA = CUBLASBooleanToTranspose(transa); cublasOperation_t opTransB = CUBLASBooleanToTranspose(transb); CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc b/src/runtime/contrib/cublas/cublas_json_runtime.cc new file mode 100644 index 000000000000..8afccb27302d --- /dev/null +++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/cublas/cublas_json_runtime.cc + * \brief A simple JSON runtime for CUBLAS. + */ + +#include +#include + +#include +#include +#include +#include + +#include "../json/json_node.h" +#include "../json/json_runtime.h" +#include "cublas_utils.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +using namespace tvm::runtime; +using namespace tvm::runtime::json; + +class CublasJSONRuntime : public JSONRuntimeBase { + public: + CublasJSONRuntime(const std::string& symbol_name, const std::string& graph_json, + const Array const_names) + : JSONRuntimeBase(symbol_name, graph_json, const_names) {} + + void Init(const Array& consts) override {} + + void Run() override { + // TODO(masahi): Reuse the same handle across different subgraphs + cublasLtHandle_t handle; + cublasLtCreate(&handle); + + for (size_t i = 0; i < nodes_.size(); ++i) { + const auto& node = nodes_[i]; + if (node.GetOpType() == "kernel") { + auto op_name = node.GetOpName(); + uint32_t output_eid = EntryID(outputs_[0]); + auto out_ptr = data_entry_[output_eid]; + bool transa = false; + bool transb = false; + cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; + + if (op_name.find("transposed") != std::string::npos) { + transb = true; + } + + if (op_name.find("relu") != std::string::npos) { + epilogue = CUBLASLT_EPILOGUE_RELU_BIAS; + } else if (op_name.find("gelu") != std::string::npos) { + epilogue = CUBLASLT_EPILOGUE_GELU_BIAS; + } else if (op_name.find("bias") != std::string::npos) { + epilogue = CUBLASLT_EPILOGUE_BIAS; + } + + auto get_inputs = [this](const JSONGraphNode& node, bool has_bias) { + const DLTensor* bias = nullptr; + if (has_bias) { + bias = GetInput(node, 2); + } + return std::make_tuple(GetInput(node, 0), GetInput(node, 1), bias); + }; + + auto [a_ptr, b_ptr, bias_ptr] = get_inputs(node, epilogue != CUBLASLT_EPILOGUE_DEFAULT); + + tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, transa, transb, + epilogue); + } + } + cublasLtDestroy(handle); + } + + private: + const DLTensor* GetInput(const JSONGraphNode& node, const int idx) { + ICHECK_LT(idx, node.GetInputs().size()); + auto eid = EntryID(node.GetInputs()[idx]); + ICHECK(eid < data_entry_.size()); + return data_entry_[eid]; + } +}; + +runtime::Module CublasJSONRuntimeCreate(String symbol_name, String graph_json, + const Array& const_names) { + auto n = make_object(symbol_name, graph_json, const_names); + return runtime::Module(n); +} + +TVM_REGISTER_GLOBAL("runtime.CublasJSONRuntimeCreate").set_body_typed(CublasJSONRuntimeCreate); + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_cublas_json") + .set_body_typed(JSONRuntimeBase::LoadFromBinary); + +} // namespace contrib +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/cublas/cublas_utils.h b/src/runtime/contrib/cublas/cublas_utils.h index 62863b8f7bc8..ac03b1210366 100644 --- a/src/runtime/contrib/cublas/cublas_utils.h +++ b/src/runtime/contrib/cublas/cublas_utils.h @@ -104,6 +104,12 @@ inline cudaDataType_t GetCudaDataType(DLDataType type) { } LOG(FATAL) << "Unsupported cuda type"; } + +/*! \brief Execute matrix multiply followed by the specified epilogue, using cuBLASLt. */ +void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, const DLTensor* bias, + const DLTensor* C, bool transa, bool transb, + cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT); + } // namespace contrib } // namespace tvm diff --git a/tests/python/relax/test_codegen_cublas.py b/tests/python/relax/test_codegen_cublas.py new file mode 100644 index 000000000000..023054256efd --- /dev/null +++ b/tests/python/relax/test_codegen_cublas.py @@ -0,0 +1,156 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import pytest + +import tvm +import tvm.testing +import tvm.topi.testing +from tvm import relax +from tvm.relax.backend.contrib.cublas import partition_for_cublas +from tvm.relax.testing import get_relax_matmul_module +from tvm.script import relax as R + + +@pytest.fixture(autouse=True) +def reset_seed(): + np.random.seed(0) + + +has_cublas = tvm.get_global_func("relax.ext.cublas", True) + +cublas_enabled = pytest.mark.skipif( + not has_cublas, + reason="CUBLAS not enabled.", +) + +pytestmark = [cublas_enabled] + + +def build_and_run(mod, inputs_np, target, legalize=False): + if legalize: + mod = relax.transform.LegalizeOps()(mod) + + dev = tvm.device(target, 0) + ex = relax.build(mod, target) + vm = relax.VirtualMachine(ex, dev) + f = vm["main"] + inputs = [tvm.nd.array(inp, dev) for inp in inputs_np] + return f(*inputs).numpy() + + +def get_result_with_relax_cublas_offload(mod, *args): + mod = partition_for_cublas(mod) + mod = relax.transform.RunCodegen()(mod) + + return build_and_run(mod, args, "cuda") + + +def _to_concrete_shape(symbolic_shape, var_table): + result = [] + for dim in symbolic_shape: + if not isinstance(dim, tvm.tir.expr.Var): + result.append(dim) + continue + + if dim not in var_table: + var_table[dim] = np.random.randint(10, 50) + result.append(var_table[dim]) + + return tuple(result) + + +_vars = { + "a": tvm.tir.expr.Var("a", "int64"), + "b": tvm.tir.expr.Var("b", "int64"), +} + + +_epilogue_table = { + "none": (False, None), + "bias": (True, None), + "relu": (True, R.nn.relu), + "gelu": (True, R.nn.gelu), +} + + +@pytest.mark.parametrize( + "x_shape, y_shape, transpose_y, epilogue", + [ + # Regular + ((8, 8), (8, 8), False, "none"), + ((_vars["a"], 6), (6, 16), False, "bias"), + # Transposed + ((4, 16), (16, 128), True, "relu"), + ((35, 8), (8, 8), True, "gelu"), + # # 3D x 3D + ((6, 32, 8), (6, 8, 10), False, "bias"), + ((6, 32, 8), (6, 8, 10), True, "none"), + ((_vars["a"], 32, 8), (_vars["a"], 8, 10), True, "gelu"), + # ND x ND + ((5, 3, 32, 8), (5, 3, 8, 10), True, "relu"), + ], +) +@pytest.mark.parametrize( + "dtype", + [ + "float16", + "float32", + ], +) +def test_matmul_offload( + x_shape, + y_shape, + transpose_y, + epilogue, + dtype, +): + with_bias, activation = _epilogue_table[epilogue] + var_table = {} + concrete_x_shape = _to_concrete_shape(x_shape, var_table) + concrete_y_shape = _to_concrete_shape(y_shape, var_table) + x = np.random.randn(*concrete_x_shape).astype(dtype) + y = np.random.randn(*concrete_y_shape).astype(dtype) + + if transpose_y: + y = np.swapaxes(y, -2, -1) + y_shape = (*y_shape[:-2], y_shape[-1], y_shape[-2]) + + if with_bias: + bias = np.random.randn(concrete_y_shape[-1]).astype(dtype) + args = (x, y, bias) + else: + bias = None + args = (x, y) + + mod = get_relax_matmul_module( + x_shape, + y_shape, + dtype, + with_bias=with_bias, + transposed_y=transpose_y, + activation=activation, + ) + + out = get_result_with_relax_cublas_offload(mod, *args) + ref = build_and_run(mod, args, "llvm", legalize=True) + + tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index c8ca44311de5..b9ba4f4dc9af 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -23,8 +23,8 @@ from tvm import relax from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul from tvm.contrib.pickle_memoize import memoize -from tvm.relax.backend import get_patterns_with_prefix from tvm.relax.backend.contrib.cutlass import partition_for_cutlass +from tvm.relax.testing import get_relax_matmul_module from tvm.script import relax as R from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import relax as relax_builder @@ -96,9 +96,6 @@ def build_and_run(mod, inputs_np, target, legalize=False): def get_result_with_relax_cutlass_offload(mod, *args, assert_all_bindings_fused=True): - patterns = [(entry.name, entry.pattern) for entry in get_patterns_with_prefix("cutlass")] - assert len(patterns) != 0, "Cannot find cutlass patterns" - mod = partition_for_cutlass(mod) if assert_all_bindings_fused: @@ -168,50 +165,6 @@ def get_relax_conv2d_module( return tvm.IRModule({"main": func}) -def get_relax_matmul_module( - x_shape, - y_shape, - dtype, - transposed_y=False, - with_bias=False, - activation=None, - residual_bin_op=None, - residual_activation=None, -): - if transposed_y: - n = y_shape[-2] - else: - n = y_shape[-1] - - with IRBuilder() as builder: - with relax_builder.function(): - R.func_name("main") - x = R.arg("x", R.Tensor(x_shape, dtype)) - y = R.arg("y", R.Tensor(y_shape, dtype)) - if with_bias: - bias = R.arg("bias", R.Tensor((n,), dtype)) - - with R.dataflow() as frame: - if transposed_y: - axes = list(range(len(y_shape) - 2)) + [-1, -2] - y = R.emit(R.permute_dims(y, axes=axes)) - result = R.emit(R.matmul(x, y, out_dtype=dtype)) - if with_bias: - result = R.emit(result + bias) - if activation is not None: - result = R.emit(activation(result)) - if residual_bin_op is not None: - result = R.emit(residual_bin_op(result, x)) - if residual_activation is not None: - result = R.emit(residual_activation(result)) - R.output(result) - - R.func_ret_value(frame.output_vars[0]) - - func = builder.get() - return tvm.IRModule({"main": func}) - - def _to_concrete_shape(symbolic_shape, var_table): result = [] for dim in symbolic_shape: