From 0683120a091b0d1a73555cdb0c27319471d58d30 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 4 Mar 2023 21:50:03 +0900 Subject: [PATCH 01/15] stub --- src/relax/backend/contrib/cublas/codegen.cc | 105 ++++++++++++++++++ .../contrib/cublas/cublas_json_runtime.cc | 97 ++++++++++++++++ 2 files changed, 202 insertions(+) create mode 100644 src/relax/backend/contrib/cublas/codegen.cc create mode 100644 src/runtime/contrib/cublas/cublas_json_runtime.cc diff --git a/src/relax/backend/contrib/cublas/codegen.cc b/src/relax/backend/contrib/cublas/codegen.cc new file mode 100644 index 000000000000..8df55191153c --- /dev/null +++ b/src/relax/backend/contrib/cublas/codegen.cc @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/contrib/cublas/codegen.cc + * \brief Implementation of the CUBLAS JSON serializer. + */ +#include + +#include + +#include "../codegen_json/codegen_json.h" +#include "../utils.h" + +namespace tvm { +namespace relax { +namespace contrib { + +using JSONGraphNode = tvm::runtime::json::JSONGraphNode; +using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; +using JSONSerializer = backend::contrib::JSONSerializer; +using backend::contrib::NodeEntries; + +class CUBLASJSONSerializer : public JSONSerializer { + public: + CUBLASJSONSerializer(Map constant_names, Map bindings) + : JSONSerializer(constant_names), bindings_(bindings) {} + + using JSONSerializer::VisitExpr_; + + NodeEntries VisitExpr_(const CallNode* call_node) final { + const auto* fn_var = call_node->op.as(); + ICHECK(fn_var); + const auto fn = Downcast(bindings_[GetRef(fn_var)]); + ICHECK(fn.defined()) << "Expects the callee to be a function."; + + auto composite_opt = fn->GetAttr(attr::kComposite); + ICHECK(composite_opt.defined()) << "Only composite functions are supported."; + + std::string composite_name = composite_opt.value(); + + NodeEntries inputs; + for (const auto& arg : call_node->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + auto node = std::make_shared(composite_name, /* name_ */ + "kernel", /* op_type_ */ + inputs, 1 /* num_outputs_ */); + + const CallNode* root_call = nullptr; + if (composite_name.find("conv2d") != std::string::npos) { + root_call = backend::GetOpInFunction(fn, "relax.nn.conv2d"); + } else { + LOG(FATAL) << "Unimplemented pattern: " << composite_name; + } + + SetCallNodeAttribute(node, root_call); + return AddNode(node, GetRef(call_node)); + } + + private: + /*! \brief The bindings to look up composite functions. */ + Map bindings_; +}; + +Array CublasCompiler(Array functions, Map /*unused*/, + Map constant_names) { + Array compiled_functions; + + for (const auto& func : functions) { + CUBLASJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); + serializer.serialize(func); + auto graph_json = serializer.GetJSON(); + auto constant_names = serializer.GetConstantNames(); + const auto* pf = runtime::Registry::Get("runtime.CUBLASJSONRuntimeCreate"); + ICHECK(pf != nullptr) << "Cannot find CUBLAS runtime module create function."; + auto func_name = GetExtSymbol(func); + compiled_functions.push_back((*pf)(func_name, graph_json, constant_names)); + } + + return compiled_functions; +} + +TVM_REGISTER_GLOBAL("relax.ext.cublas").set_body_typed(CublasCompiler); + +} // namespace contrib +} // namespace relax +} // namespace tvm diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc b/src/runtime/contrib/cublas/cublas_json_runtime.cc new file mode 100644 index 000000000000..9866225db8df --- /dev/null +++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/cublas/cublas_json_runtime.cc + * \brief A simple JSON runtime for CUBLAS. + */ + +#include +#include + +#include +#include +#include +#include + +#include "../json/json_node.h" +#include "../json/json_runtime.h" + +// TODO(@apeskov): Have to mute warning from cublas headers. +// -Wzero-as-null-pointer-constant and -Wdocumentation-unknown-command + +#include "cublas_tensor_requisite.h" +#include "cublas_utils.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +using namespace tvm::runtime; +using namespace tvm::runtime::json; + +class CUBLASJSONRuntime : public JSONRuntimeBase { + public: + CUBLASJSONRuntime(const std::string& symbol_name, const std::string& graph_json, + const Array const_names) + : JSONRuntimeBase(symbol_name, graph_json, const_names), + next_unique_eid_offset_(data_entry_.size()), + run_arg_eid_(input_var_eid_) { + } + + /* Unused stub implementation */ + void Run() override { LOG(FATAL) << "Unreachable code"; } + + /* Thread safe implementation of Run. Keep runtime instance immutable */ + void Run(const TVMArgs& args) const { + } + + /* Override GetFunction to reimplement Run method */ + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) override { + if (this->symbol_name_ == name) { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + ICHECK(this->initialized_) << "The module has not been initialized"; + + ICHECK_EQ(args.size(), input_var_eid_.size() + outputs_.size()) + << "Found mismatch in the number of provided data entries and required."; + + Run(args); + }); + } else { + return JSONRuntimeBase::GetFunction(name, sptr_to_self); + } + } + + private: +}; + +runtime::Module CublasJSONRuntimeCreate(String symbol_name, String graph_json, + const Array& const_names) { + auto n = make_object(symbol_name, graph_json, const_names); + return runtime::Module(n); +} + +TVM_REGISTER_GLOBAL("runtime.CublasJSONRuntimeCreate").set_body_typed(CubblasJSONRuntimeCreate); + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_cublas_json") + .set_body_typed(JSONRuntimeBase::LoadFromBinary); + +} // namespace contrib +} // namespace runtime +} // namespace tvm From 91b8c0cbce96e8117fc7379dceda1bbfd398757b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 9 Mar 2023 10:24:04 +0900 Subject: [PATCH 02/15] fixed build --- cmake/modules/CUDA.cmake | 4 ++-- .../contrib/cublas/cublas_json_runtime.cc | 20 +++++++++---------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake index 96d5922e84d9..1502c4f8bc80 100644 --- a/cmake/modules/CUDA.cmake +++ b/cmake/modules/CUDA.cmake @@ -50,8 +50,8 @@ if(USE_CUDA) if(USE_CUBLAS) message(STATUS "Build with cuBLAS support") - tvm_file_glob(GLOB CUBLAS_RELAY_CONTRIB_SRC src/relay/backend/contrib/cublas/*.cc) - list(APPEND COMPILER_SRCS ${CUBLAS_RELAY_CONTRIB_SRC}) + tvm_file_glob(GLOB CUBLAS_CONTRIB_SRC src/relay/backend/contrib/cublas/*.cc src/relax/backend/contrib/cublas/*.cc) + list(APPEND COMPILER_SRCS ${CUBLAS_CONTRIB_SRC}) tvm_file_glob(GLOB CONTRIB_CUBLAS_SRCS src/runtime/contrib/cublas/*.cc) list(APPEND RUNTIME_SRCS ${CONTRIB_CUBLAS_SRCS}) list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUBLAS_LIBRARY}) diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc b/src/runtime/contrib/cublas/cublas_json_runtime.cc index 9866225db8df..8a4bdfbf15b7 100644 --- a/src/runtime/contrib/cublas/cublas_json_runtime.cc +++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc @@ -36,7 +36,6 @@ // TODO(@apeskov): Have to mute warning from cublas headers. // -Wzero-as-null-pointer-constant and -Wdocumentation-unknown-command -#include "cublas_tensor_requisite.h" #include "cublas_utils.h" namespace tvm { @@ -46,21 +45,20 @@ namespace contrib { using namespace tvm::runtime; using namespace tvm::runtime::json; -class CUBLASJSONRuntime : public JSONRuntimeBase { +class CublasJSONRuntime : public JSONRuntimeBase { public: - CUBLASJSONRuntime(const std::string& symbol_name, const std::string& graph_json, - const Array const_names) - : JSONRuntimeBase(symbol_name, graph_json, const_names), - next_unique_eid_offset_(data_entry_.size()), - run_arg_eid_(input_var_eid_) { + CublasJSONRuntime(const std::string& symbol_name, const std::string& graph_json, + const Array const_names) + : JSONRuntimeBase(symbol_name, graph_json, const_names) {} + + void Init(const Array& consts) override { } /* Unused stub implementation */ void Run() override { LOG(FATAL) << "Unreachable code"; } /* Thread safe implementation of Run. Keep runtime instance immutable */ - void Run(const TVMArgs& args) const { - } + void Run(const TVMArgs& args) const {} /* Override GetFunction to reimplement Run method */ PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) override { @@ -82,12 +80,12 @@ class CUBLASJSONRuntime : public JSONRuntimeBase { }; runtime::Module CublasJSONRuntimeCreate(String symbol_name, String graph_json, - const Array& const_names) { + const Array& const_names) { auto n = make_object(symbol_name, graph_json, const_names); return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.CublasJSONRuntimeCreate").set_body_typed(CubblasJSONRuntimeCreate); +TVM_REGISTER_GLOBAL("runtime.CublasJSONRuntimeCreate").set_body_typed(CublasJSONRuntimeCreate); TVM_REGISTER_GLOBAL("runtime.module.loadbinary_cublas_json") .set_body_typed(JSONRuntimeBase::LoadFromBinary); From c4939622b0e9518b7d32cf8fd0a8d5016bb2faca Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 9 Mar 2023 10:36:30 +0900 Subject: [PATCH 03/15] test stub --- src/relax/backend/contrib/cublas/codegen.cc | 12 +- tests/python/relax/test_codegen_cublas.py | 208 ++++++++++++++++++++ 2 files changed, 214 insertions(+), 6 deletions(-) create mode 100644 tests/python/relax/test_codegen_cublas.py diff --git a/src/relax/backend/contrib/cublas/codegen.cc b/src/relax/backend/contrib/cublas/codegen.cc index 8df55191153c..8fb23f29c0ff 100644 --- a/src/relax/backend/contrib/cublas/codegen.cc +++ b/src/relax/backend/contrib/cublas/codegen.cc @@ -37,9 +37,9 @@ using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; using JSONSerializer = backend::contrib::JSONSerializer; using backend::contrib::NodeEntries; -class CUBLASJSONSerializer : public JSONSerializer { +class CublasJSONSerializer : public JSONSerializer { public: - CUBLASJSONSerializer(Map constant_names, Map bindings) + CublasJSONSerializer(Map constant_names, Map bindings) : JSONSerializer(constant_names), bindings_(bindings) {} using JSONSerializer::VisitExpr_; @@ -65,8 +65,8 @@ class CUBLASJSONSerializer : public JSONSerializer { inputs, 1 /* num_outputs_ */); const CallNode* root_call = nullptr; - if (composite_name.find("conv2d") != std::string::npos) { - root_call = backend::GetOpInFunction(fn, "relax.nn.conv2d"); + if (composite_name.find("matmul") != std::string::npos) { + root_call = backend::GetOpInFunction(fn, "relax.matmul"); } else { LOG(FATAL) << "Unimplemented pattern: " << composite_name; } @@ -85,11 +85,11 @@ Array CublasCompiler(Array functions, Map compiled_functions; for (const auto& func : functions) { - CUBLASJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); + CublasJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); serializer.serialize(func); auto graph_json = serializer.GetJSON(); auto constant_names = serializer.GetConstantNames(); - const auto* pf = runtime::Registry::Get("runtime.CUBLASJSONRuntimeCreate"); + const auto* pf = runtime::Registry::Get("runtime.CublasJSONRuntimeCreate"); ICHECK(pf != nullptr) << "Cannot find CUBLAS runtime module create function."; auto func_name = GetExtSymbol(func); compiled_functions.push_back((*pf)(func_name, graph_json, constant_names)); diff --git a/tests/python/relax/test_codegen_cublas.py b/tests/python/relax/test_codegen_cublas.py new file mode 100644 index 000000000000..0986d3caf29f --- /dev/null +++ b/tests/python/relax/test_codegen_cublas.py @@ -0,0 +1,208 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import pytest + +import tvm +import tvm.testing +import tvm.topi.testing +from tvm import relax +from tvm.contrib.pickle_memoize import memoize +from tvm.relax.backend import get_patterns_with_prefix +from tvm.script import relax as R + + +@pytest.fixture(autouse=True) +def reset_seed(): + np.random.seed(0) + + +has_cublas = tvm.get_global_func("relax.ext.cublas", True) + +cublas_enabled = pytest.mark.skipif( + not has_cublas, + reason="CUBLAS not enabled.", +) + +pytestmark = [cublas_enabled] + + +def build_and_run(mod, inputs_np, target, legalize=False): + if legalize: + mod = relax.transform.LegalizeOps()(mod) + + dev = tvm.device(target, 0) + ex = relax.build(mod, target) + vm = relax.VirtualMachine(ex, dev) + f = vm["main"] + inputs = [tvm.nd.array(inp, dev) for inp in inputs_np] + return f(*inputs).numpy() + + +def get_result_with_relax_cublas_offload(mod, *args): + patterns = [(entry.name, entry.pattern) for entry in get_patterns_with_prefix("cublas")] + assert len(patterns) != 0, "Cannot find cublas patterns" + + # TODO: partition + codegen_pass = relax.transform.RunCodegen() + mod = codegen_pass(mod) + + return build_and_run(mod, args, "cuda") + + +def get_relax_matmul_module( + x_shape, y_shape, dtype, transposed_y=False, with_bias=False, activation=None +): + if transposed_y: + n = y_shape[-2] + else: + n = y_shape[-1] + + from tvm.script.ir_builder import IRBuilder + from tvm.script.ir_builder import relax as relax_builder + + with IRBuilder() as builder: + with relax_builder.function(): + R.func_name("main") + x = R.arg("x", R.Tensor(x_shape, dtype)) + y = R.arg("y", R.Tensor(y_shape, dtype)) + if with_bias: + bias = R.arg("bias", R.Tensor((n,), dtype)) + + with R.dataflow() as frame: + if transposed_y: + axes = list(range(len(y_shape) - 2)) + [-1, -2] + y = R.emit(R.permute_dims(y, axes=axes)) + result = R.emit(R.matmul(x, y, out_dtype=dtype)) + if with_bias: + result = R.emit(result + bias) + if activation is not None: + result = R.emit(activation(result)) + R.output(result) + + R.func_ret_value(frame.output_vars[0]) + + func = builder.get() + return tvm.IRModule({"main": func}) + + +def _to_concrete_shape(symbolic_shape, var_table): + result = [] + for dim in symbolic_shape: + if not isinstance(dim, tvm.tir.expr.Var): + result.append(dim) + continue + + if dim not in var_table: + var_table[dim] = np.random.randint(10, 50) + result.append(var_table[dim]) + + return tuple(result) + + +_vars = { + "a": tvm.tir.expr.Var("a", "int64"), + "b": tvm.tir.expr.Var("b", "int64"), +} + + +_epilogue_table = { + "none": (False, None), + "bias": (True, None), + "relu": (True, R.nn.relu), + "gelu": (True, R.nn.gelu), +} + + +@pytest.mark.parametrize( + "x_shape, y_shape, transpose_y, epilogue", + [ + # Regular + ((32, 6), (6, 16), False, "none"), + ((_vars["a"], 6), (6, 16), False, "bias"), + # Transposed + ((4, 16), (16, 128), True, "relu"), + ((35, 8), (8, 8), True, "gelu"), + # 3D x 3D + ((6, 32, 8), (6, 8, 10), False, "bias"), + ((6, 32, 8), (6, 8, 10), True, "none"), + ((_vars["a"], 32, 8), (_vars["a"], 8, 10), True, "gelu"), + # 3D x 2D + ((6, 32, 8), (8, 10), False, "none"), + ((_vars["a"], 32, 8), (8, 10), False, "bias"), + ((10, 16, 8), (8, 10), True, "relu"), + # 2D x 3D + ((32, 8), (10, 8, 10), False, "relu"), + ((32, 8), (_vars["a"], 8, 10), True, "gelu"), + # ND x 2D + ((3, 6, 32, 8), (8, 10), False, "bias"), + ((_vars["a"], _vars["b"], 6, 32, 8), (8, 10), False, "none"), + # 2D x ND + ((32, 8), (5, 3, 8, 10), False, "gelu"), + # ND x ND + ((5, 3, 32, 8), (5, 3, 8, 10), True, "relu"), + ((3, 2, 4, 16, 15), (1, 1, 15, 2), True, "gelu"), + ((1, 1, 16, 15), (3, 2, _vars["a"], 15, 2), False, "none"), + ], +) +@pytest.mark.parametrize( + "dtype", + [ + "float16", + ], +) +def test_matmul_offload( + x_shape, + y_shape, + transpose_y, + epilogue, + dtype, +): + with_bias, activation = _epilogue_table[epilogue] + var_table = {} + concrete_x_shape = _to_concrete_shape(x_shape, var_table) + concrete_y_shape = _to_concrete_shape(y_shape, var_table) + x = np.random.randn(*concrete_x_shape).astype(dtype) + y = np.random.randn(*concrete_y_shape).astype(dtype) + + if transpose_y: + y = np.swapaxes(y, -2, -1) + y_shape = (*y_shape[:-2], y_shape[-1], y_shape[-2]) + + if with_bias: + bias = np.random.randn(concrete_y_shape[-1]).astype(dtype) + args = (x, y, bias) + else: + bias = None + args = (x, y) + + mod = get_relax_matmul_module( + x_shape, + y_shape, + dtype, + with_bias=with_bias, + transposed_y=transpose_y, + activation=activation, + ) + out = get_result_with_relax_cublas_offload(mod, *args) + ref = build_and_run(mod, args, "llvm", legalize=True) + + tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) + + +if __name__ == "__main__": + tvm.testing.main() From 8e8fa355fe44b80df5b4120cd7a958c5307562e4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 9 Mar 2023 20:11:35 +0900 Subject: [PATCH 04/15] basic gemm working --- python/tvm/relax/backend/contrib/cublas.py | 134 ++++++++++++++++++ python/tvm/relax/backend/contrib/cutlass.py | 10 -- src/runtime/contrib/cblas/gemm_common.h | 10 +- src/runtime/contrib/cublas/cublas.cc | 56 +++++++- .../contrib/cublas/cublas_json_runtime.cc | 47 +++--- src/runtime/contrib/cublas/cublas_utils.h | 4 + tests/python/relax/test_codegen_cublas.py | 73 +++++----- tests/python/relax/test_codegen_cutlass.py | 4 - 8 files changed, 256 insertions(+), 82 deletions(-) create mode 100644 python/tvm/relax/backend/contrib/cublas.py diff --git a/python/tvm/relax/backend/contrib/cublas.py b/python/tvm/relax/backend/contrib/cublas.py new file mode 100644 index 000000000000..96c7b48b9283 --- /dev/null +++ b/python/tvm/relax/backend/contrib/cublas.py @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Pattern table for CUTLASS backend""" + +from typing import Mapping, Optional, Tuple + +import tvm +from tvm.relax import Call, Expr, ShapeExpr, transform +from tvm.relax.dpl import CallPattern, DFPattern + +from ..pattern_registry import get_patterns_with_prefix, register_patterns +from ..patterns import make_matmul_pattern + + +def _is_supported_dtype(lhs_dtype, rhs_dtype): + """Check if dtypes in the given workload are supported by CUTLASS.""" + return ( + (lhs_dtype == "float16" and rhs_dtype == "float16") + or (lhs_dtype == "float32" and rhs_dtype == "float32") + or (lhs_dtype in ("int8", "uint8") and rhs_dtype in ("int8", "uint8")) + ) + + +def _check_matmul( + match_result: Mapping[DFPattern, Expr], + _: Expr, +) -> bool: + return True + + +register_patterns( + [ + ( + "cublas.matmul", + *make_matmul_pattern( + with_bias=False, + ), + _check_matmul, + ), + ( + "cublas.matmul_bias", + *make_matmul_pattern( + with_bias=True, + ), + _check_matmul, + ), + ( + "cublas.matmul_bias_relu", + *make_matmul_pattern( + with_bias=True, + activation="relax.nn.relu", + ), + _check_matmul, + ), + ( + "cublas.matmul_bias_gelu", + *make_matmul_pattern( + with_bias=True, + activation="relax.nn.gelu", + ), + _check_matmul, + ), + ( + "cublas.matmul_transposed", + *make_matmul_pattern( + with_bias=False, + transposed_rhs=True, + ), + _check_matmul, + ), + ( + "cublas.matmul_transposed_bias", + *make_matmul_pattern( + with_bias=True, + transposed_rhs=True, + ), + _check_matmul, + ), + ( + "cublas.matmul_transposed_bias_relu", + *make_matmul_pattern( + with_bias=True, + activation="relax.nn.relu", + transposed_rhs=True, + ), + _check_matmul, + ), + ( + "cublas.matmul_transposed_bias_gelu", + *make_matmul_pattern( + with_bias=True, + activation="relax.nn.gelu", + transposed_rhs=True, + ), + _check_matmul, + ), + ] +) + + +def partition_for_cublas(mod): + """ + Partition the input module into CUTLASS-supported subgraphs. + + Parameters + ---------- + mod: tvm.IRModule + The IRModule to be partitioned. + + Returns + ------- + mod: tvm.IRModule + The resulting IRModule, containing partitioned subgraphs to be + compiled by the CUTLASS backend. + """ + + cublas_pattern_entries = get_patterns_with_prefix("cublas") + patterns = [(e.name, e.pattern, e.check) for e in cublas_pattern_entries] + return transform.FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=True)(mod) diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index c03c913d63cd..8b243e52eae2 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -33,16 +33,6 @@ ) -def _get_static_shape(shape: ShapeExpr) -> Optional[Tuple[int]]: - result = [] - for dim in shape.values: - if isinstance(dim, tvm.tir.expr.IntImm): - result.append(int(dim)) - else: - return None - return result - - def _is_supported_dtype(lhs_dtype, rhs_dtype): """Check if dtypes in the given workload are supported by CUTLASS.""" return ( diff --git a/src/runtime/contrib/cblas/gemm_common.h b/src/runtime/contrib/cblas/gemm_common.h index 4724b14bffa1..9946484ae297 100644 --- a/src/runtime/contrib/cblas/gemm_common.h +++ b/src/runtime/contrib/cblas/gemm_common.h @@ -46,7 +46,7 @@ inline int ColumnStride(DLTensor* tensor) { } } -inline int ElementStride(DLTensor* tensor) { +inline int ElementStride(const DLTensor* tensor) { if (tensor->strides) { return std::min(tensor->strides[0], tensor->strides[1]); } else { @@ -55,13 +55,13 @@ inline int ElementStride(DLTensor* tensor) { } // Reversed strides indicates an in-place transpose operation. -inline bool IsInPlaceTransposed(DLTensor* tensor) { +inline bool IsInPlaceTransposed(const DLTensor* tensor) { return tensor->strides && (tensor->strides[1] > tensor->strides[0]); } -inline int RowCount(DLTensor* tensor, bool trans) { return tensor->shape[trans ? 1 : 0]; } +inline int RowCount(const DLTensor* tensor, bool trans) { return tensor->shape[trans ? 1 : 0]; } -inline int ColumnCount(DLTensor* tensor, bool trans) { return tensor->shape[trans ? 0 : 1]; } +inline int ColumnCount(const DLTensor* tensor, bool trans) { return tensor->shape[trans ? 0 : 1]; } // Call a column major blas. Note that data is stored in tvm as row // major, so this we switch the arguments. @@ -159,7 +159,7 @@ inline int ColumnStride3D(DLTensor* tensor) { return tensor->shape[2]; } } -inline int ElementStride3D(DLTensor* tensor) { +inline int ElementStride3D(const DLTensor* tensor) { if (tensor->strides) { return std::min(tensor->strides[1], tensor->strides[2]); } else { diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index ee0f50e3495b..aff2cfd135f6 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -133,6 +133,59 @@ bool CheckMixPrecisionType(DLDataType in_dtype, DLDataType out_dtype, bool int_s int roundoff(int v, int d) { return (v + d - 1) / d * d; } #if CUDART_VERSION >= 10010 + +void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, const DLTensor* C, + bool transa, bool transb, float alpha, float beta) { + ICHECK_EQ(A->ndim, 2); + ICHECK_EQ(B->ndim, 2); + ICHECK_EQ(C->ndim, 2); + ICHECK_EQ(ElementStride(A), 1); + ICHECK_EQ(ElementStride(B), 1); + ICHECK_EQ(ElementStride(C), 1); + ICHECK(TypeEqual(A->dtype, B->dtype)); + + // Reversed strides indicates an in-place transpose operation. + transa = IsInPlaceTransposed(A) ? !transa : transa; + transb = IsInPlaceTransposed(B) ? !transb : transb; + + int M = ColumnCount(B, transb); + int N = RowCount(A, transa); + int K = ColumnCount(A, transa); + + cublasLtMatmulDesc_t operationDesc = nullptr; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); + + cublasOperation_t opTransA = CUBLASBooleanToTranspose(transa); + cublasOperation_t opTransB = CUBLASBooleanToTranspose(transb); + + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, + &opTransA, sizeof(opTransA))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, + &opTransB, sizeof(opTransB))); + + cublasLtMatrixLayout_t Adesc, Bdesc, Cdesc; + cudaDataType_t ab_type = CUDA_R_32F; + cudaDataType_t c_type = CUDA_R_32F; + + // TODO + int lda = M; + int ldb = K; + int ldc = M; + + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Adesc, ab_type, opTransA == CUBLAS_OP_N ? M : K, + opTransA == CUBLAS_OP_N ? K : M, lda)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Bdesc, ab_type, opTransB == CUBLAS_OP_N ? K : N, + opTransB == CUBLAS_OP_N ? N : K, ldb)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Cdesc, c_type, M, N, ldc)); + + auto A_data = reinterpret_cast(static_cast(A->data) + A->byte_offset); + auto B_data = reinterpret_cast(static_cast(B->data) + B->byte_offset); + auto C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); + + CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl, operationDesc, &alpha, B_data, Adesc, A_data, Bdesc, &beta, + C_data, Cdesc, C_data, Cdesc, nullptr, nullptr, 0, nullptr)); +} + inline void CallLtIgemm(TVMArgs args, TVMRetValue* ret, cublasLtHandle_t hdl) { DLTensor* A = args[0]; DLTensor* B = args[1]; @@ -172,7 +225,6 @@ inline void CallLtIgemm(TVMArgs args, TVMRetValue* ret, cublasLtHandle_t hdl) { auto B_data = reinterpret_cast(static_cast(B->data) + B->byte_offset); auto C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); - cublasOperation_t opTranspose = CUBLAS_OP_T; cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; cublasLtOrder_t order_COL4_4R2_8C = CUBLASLT_ORDER_COL4_4R2_8C; cublasLtMatmulDesc_t operationDesc = nullptr; @@ -181,8 +233,6 @@ inline void CallLtIgemm(TVMArgs args, TVMRetValue* ret, cublasLtHandle_t hdl) { #else CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&operationDesc, CUDA_R_32I)); #endif - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, - &opTranspose, sizeof(opTranspose))); cublasOperation_t opTransA = CUBLASBooleanToTranspose(transa); cublasOperation_t opTransB = CUBLASBooleanToTranspose(transb); CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc b/src/runtime/contrib/cublas/cublas_json_runtime.cc index 8a4bdfbf15b7..98087022fe15 100644 --- a/src/runtime/contrib/cublas/cublas_json_runtime.cc +++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc @@ -33,9 +33,6 @@ #include "../json/json_node.h" #include "../json/json_runtime.h" -// TODO(@apeskov): Have to mute warning from cublas headers. -// -Wzero-as-null-pointer-constant and -Wdocumentation-unknown-command - #include "cublas_utils.h" namespace tvm { @@ -49,34 +46,38 @@ class CublasJSONRuntime : public JSONRuntimeBase { public: CublasJSONRuntime(const std::string& symbol_name, const std::string& graph_json, const Array const_names) - : JSONRuntimeBase(symbol_name, graph_json, const_names) {} + : JSONRuntimeBase(symbol_name, graph_json, const_names) { + } void Init(const Array& consts) override { } - /* Unused stub implementation */ - void Run() override { LOG(FATAL) << "Unreachable code"; } - - /* Thread safe implementation of Run. Keep runtime instance immutable */ - void Run(const TVMArgs& args) const {} - - /* Override GetFunction to reimplement Run method */ - PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) override { - if (this->symbol_name_ == name) { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - ICHECK(this->initialized_) << "The module has not been initialized"; - - ICHECK_EQ(args.size(), input_var_eid_.size() + outputs_.size()) - << "Found mismatch in the number of provided data entries and required."; - - Run(args); - }); - } else { - return JSONRuntimeBase::GetFunction(name, sptr_to_self); + void Run() override{ + cublasLtHandle_t handle; + cublasLtCreate(&handle); + + for (size_t i = 0; i < nodes_.size(); ++i) { + const auto& node = nodes_[i]; + if (node.GetOpType() == "kernel") { + auto op_name = node.GetOpName(); + if (op_name == "cublas.matmul") { + auto a_ptr = GetInput(node, 0); + auto b_ptr = GetInput(node, 1); + uint32_t output_eid = EntryID(outputs_[0]); + auto out_ptr = data_entry_[output_eid]; + tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, out_ptr, false, false, 1.0, 0.0); + } + } } } private: + const DLTensor* GetInput(const JSONGraphNode& node, const int idx) { + ICHECK_LT(idx, node.GetInputs().size()); + auto eid = EntryID(node.GetInputs()[idx]); + ICHECK(eid < data_entry_.size()); + return data_entry_[eid]; + } }; runtime::Module CublasJSONRuntimeCreate(String symbol_name, String graph_json, diff --git a/src/runtime/contrib/cublas/cublas_utils.h b/src/runtime/contrib/cublas/cublas_utils.h index 62863b8f7bc8..3a72ba4328ba 100644 --- a/src/runtime/contrib/cublas/cublas_utils.h +++ b/src/runtime/contrib/cublas/cublas_utils.h @@ -104,6 +104,10 @@ inline cudaDataType_t GetCudaDataType(DLDataType type) { } LOG(FATAL) << "Unsupported cuda type"; } + +void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, const DLTensor* C, + bool transa, bool transb, float alpha, float beta); + } // namespace contrib } // namespace tvm diff --git a/tests/python/relax/test_codegen_cublas.py b/tests/python/relax/test_codegen_cublas.py index 0986d3caf29f..81369ec3533d 100644 --- a/tests/python/relax/test_codegen_cublas.py +++ b/tests/python/relax/test_codegen_cublas.py @@ -22,7 +22,7 @@ import tvm.topi.testing from tvm import relax from tvm.contrib.pickle_memoize import memoize -from tvm.relax.backend import get_patterns_with_prefix +from tvm.relax.backend.contrib.cublas import partition_for_cublas from tvm.script import relax as R @@ -54,12 +54,8 @@ def build_and_run(mod, inputs_np, target, legalize=False): def get_result_with_relax_cublas_offload(mod, *args): - patterns = [(entry.name, entry.pattern) for entry in get_patterns_with_prefix("cublas")] - assert len(patterns) != 0, "Cannot find cublas patterns" - - # TODO: partition - codegen_pass = relax.transform.RunCodegen() - mod = codegen_pass(mod) + mod = partition_for_cublas(mod) + mod = relax.transform.RunCodegen()(mod) return build_and_run(mod, args, "cuda") @@ -122,9 +118,9 @@ def _to_concrete_shape(symbolic_shape, var_table): _epilogue_table = { "none": (False, None), - "bias": (True, None), - "relu": (True, R.nn.relu), - "gelu": (True, R.nn.gelu), + # "bias": (True, None), + # "relu": (True, R.nn.relu), + # "gelu": (True, R.nn.gelu), } @@ -132,37 +128,38 @@ def _to_concrete_shape(symbolic_shape, var_table): "x_shape, y_shape, transpose_y, epilogue", [ # Regular - ((32, 6), (6, 16), False, "none"), - ((_vars["a"], 6), (6, 16), False, "bias"), - # Transposed - ((4, 16), (16, 128), True, "relu"), - ((35, 8), (8, 8), True, "gelu"), - # 3D x 3D - ((6, 32, 8), (6, 8, 10), False, "bias"), - ((6, 32, 8), (6, 8, 10), True, "none"), - ((_vars["a"], 32, 8), (_vars["a"], 8, 10), True, "gelu"), - # 3D x 2D - ((6, 32, 8), (8, 10), False, "none"), - ((_vars["a"], 32, 8), (8, 10), False, "bias"), - ((10, 16, 8), (8, 10), True, "relu"), - # 2D x 3D - ((32, 8), (10, 8, 10), False, "relu"), - ((32, 8), (_vars["a"], 8, 10), True, "gelu"), - # ND x 2D - ((3, 6, 32, 8), (8, 10), False, "bias"), - ((_vars["a"], _vars["b"], 6, 32, 8), (8, 10), False, "none"), - # 2D x ND - ((32, 8), (5, 3, 8, 10), False, "gelu"), - # ND x ND - ((5, 3, 32, 8), (5, 3, 8, 10), True, "relu"), - ((3, 2, 4, 16, 15), (1, 1, 15, 2), True, "gelu"), - ((1, 1, 16, 15), (3, 2, _vars["a"], 15, 2), False, "none"), + ((8, 8), (8, 8), False, "none"), + # ((_vars["a"], 6), (6, 16), False, "bias"), + # # Transposed + # ((4, 16), (16, 128), True, "relu"), + # ((35, 8), (8, 8), True, "gelu"), + # # 3D x 3D + # ((6, 32, 8), (6, 8, 10), False, "bias"), + # ((6, 32, 8), (6, 8, 10), True, "none"), + # ((_vars["a"], 32, 8), (_vars["a"], 8, 10), True, "gelu"), + # # 3D x 2D + # ((6, 32, 8), (8, 10), False, "none"), + # ((_vars["a"], 32, 8), (8, 10), False, "bias"), + # ((10, 16, 8), (8, 10), True, "relu"), + # # 2D x 3D + # ((32, 8), (10, 8, 10), False, "relu"), + # ((32, 8), (_vars["a"], 8, 10), True, "gelu"), + # # ND x 2D + # ((3, 6, 32, 8), (8, 10), False, "bias"), + # ((_vars["a"], _vars["b"], 6, 32, 8), (8, 10), False, "none"), + # # 2D x ND + # ((32, 8), (5, 3, 8, 10), False, "gelu"), + # # ND x ND + # ((5, 3, 32, 8), (5, 3, 8, 10), True, "relu"), + # ((3, 2, 4, 16, 15), (1, 1, 15, 2), True, "gelu"), + # ((1, 1, 16, 15), (3, 2, _vars["a"], 15, 2), False, "none"), ], ) @pytest.mark.parametrize( "dtype", [ - "float16", + # "float16", + "float32", ], ) def test_matmul_offload( @@ -198,6 +195,7 @@ def test_matmul_offload( transposed_y=transpose_y, activation=activation, ) + out = get_result_with_relax_cublas_offload(mod, *args) ref = build_and_run(mod, args, "llvm", legalize=True) @@ -205,4 +203,5 @@ def test_matmul_offload( if __name__ == "__main__": - tvm.testing.main() + # tvm.testing.main() + test_matmul_offload((32, 8), (8, 16), False, "none", "float32") diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index c8ca44311de5..1c814294c94e 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -23,7 +23,6 @@ from tvm import relax from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul from tvm.contrib.pickle_memoize import memoize -from tvm.relax.backend import get_patterns_with_prefix from tvm.relax.backend.contrib.cutlass import partition_for_cutlass from tvm.script import relax as R from tvm.script.ir_builder import IRBuilder @@ -96,9 +95,6 @@ def build_and_run(mod, inputs_np, target, legalize=False): def get_result_with_relax_cutlass_offload(mod, *args, assert_all_bindings_fused=True): - patterns = [(entry.name, entry.pattern) for entry in get_patterns_with_prefix("cutlass")] - assert len(patterns) != 0, "Cannot find cutlass patterns" - mod = partition_for_cutlass(mod) if assert_all_bindings_fused: From 3ce49a68411e25aa9d8348b738cf32d4411ab0f5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 10 Mar 2023 21:49:18 +0900 Subject: [PATCH 05/15] transposed gemm work --- src/runtime/contrib/cblas/gemm_common.h | 2 +- src/runtime/contrib/cublas/cublas.cc | 18 +++++++++--------- .../contrib/cublas/cublas_json_runtime.cc | 9 ++++++++- tests/python/relax/test_codegen_cublas.py | 1 + 4 files changed, 19 insertions(+), 11 deletions(-) diff --git a/src/runtime/contrib/cblas/gemm_common.h b/src/runtime/contrib/cblas/gemm_common.h index 9946484ae297..fe05f5f483fc 100644 --- a/src/runtime/contrib/cblas/gemm_common.h +++ b/src/runtime/contrib/cblas/gemm_common.h @@ -35,7 +35,7 @@ namespace tvm { namespace contrib { using namespace runtime; -inline int ColumnStride(DLTensor* tensor) { +inline int ColumnStride(const DLTensor* tensor) { // If the tensor itself is transposed then it will have strides // backward from what we expect. Regardless, the max of the strides // (the other stride is 1) is the column stride. diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index aff2cfd135f6..ba5848532370 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -159,23 +159,23 @@ void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, co cublasOperation_t opTransB = CUBLASBooleanToTranspose(transb); CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, - &opTransA, sizeof(opTransA))); + &opTransB, sizeof(opTransA))); CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, - &opTransB, sizeof(opTransB))); + &opTransA, sizeof(opTransB))); cublasLtMatrixLayout_t Adesc, Bdesc, Cdesc; cudaDataType_t ab_type = CUDA_R_32F; cudaDataType_t c_type = CUDA_R_32F; - // TODO - int lda = M; - int ldb = K; + int lda = opTransB == CUBLAS_OP_N? M : K; + int ldb = opTransA == CUBLAS_OP_N? K : N; int ldc = M; - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Adesc, ab_type, opTransA == CUBLAS_OP_N ? M : K, - opTransA == CUBLAS_OP_N ? K : M, lda)); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Bdesc, ab_type, opTransB == CUBLAS_OP_N ? K : N, - opTransB == CUBLAS_OP_N ? N : K, ldb)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Adesc, ab_type, opTransB == CUBLAS_OP_N ? M : K, + opTransB == CUBLAS_OP_N ? K : M, lda)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Bdesc, ab_type, opTransA == CUBLAS_OP_N ? K : N, + opTransA == CUBLAS_OP_N ? N : K, ldb)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Cdesc, c_type, M, N, ldc)); auto A_data = reinterpret_cast(static_cast(A->data) + A->byte_offset); diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc b/src/runtime/contrib/cublas/cublas_json_runtime.cc index 98087022fe15..2e85d5989c8c 100644 --- a/src/runtime/contrib/cublas/cublas_json_runtime.cc +++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc @@ -61,11 +61,18 @@ class CublasJSONRuntime : public JSONRuntimeBase { if (node.GetOpType() == "kernel") { auto op_name = node.GetOpName(); if (op_name == "cublas.matmul") { + uint32_t output_eid = EntryID(outputs_[0]); + auto out_ptr = data_entry_[output_eid]; auto a_ptr = GetInput(node, 0); auto b_ptr = GetInput(node, 1); + tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, out_ptr, false, false, 1.0, 0.0); + } else if (op_name == "cublas.matmul_transposed") { uint32_t output_eid = EntryID(outputs_[0]); auto out_ptr = data_entry_[output_eid]; - tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, out_ptr, false, false, 1.0, 0.0); + // TODO: fix + auto a_ptr = GetInput(node, 1); + auto b_ptr = GetInput(node, 0); + tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, out_ptr, false, true, 1.0, 0.0); } } } diff --git a/tests/python/relax/test_codegen_cublas.py b/tests/python/relax/test_codegen_cublas.py index 81369ec3533d..427fc3bf4d1b 100644 --- a/tests/python/relax/test_codegen_cublas.py +++ b/tests/python/relax/test_codegen_cublas.py @@ -205,3 +205,4 @@ def test_matmul_offload( if __name__ == "__main__": # tvm.testing.main() test_matmul_offload((32, 8), (8, 16), False, "none", "float32") + test_matmul_offload((32, 8), (8, 16), True, "none", "float32") From 9f48c2b7383368415ebe089f2aaa8b6db9bd22b8 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 11 Mar 2023 05:16:05 +0900 Subject: [PATCH 06/15] wip --- src/runtime/contrib/cublas/cublas.cc | 70 +++++++++++++++++++ .../contrib/cublas/cublas_json_runtime.cc | 14 ++++ src/runtime/contrib/cublas/cublas_utils.h | 3 + tests/python/relax/test_codegen_cublas.py | 6 +- 4 files changed, 91 insertions(+), 2 deletions(-) diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index ba5848532370..ab8a6d743297 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -186,6 +186,76 @@ void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, co C_data, Cdesc, C_data, Cdesc, nullptr, nullptr, 0, nullptr)); } +void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, const DLTensor* bias, const DLTensor* C, + bool transa, bool transb, cublasLtEpilogue_t epilouge) { + ICHECK_EQ(A->ndim, 2); + ICHECK_EQ(B->ndim, 2); + ICHECK_EQ(C->ndim, 2); + ICHECK_EQ(ElementStride(A), 1); + ICHECK_EQ(ElementStride(B), 1); + ICHECK_EQ(ElementStride(C), 1); + ICHECK(TypeEqual(A->dtype, B->dtype)); + + // Reversed strides indicates an in-place transpose operation. + transa = IsInPlaceTransposed(A) ? !transa : transa; + transb = IsInPlaceTransposed(B) ? !transb : transb; + + int M = ColumnCount(B, transb); + int N = RowCount(A, transa); + int K = ColumnCount(A, transa); + + cublasLtMatmulDesc_t operationDesc = nullptr; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); + + if (bias_ptr != nullptr) { + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, + CUBLASLT_MATMUL_DESC_BIAS_POINTER, + &bias->data, + sizeof(float*))) + } + + if (epilogue != CUBLASLT_EPILOGUE_DEFAULT) { + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, + CUBLASLT_MATMUL_DESC_EPILOGUE, + &epilogue, + sizeof(epilogue))) + } + + cublasOperation_t opTransA = CUBLASBooleanToTranspose(transa); + cublasOperation_t opTransB = CUBLASBooleanToTranspose(transb); + + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, + &opTransB, sizeof(opTransA))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, + &opTransA, sizeof(opTransB))); + + cublasLtMatrixLayout_t Adesc, Bdesc, Cdesc; + cudaDataType_t ab_type = CUDA_R_32F; + cudaDataType_t c_type = CUDA_R_32F; + + int lda = opTransB == CUBLAS_OP_N? M : K; + int ldb = opTransA == CUBLAS_OP_N? K : N; + int ldc = M; + + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Adesc, ab_type, opTransB == CUBLAS_OP_N ? M : K, + opTransB == CUBLAS_OP_N ? K : M, lda)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Bdesc, ab_type, opTransA == CUBLAS_OP_N ? K : N, + opTransA == CUBLAS_OP_N ? N : K, ldb)); + + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Cdesc, c_type, M, N, ldc)); + + auto A_data = reinterpret_cast(static_cast(A->data) + A->byte_offset); + auto B_data = reinterpret_cast(static_cast(B->data) + B->byte_offset); + auto C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); + + float alpha = 1.0; + float beta = 0.0; + CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl, operationDesc, &alpha, B_data, Adesc, A_data, Bdesc, &beta, + C_data, Cdesc, C_data, Cdesc, nullptr, nullptr, 0, nullptr)); +} + inline void CallLtIgemm(TVMArgs args, TVMRetValue* ret, cublasLtHandle_t hdl) { DLTensor* A = args[0]; DLTensor* B = args[1]; diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc b/src/runtime/contrib/cublas/cublas_json_runtime.cc index 2e85d5989c8c..ee3adbd8a3fd 100644 --- a/src/runtime/contrib/cublas/cublas_json_runtime.cc +++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc @@ -73,6 +73,20 @@ class CublasJSONRuntime : public JSONRuntimeBase { auto a_ptr = GetInput(node, 1); auto b_ptr = GetInput(node, 0); tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, out_ptr, false, true, 1.0, 0.0); + } else if (op_name == "cublas.matmul_bias") { + uint32_t output_eid = EntryID(outputs_[0]); + auto out_ptr = data_entry_[output_eid]; + auto a_ptr = GetInput(node, 0); + auto b_ptr = GetInput(node, 1); + auto bias_ptr = GetInput(node, 2); + tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, false, false, CUBLASLT_EPILOGUE_BIAS); + } else if (op_name == "cublas.matmul_bias_relu") { + uint32_t output_eid = EntryID(outputs_[0]); + auto out_ptr = data_entry_[output_eid]; + auto a_ptr = GetInput(node, 0); + auto b_ptr = GetInput(node, 1); + auto bias_ptr = GetInput(node, 2); + tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, false, false, CUBLASLT_EPILOGUE_RELU_BIAS); } } } diff --git a/src/runtime/contrib/cublas/cublas_utils.h b/src/runtime/contrib/cublas/cublas_utils.h index 3a72ba4328ba..858011f54149 100644 --- a/src/runtime/contrib/cublas/cublas_utils.h +++ b/src/runtime/contrib/cublas/cublas_utils.h @@ -108,6 +108,9 @@ inline cudaDataType_t GetCudaDataType(DLDataType type) { void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, const DLTensor* C, bool transa, bool transb, float alpha, float beta); +void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, const DLTensor* bias, const DLTensor* C, + bool transa, bool transb, cublasLtEpilogue_t epilouge=CUBLASLT_EPILOGUE_DEFAULT); + } // namespace contrib } // namespace tvm diff --git a/tests/python/relax/test_codegen_cublas.py b/tests/python/relax/test_codegen_cublas.py index 427fc3bf4d1b..eed8084419f2 100644 --- a/tests/python/relax/test_codegen_cublas.py +++ b/tests/python/relax/test_codegen_cublas.py @@ -204,5 +204,7 @@ def test_matmul_offload( if __name__ == "__main__": # tvm.testing.main() - test_matmul_offload((32, 8), (8, 16), False, "none", "float32") - test_matmul_offload((32, 8), (8, 16), True, "none", "float32") + # test_matmul_offload((32, 8), (8, 16), False, "none", "float32") + # test_matmul_offload((32, 8), (8, 16), True, "none", "float32") + test_matmul_offload((32, 8), (8, 16), False, "bias", "float32") + test_matmul_offload((32, 8), (8, 16), False, "relu", "float32") From a834b0f4a18534624837a0c88ecbc4cdf5144a8f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 11 Mar 2023 05:33:15 +0900 Subject: [PATCH 07/15] bias and epilogue work --- src/runtime/contrib/cublas/cublas.cc | 8 ++++---- src/runtime/contrib/cublas/cublas_json_runtime.cc | 7 +++++++ src/runtime/contrib/cublas/cublas_utils.h | 2 +- tests/python/relax/test_codegen_cublas.py | 10 +++++----- 4 files changed, 17 insertions(+), 10 deletions(-) diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index ab8a6d743297..ecc8b35a89a4 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -187,7 +187,7 @@ void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, co } void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, const DLTensor* bias, const DLTensor* C, - bool transa, bool transb, cublasLtEpilogue_t epilouge) { + bool transa, bool transb, cublasLtEpilogue_t epilogue) { ICHECK_EQ(A->ndim, 2); ICHECK_EQ(B->ndim, 2); ICHECK_EQ(C->ndim, 2); @@ -207,12 +207,12 @@ void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, co cublasLtMatmulDesc_t operationDesc = nullptr; CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); - if (bias_ptr != nullptr) { + if (bias != nullptr) { CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias->data, - sizeof(float*))) + sizeof(float*))); } if (epilogue != CUBLASLT_EPILOGUE_DEFAULT) { @@ -220,7 +220,7 @@ void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, co operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, - sizeof(epilogue))) + sizeof(epilogue))); } cublasOperation_t opTransA = CUBLASBooleanToTranspose(transa); diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc b/src/runtime/contrib/cublas/cublas_json_runtime.cc index ee3adbd8a3fd..55c5d564ae03 100644 --- a/src/runtime/contrib/cublas/cublas_json_runtime.cc +++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc @@ -87,6 +87,13 @@ class CublasJSONRuntime : public JSONRuntimeBase { auto b_ptr = GetInput(node, 1); auto bias_ptr = GetInput(node, 2); tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, false, false, CUBLASLT_EPILOGUE_RELU_BIAS); + } else if (op_name == "cublas.matmul_bias_gelu") { + uint32_t output_eid = EntryID(outputs_[0]); + auto out_ptr = data_entry_[output_eid]; + auto a_ptr = GetInput(node, 0); + auto b_ptr = GetInput(node, 1); + auto bias_ptr = GetInput(node, 2); + tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, false, false, CUBLASLT_EPILOGUE_GELU_BIAS); } } } diff --git a/src/runtime/contrib/cublas/cublas_utils.h b/src/runtime/contrib/cublas/cublas_utils.h index 858011f54149..6aa7b9e92df7 100644 --- a/src/runtime/contrib/cublas/cublas_utils.h +++ b/src/runtime/contrib/cublas/cublas_utils.h @@ -109,7 +109,7 @@ void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, co bool transa, bool transb, float alpha, float beta); void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, const DLTensor* bias, const DLTensor* C, - bool transa, bool transb, cublasLtEpilogue_t epilouge=CUBLASLT_EPILOGUE_DEFAULT); + bool transa, bool transb, cublasLtEpilogue_t epilogue=CUBLASLT_EPILOGUE_DEFAULT); } // namespace contrib } // namespace tvm diff --git a/tests/python/relax/test_codegen_cublas.py b/tests/python/relax/test_codegen_cublas.py index eed8084419f2..0cdec6de3f71 100644 --- a/tests/python/relax/test_codegen_cublas.py +++ b/tests/python/relax/test_codegen_cublas.py @@ -118,9 +118,9 @@ def _to_concrete_shape(symbolic_shape, var_table): _epilogue_table = { "none": (False, None), - # "bias": (True, None), - # "relu": (True, R.nn.relu), - # "gelu": (True, R.nn.gelu), + "bias": (True, None), + "relu": (True, R.nn.relu), + "gelu": (True, R.nn.gelu), } @@ -206,5 +206,5 @@ def test_matmul_offload( # tvm.testing.main() # test_matmul_offload((32, 8), (8, 16), False, "none", "float32") # test_matmul_offload((32, 8), (8, 16), True, "none", "float32") - test_matmul_offload((32, 8), (8, 16), False, "bias", "float32") - test_matmul_offload((32, 8), (8, 16), False, "relu", "float32") + # test_matmul_offload((32, 8), (8, 16), False, "bias", "float32") + test_matmul_offload((32, 8), (8, 16), False, "gelu", "float32") From 5e7480ea65600007debae93b628a60478ada1710 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 13 Mar 2023 20:50:54 +0900 Subject: [PATCH 08/15] support fp16 and transposed bias --- src/runtime/contrib/cublas/cublas.cc | 87 ++++++------------- .../contrib/cublas/cublas_json_runtime.cc | 27 +++++- src/runtime/contrib/cublas/cublas_utils.h | 3 - tests/python/relax/test_codegen_cublas.py | 18 ++-- 4 files changed, 63 insertions(+), 72 deletions(-) diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index ecc8b35a89a4..cf6d3f8aa13b 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -27,6 +27,8 @@ #include "../cblas/gemm_common.h" #include "cublas_utils.h" +#include "../../3rdparty/compiler-rt/builtin_fp16.h" + namespace tvm { namespace contrib { @@ -134,8 +136,8 @@ int roundoff(int v, int d) { return (v + d - 1) / d * d; } #if CUDART_VERSION >= 10010 -void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, const DLTensor* C, - bool transa, bool transb, float alpha, float beta) { +void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, const DLTensor* bias, const DLTensor* C, + bool transa, bool transb, cublasLtEpilogue_t epilogue) { ICHECK_EQ(A->ndim, 2); ICHECK_EQ(B->ndim, 2); ICHECK_EQ(C->ndim, 2); @@ -153,59 +155,30 @@ void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, co int K = ColumnCount(A, transa); cublasLtMatmulDesc_t operationDesc = nullptr; - CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); - - cublasOperation_t opTransA = CUBLASBooleanToTranspose(transa); - cublasOperation_t opTransB = CUBLASBooleanToTranspose(transb); - - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, - &opTransB, sizeof(opTransA))); - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, - &opTransA, sizeof(opTransB))); - - cublasLtMatrixLayout_t Adesc, Bdesc, Cdesc; + auto compute_type = CUBLAS_COMPUTE_32F; + auto scale_type = CUDA_R_32F; cudaDataType_t ab_type = CUDA_R_32F; cudaDataType_t c_type = CUDA_R_32F; + float one_fp32 = 1.0; + float zero_fp32 = 0.0; + auto one_fp16 = __truncXfYf2__(1.0); + auto zero_fp16 = __truncXfYf2__(0.0); + void* alpha = &one_fp32; + void* beta = &zero_fp32; + + if (A->dtype.bits == 16 && A->dtype.code == kDLFloat) { + ab_type = CUDA_R_16F; + } - int lda = opTransB == CUBLAS_OP_N? M : K; - int ldb = opTransA == CUBLAS_OP_N? K : N; - int ldc = M; - - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Adesc, ab_type, opTransB == CUBLAS_OP_N ? M : K, - opTransB == CUBLAS_OP_N ? K : M, lda)); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Bdesc, ab_type, opTransA == CUBLAS_OP_N ? K : N, - opTransA == CUBLAS_OP_N ? N : K, ldb)); - - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Cdesc, c_type, M, N, ldc)); - - auto A_data = reinterpret_cast(static_cast(A->data) + A->byte_offset); - auto B_data = reinterpret_cast(static_cast(B->data) + B->byte_offset); - auto C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); - - CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl, operationDesc, &alpha, B_data, Adesc, A_data, Bdesc, &beta, - C_data, Cdesc, C_data, Cdesc, nullptr, nullptr, 0, nullptr)); -} - -void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, const DLTensor* bias, const DLTensor* C, - bool transa, bool transb, cublasLtEpilogue_t epilogue) { - ICHECK_EQ(A->ndim, 2); - ICHECK_EQ(B->ndim, 2); - ICHECK_EQ(C->ndim, 2); - ICHECK_EQ(ElementStride(A), 1); - ICHECK_EQ(ElementStride(B), 1); - ICHECK_EQ(ElementStride(C), 1); - ICHECK(TypeEqual(A->dtype, B->dtype)); - - // Reversed strides indicates an in-place transpose operation. - transa = IsInPlaceTransposed(A) ? !transa : transa; - transb = IsInPlaceTransposed(B) ? !transb : transb; - - int M = ColumnCount(B, transb); - int N = RowCount(A, transa); - int K = ColumnCount(A, transa); + if (C->dtype.bits == 16 && C->dtype.code == kDLFloat) { + c_type = CUDA_R_16F; + compute_type = CUBLAS_COMPUTE_16F; + scale_type = CUDA_R_16F; + alpha = &one_fp16; + beta = &zero_fp16; + } - cublasLtMatmulDesc_t operationDesc = nullptr; - CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&operationDesc, compute_type, scale_type)); if (bias != nullptr) { CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( @@ -232,8 +205,6 @@ void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, co &opTransA, sizeof(opTransB))); cublasLtMatrixLayout_t Adesc, Bdesc, Cdesc; - cudaDataType_t ab_type = CUDA_R_32F; - cudaDataType_t c_type = CUDA_R_32F; int lda = opTransB == CUBLAS_OP_N? M : K; int ldb = opTransA == CUBLAS_OP_N? K : N; @@ -246,13 +217,11 @@ void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, co CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Cdesc, c_type, M, N, ldc)); - auto A_data = reinterpret_cast(static_cast(A->data) + A->byte_offset); - auto B_data = reinterpret_cast(static_cast(B->data) + B->byte_offset); - auto C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); + auto A_data = static_cast(A->data) + A->byte_offset; + auto B_data = static_cast(B->data) + B->byte_offset; + auto C_data = static_cast(C->data) + C->byte_offset; - float alpha = 1.0; - float beta = 0.0; - CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl, operationDesc, &alpha, B_data, Adesc, A_data, Bdesc, &beta, + CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl, operationDesc, alpha, B_data, Adesc, A_data, Bdesc, beta, C_data, Cdesc, C_data, Cdesc, nullptr, nullptr, 0, nullptr)); } diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc b/src/runtime/contrib/cublas/cublas_json_runtime.cc index 55c5d564ae03..c6cee4f801c2 100644 --- a/src/runtime/contrib/cublas/cublas_json_runtime.cc +++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc @@ -65,14 +65,14 @@ class CublasJSONRuntime : public JSONRuntimeBase { auto out_ptr = data_entry_[output_eid]; auto a_ptr = GetInput(node, 0); auto b_ptr = GetInput(node, 1); - tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, out_ptr, false, false, 1.0, 0.0); + tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, nullptr, out_ptr, false, false, CUBLASLT_EPILOGUE_DEFAULT); } else if (op_name == "cublas.matmul_transposed") { uint32_t output_eid = EntryID(outputs_[0]); auto out_ptr = data_entry_[output_eid]; // TODO: fix auto a_ptr = GetInput(node, 1); auto b_ptr = GetInput(node, 0); - tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, out_ptr, false, true, 1.0, 0.0); + tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, nullptr, out_ptr, false, true, CUBLASLT_EPILOGUE_DEFAULT); } else if (op_name == "cublas.matmul_bias") { uint32_t output_eid = EntryID(outputs_[0]); auto out_ptr = data_entry_[output_eid]; @@ -94,6 +94,29 @@ class CublasJSONRuntime : public JSONRuntimeBase { auto b_ptr = GetInput(node, 1); auto bias_ptr = GetInput(node, 2); tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, false, false, CUBLASLT_EPILOGUE_GELU_BIAS); + } else if (op_name == "cublas.matmul_transposed_bias") { + uint32_t output_eid = EntryID(outputs_[0]); + auto out_ptr = data_entry_[output_eid]; + auto a_ptr = GetInput(node, 1); + auto b_ptr = GetInput(node, 0); + auto bias_ptr = GetInput(node, 2); + tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, false, true, CUBLASLT_EPILOGUE_BIAS); + } else if (op_name == "cublas.matmul_transposed_bias_relu") { + uint32_t output_eid = EntryID(outputs_[0]); + auto out_ptr = data_entry_[output_eid]; + auto a_ptr = GetInput(node, 1); + auto b_ptr = GetInput(node, 0); + auto bias_ptr = GetInput(node, 2); + tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, false, true, CUBLASLT_EPILOGUE_RELU_BIAS); + } else if (op_name == "cublas.matmul_transposed_bias_gelu") { + uint32_t output_eid = EntryID(outputs_[0]); + auto out_ptr = data_entry_[output_eid]; + auto a_ptr = GetInput(node, 1); + auto b_ptr = GetInput(node, 0); + auto bias_ptr = GetInput(node, 2); + tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, false, true, CUBLASLT_EPILOGUE_GELU_BIAS); + } else { + LOG(FATAL) << op_name; } } } diff --git a/src/runtime/contrib/cublas/cublas_utils.h b/src/runtime/contrib/cublas/cublas_utils.h index 6aa7b9e92df7..1cd327f3b908 100644 --- a/src/runtime/contrib/cublas/cublas_utils.h +++ b/src/runtime/contrib/cublas/cublas_utils.h @@ -105,9 +105,6 @@ inline cudaDataType_t GetCudaDataType(DLDataType type) { LOG(FATAL) << "Unsupported cuda type"; } -void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, const DLTensor* C, - bool transa, bool transb, float alpha, float beta); - void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, const DLTensor* bias, const DLTensor* C, bool transa, bool transb, cublasLtEpilogue_t epilogue=CUBLASLT_EPILOGUE_DEFAULT); diff --git a/tests/python/relax/test_codegen_cublas.py b/tests/python/relax/test_codegen_cublas.py index 0cdec6de3f71..abc78dff0052 100644 --- a/tests/python/relax/test_codegen_cublas.py +++ b/tests/python/relax/test_codegen_cublas.py @@ -129,10 +129,10 @@ def _to_concrete_shape(symbolic_shape, var_table): [ # Regular ((8, 8), (8, 8), False, "none"), - # ((_vars["a"], 6), (6, 16), False, "bias"), - # # Transposed - # ((4, 16), (16, 128), True, "relu"), - # ((35, 8), (8, 8), True, "gelu"), + ((_vars["a"], 6), (6, 16), False, "bias"), + # Transposed + ((4, 16), (16, 128), True, "relu"), + ((35, 8), (8, 8), True, "gelu"), # # 3D x 3D # ((6, 32, 8), (6, 8, 10), False, "bias"), # ((6, 32, 8), (6, 8, 10), True, "none"), @@ -158,7 +158,7 @@ def _to_concrete_shape(symbolic_shape, var_table): @pytest.mark.parametrize( "dtype", [ - # "float16", + "float16", "float32", ], ) @@ -200,11 +200,13 @@ def test_matmul_offload( ref = build_and_run(mod, args, "llvm", legalize=True) tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) + print("ok") if __name__ == "__main__": # tvm.testing.main() - # test_matmul_offload((32, 8), (8, 16), False, "none", "float32") - # test_matmul_offload((32, 8), (8, 16), True, "none", "float32") + test_matmul_offload((32, 8), (8, 16), False, "none", "float16") + # test_matmul_offload((32, 8), (8, 16), True, "relu", "float32") # test_matmul_offload((32, 8), (8, 16), False, "bias", "float32") - test_matmul_offload((32, 8), (8, 16), False, "gelu", "float32") + # test_matmul_offload((32, 8), (8, 16), False, "gelu", "float32") + # test_matmul_offload((_vars["a"], 8), (8, 16), False, "relu", "float32") From f9ce24e7789c16b376b8793d1db363813be45125 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 13 Mar 2023 21:07:32 +0900 Subject: [PATCH 09/15] support batched gemm --- python/tvm/relax/backend/contrib/cublas.py | 59 ++++++++-- src/runtime/contrib/cblas/gemm_common.h | 8 +- src/runtime/contrib/cublas/cublas.cc | 99 ++++++++++------- .../contrib/cublas/cublas_json_runtime.cc | 105 +++++++++--------- src/runtime/contrib/cublas/cublas_utils.h | 5 +- tests/python/relax/test_codegen_cublas.py | 33 +----- 6 files changed, 177 insertions(+), 132 deletions(-) diff --git a/python/tvm/relax/backend/contrib/cublas.py b/python/tvm/relax/backend/contrib/cublas.py index 96c7b48b9283..3df205dccbd0 100644 --- a/python/tvm/relax/backend/contrib/cublas.py +++ b/python/tvm/relax/backend/contrib/cublas.py @@ -15,12 +15,14 @@ # specific language governing permissions and limitations # under the License. -"""Pattern table for CUTLASS backend""" +"""Pattern table for cuBLAS backend""" +import operator +from functools import reduce -from typing import Mapping, Optional, Tuple +from typing import Mapping import tvm -from tvm.relax import Call, Expr, ShapeExpr, transform +from tvm.relax import Call, Expr, transform from tvm.relax.dpl import CallPattern, DFPattern from ..pattern_registry import get_patterns_with_prefix, register_patterns @@ -28,11 +30,9 @@ def _is_supported_dtype(lhs_dtype, rhs_dtype): - """Check if dtypes in the given workload are supported by CUTLASS.""" - return ( - (lhs_dtype == "float16" and rhs_dtype == "float16") - or (lhs_dtype == "float32" and rhs_dtype == "float32") - or (lhs_dtype in ("int8", "uint8") and rhs_dtype in ("int8", "uint8")) + """Check if dtypes in the given workload are supported by cuBLAS BYOC.""" + return (lhs_dtype == "float16" and rhs_dtype == "float16") or ( + lhs_dtype == "float32" and rhs_dtype == "float32" ) @@ -40,7 +40,44 @@ def _check_matmul( match_result: Mapping[DFPattern, Expr], _: Expr, ) -> bool: - return True + matmul_call: Call = None + for pattern, expr in match_result.items(): + if ( + isinstance(expr, Call) + and isinstance(pattern, CallPattern) + and isinstance(expr.op, tvm.ir.Op) + and expr.op.name == "relax.matmul" + ): + matmul_call = expr + if matmul_call is None: + raise ValueError("Cannot find call to matmul from match_result.") + + lhs, rhs, *_ = matmul_call.args + + lhs_dtype = lhs.struct_info.dtype + rhs_dtype = rhs.struct_info.dtype + if not _is_supported_dtype(lhs_dtype, rhs_dtype): + return False + + lhs_shape = lhs.struct_info.shape.values + rhs_shape = rhs.struct_info.shape.values + + if not isinstance(lhs_shape[-1], (tvm.tir.expr.IntImm, int)): + # Reduction axis must be constant + return False + + lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1) + rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1) + + # CublasLT does not seem to support batched GEMM with one of matrices having + # one batch (with batch_stride 0). So for batched GEMM, the two batch counts + # must be equal. + return ( + (lhs_batches == 1 and rhs_batches == 1) + or isinstance(lhs_batches, tvm.tir.Var) + or isinstance(rhs_batches, tvm.tir.Var) + or (int(lhs_batches) == int(rhs_batches)) + ) register_patterns( @@ -115,7 +152,7 @@ def _check_matmul( def partition_for_cublas(mod): """ - Partition the input module into CUTLASS-supported subgraphs. + Partition the input module into cuBLAS-supported subgraphs. Parameters ---------- @@ -126,7 +163,7 @@ def partition_for_cublas(mod): ------- mod: tvm.IRModule The resulting IRModule, containing partitioned subgraphs to be - compiled by the CUTLASS backend. + offloaded to the cuBLAS backend. """ cublas_pattern_entries = get_patterns_with_prefix("cublas") diff --git a/src/runtime/contrib/cblas/gemm_common.h b/src/runtime/contrib/cblas/gemm_common.h index fe05f5f483fc..af073da9ba1a 100644 --- a/src/runtime/contrib/cblas/gemm_common.h +++ b/src/runtime/contrib/cblas/gemm_common.h @@ -59,9 +59,13 @@ inline bool IsInPlaceTransposed(const DLTensor* tensor) { return tensor->strides && (tensor->strides[1] > tensor->strides[0]); } -inline int RowCount(const DLTensor* tensor, bool trans) { return tensor->shape[trans ? 1 : 0]; } +inline int RowCount(const DLTensor* tensor, bool trans, int batch_offset = 0) { + return tensor->shape[batch_offset + (trans ? 1 : 0)]; +} -inline int ColumnCount(const DLTensor* tensor, bool trans) { return tensor->shape[trans ? 0 : 1]; } +inline int ColumnCount(const DLTensor* tensor, bool trans, int batch_offset = 0) { + return tensor->shape[batch_offset + (trans ? 0 : 1)]; +} // Call a column major blas. Note that data is stored in tvm as row // major, so this we switch the arguments. diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index cf6d3f8aa13b..b53e99660c5a 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -24,11 +24,10 @@ #include #include +#include "../../3rdparty/compiler-rt/builtin_fp16.h" #include "../cblas/gemm_common.h" #include "cublas_utils.h" -#include "../../3rdparty/compiler-rt/builtin_fp16.h" - namespace tvm { namespace contrib { @@ -136,25 +135,15 @@ int roundoff(int v, int d) { return (v + d - 1) / d * d; } #if CUDART_VERSION >= 10010 -void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, const DLTensor* bias, const DLTensor* C, - bool transa, bool transb, cublasLtEpilogue_t epilogue) { - ICHECK_EQ(A->ndim, 2); - ICHECK_EQ(B->ndim, 2); - ICHECK_EQ(C->ndim, 2); - ICHECK_EQ(ElementStride(A), 1); - ICHECK_EQ(ElementStride(B), 1); - ICHECK_EQ(ElementStride(C), 1); +void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, const DLTensor* bias, + const DLTensor* C, bool transa, bool transb, cublasLtEpilogue_t epilogue) { ICHECK(TypeEqual(A->dtype, B->dtype)); - // Reversed strides indicates an in-place transpose operation. transa = IsInPlaceTransposed(A) ? !transa : transa; transb = IsInPlaceTransposed(B) ? !transb : transb; - int M = ColumnCount(B, transb); - int N = RowCount(A, transa); - int K = ColumnCount(A, transa); - - cublasLtMatmulDesc_t operationDesc = nullptr; + cublasOperation_t opTransA = CUBLASBooleanToTranspose(transa); + cublasOperation_t opTransB = CUBLASBooleanToTranspose(transb); auto compute_type = CUBLAS_COMPUTE_32F; auto scale_type = CUDA_R_32F; cudaDataType_t ab_type = CUDA_R_32F; @@ -178,45 +167,75 @@ void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, co beta = &zero_fp16; } + cublasLtMatmulDesc_t operationDesc = nullptr; CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&operationDesc, compute_type, scale_type)); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, + &opTransB, sizeof(opTransA))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, + &opTransA, sizeof(opTransB))); if (bias != nullptr) { CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - operationDesc, - CUBLASLT_MATMUL_DESC_BIAS_POINTER, - &bias->data, - sizeof(float*))); + operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias->data, sizeof(float*))); } if (epilogue != CUBLASLT_EPILOGUE_DEFAULT) { - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - operationDesc, - CUBLASLT_MATMUL_DESC_EPILOGUE, - &epilogue, - sizeof(epilogue))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, + &epilogue, sizeof(epilogue))); } - cublasOperation_t opTransA = CUBLASBooleanToTranspose(transa); - cublasOperation_t opTransB = CUBLASBooleanToTranspose(transb); - - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, - &opTransB, sizeof(opTransA))); - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, - &opTransA, sizeof(opTransB))); + cublasLtMatrixLayout_t Adesc = nullptr; + cublasLtMatrixLayout_t Bdesc = nullptr; + cublasLtMatrixLayout_t Cdesc = nullptr; - cublasLtMatrixLayout_t Adesc, Bdesc, Cdesc; + auto set_batch = [](cublasLtMatrixLayout_t mat_desc, int batch_count, int64_t batch_stride) { + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + mat_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(batch_count))); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(mat_desc, + CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &batch_stride, sizeof(batch_stride))); + }; - int lda = opTransB == CUBLAS_OP_N? M : K; - int ldb = opTransA == CUBLAS_OP_N? K : N; - int ldc = M; + int batch_offset_A = A->ndim - 2; + int batch_offset_B = B->ndim - 2; - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Adesc, ab_type, opTransB == CUBLAS_OP_N ? M : K, - opTransB == CUBLAS_OP_N ? K : M, lda)); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Bdesc, ab_type, opTransA == CUBLAS_OP_N ? K : N, - opTransA == CUBLAS_OP_N ? N : K, ldb)); + int M = ColumnCount(B, transb, batch_offset_B); + int N = RowCount(A, transa, batch_offset_A); + int K = ColumnCount(A, transa, batch_offset_A); + int lda = transb ? K : M; + int ldb = transa ? N : K; + int ldc = M; + CHECK_CUBLAS_ERROR( + cublasLtMatrixLayoutCreate(&Adesc, ab_type, !transb ? M : K, !transb ? K : M, lda)); + CHECK_CUBLAS_ERROR( + cublasLtMatrixLayoutCreate(&Bdesc, ab_type, !transa ? K : N, !transa ? N : K, ldb)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Cdesc, c_type, M, N, ldc)); + if (A->ndim != 2 || B->ndim != 2) { + auto get_batch_count = [](int64_t* shape, int batch_offset) { + int64_t count = 1; + for (int i = 0; i < batch_offset; ++i) { + count *= shape[i]; + } + return count; + }; + int batch_count_A = get_batch_count(A->shape, batch_offset_A); + int batch_count_B = get_batch_count(B->shape, batch_offset_B); + int batch_count_C = get_batch_count(C->shape, C->ndim - 2); + int64_t batch_stride_A = M * K; + int64_t batch_stride_B = K * N; + int64_t batch_stride_C = M * N; + + // CublasLT does not seem to support batched GEMM with one of matrices having + // one batch (with batch_stride 0). + ICHECK_EQ(batch_count_A, batch_count_B); + + set_batch(Adesc, batch_count_A, batch_stride_A); + set_batch(Bdesc, batch_count_B, batch_stride_B); + set_batch(Cdesc, batch_count_C, batch_stride_C); + } + auto A_data = static_cast(A->data) + A->byte_offset; auto B_data = static_cast(B->data) + B->byte_offset; auto C_data = static_cast(C->data) + C->byte_offset; diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc b/src/runtime/contrib/cublas/cublas_json_runtime.cc index c6cee4f801c2..306d8161a59e 100644 --- a/src/runtime/contrib/cublas/cublas_json_runtime.cc +++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc @@ -32,7 +32,6 @@ #include "../json/json_node.h" #include "../json/json_runtime.h" - #include "cublas_utils.h" namespace tvm { @@ -46,78 +45,84 @@ class CublasJSONRuntime : public JSONRuntimeBase { public: CublasJSONRuntime(const std::string& symbol_name, const std::string& graph_json, const Array const_names) - : JSONRuntimeBase(symbol_name, graph_json, const_names) { - } + : JSONRuntimeBase(symbol_name, graph_json, const_names) {} - void Init(const Array& consts) override { - } + void Init(const Array& consts) override {} - void Run() override{ + void Run() override { cublasLtHandle_t handle; cublasLtCreate(&handle); - for (size_t i = 0; i < nodes_.size(); ++i) { + for (size_t i = 0; i < nodes_.size(); ++i) { const auto& node = nodes_[i]; if (node.GetOpType() == "kernel") { auto op_name = node.GetOpName(); if (op_name == "cublas.matmul") { uint32_t output_eid = EntryID(outputs_[0]); - auto out_ptr = data_entry_[output_eid]; - auto a_ptr = GetInput(node, 0); - auto b_ptr = GetInput(node, 1); - tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, nullptr, out_ptr, false, false, CUBLASLT_EPILOGUE_DEFAULT); - } else if (op_name == "cublas.matmul_transposed") { + auto out_ptr = data_entry_[output_eid]; + auto a_ptr = GetInput(node, 0); + auto b_ptr = GetInput(node, 1); + tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, nullptr, out_ptr, false, false, + CUBLASLT_EPILOGUE_DEFAULT); + } else if (op_name == "cublas.matmul_transposed") { uint32_t output_eid = EntryID(outputs_[0]); - auto out_ptr = data_entry_[output_eid]; - // TODO: fix - auto a_ptr = GetInput(node, 1); - auto b_ptr = GetInput(node, 0); - tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, nullptr, out_ptr, false, true, CUBLASLT_EPILOGUE_DEFAULT); - } else if (op_name == "cublas.matmul_bias") { + auto out_ptr = data_entry_[output_eid]; + // TODO: fix + auto a_ptr = GetInput(node, 1); + auto b_ptr = GetInput(node, 0); + tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, nullptr, out_ptr, false, true, + CUBLASLT_EPILOGUE_DEFAULT); + } else if (op_name == "cublas.matmul_bias") { uint32_t output_eid = EntryID(outputs_[0]); - auto out_ptr = data_entry_[output_eid]; - auto a_ptr = GetInput(node, 0); - auto b_ptr = GetInput(node, 1); + auto out_ptr = data_entry_[output_eid]; + auto a_ptr = GetInput(node, 0); + auto b_ptr = GetInput(node, 1); auto bias_ptr = GetInput(node, 2); - tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, false, false, CUBLASLT_EPILOGUE_BIAS); - } else if (op_name == "cublas.matmul_bias_relu") { + tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, false, false, + CUBLASLT_EPILOGUE_BIAS); + } else if (op_name == "cublas.matmul_bias_relu") { uint32_t output_eid = EntryID(outputs_[0]); - auto out_ptr = data_entry_[output_eid]; - auto a_ptr = GetInput(node, 0); - auto b_ptr = GetInput(node, 1); + auto out_ptr = data_entry_[output_eid]; + auto a_ptr = GetInput(node, 0); + auto b_ptr = GetInput(node, 1); auto bias_ptr = GetInput(node, 2); - tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, false, false, CUBLASLT_EPILOGUE_RELU_BIAS); - } else if (op_name == "cublas.matmul_bias_gelu") { + tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, false, false, + CUBLASLT_EPILOGUE_RELU_BIAS); + } else if (op_name == "cublas.matmul_bias_gelu") { uint32_t output_eid = EntryID(outputs_[0]); - auto out_ptr = data_entry_[output_eid]; - auto a_ptr = GetInput(node, 0); - auto b_ptr = GetInput(node, 1); + auto out_ptr = data_entry_[output_eid]; + auto a_ptr = GetInput(node, 0); + auto b_ptr = GetInput(node, 1); auto bias_ptr = GetInput(node, 2); - tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, false, false, CUBLASLT_EPILOGUE_GELU_BIAS); - } else if (op_name == "cublas.matmul_transposed_bias") { + tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, false, false, + CUBLASLT_EPILOGUE_GELU_BIAS); + } else if (op_name == "cublas.matmul_transposed_bias") { uint32_t output_eid = EntryID(outputs_[0]); - auto out_ptr = data_entry_[output_eid]; - auto a_ptr = GetInput(node, 1); - auto b_ptr = GetInput(node, 0); + auto out_ptr = data_entry_[output_eid]; + auto a_ptr = GetInput(node, 1); + auto b_ptr = GetInput(node, 0); auto bias_ptr = GetInput(node, 2); - tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, false, true, CUBLASLT_EPILOGUE_BIAS); - } else if (op_name == "cublas.matmul_transposed_bias_relu") { + tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, false, true, + CUBLASLT_EPILOGUE_BIAS); + } else if (op_name == "cublas.matmul_transposed_bias_relu") { uint32_t output_eid = EntryID(outputs_[0]); - auto out_ptr = data_entry_[output_eid]; - auto a_ptr = GetInput(node, 1); - auto b_ptr = GetInput(node, 0); + auto out_ptr = data_entry_[output_eid]; + auto a_ptr = GetInput(node, 1); + auto b_ptr = GetInput(node, 0); auto bias_ptr = GetInput(node, 2); - tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, false, true, CUBLASLT_EPILOGUE_RELU_BIAS); - } else if (op_name == "cublas.matmul_transposed_bias_gelu") { + tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, false, true, + CUBLASLT_EPILOGUE_RELU_BIAS); + } else if (op_name == "cublas.matmul_transposed_bias_gelu") { uint32_t output_eid = EntryID(outputs_[0]); - auto out_ptr = data_entry_[output_eid]; - auto a_ptr = GetInput(node, 1); - auto b_ptr = GetInput(node, 0); + auto out_ptr = data_entry_[output_eid]; + auto a_ptr = GetInput(node, 1); + auto b_ptr = GetInput(node, 0); auto bias_ptr = GetInput(node, 2); - tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, false, true, CUBLASLT_EPILOGUE_GELU_BIAS); - } else { - LOG(FATAL) << op_name; - } + tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, false, true, + CUBLASLT_EPILOGUE_GELU_BIAS); + } else { + LOG(FATAL) << op_name; + } } } } diff --git a/src/runtime/contrib/cublas/cublas_utils.h b/src/runtime/contrib/cublas/cublas_utils.h index 1cd327f3b908..ae730b67a945 100644 --- a/src/runtime/contrib/cublas/cublas_utils.h +++ b/src/runtime/contrib/cublas/cublas_utils.h @@ -105,8 +105,9 @@ inline cudaDataType_t GetCudaDataType(DLDataType type) { LOG(FATAL) << "Unsupported cuda type"; } -void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, const DLTensor* bias, const DLTensor* C, - bool transa, bool transb, cublasLtEpilogue_t epilogue=CUBLASLT_EPILOGUE_DEFAULT); +void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, const DLTensor* bias, + const DLTensor* C, bool transa, bool transb, + cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT); } // namespace contrib } // namespace tvm diff --git a/tests/python/relax/test_codegen_cublas.py b/tests/python/relax/test_codegen_cublas.py index abc78dff0052..67484242a778 100644 --- a/tests/python/relax/test_codegen_cublas.py +++ b/tests/python/relax/test_codegen_cublas.py @@ -21,7 +21,6 @@ import tvm.testing import tvm.topi.testing from tvm import relax -from tvm.contrib.pickle_memoize import memoize from tvm.relax.backend.contrib.cublas import partition_for_cublas from tvm.script import relax as R @@ -134,25 +133,11 @@ def _to_concrete_shape(symbolic_shape, var_table): ((4, 16), (16, 128), True, "relu"), ((35, 8), (8, 8), True, "gelu"), # # 3D x 3D - # ((6, 32, 8), (6, 8, 10), False, "bias"), - # ((6, 32, 8), (6, 8, 10), True, "none"), - # ((_vars["a"], 32, 8), (_vars["a"], 8, 10), True, "gelu"), - # # 3D x 2D - # ((6, 32, 8), (8, 10), False, "none"), - # ((_vars["a"], 32, 8), (8, 10), False, "bias"), - # ((10, 16, 8), (8, 10), True, "relu"), - # # 2D x 3D - # ((32, 8), (10, 8, 10), False, "relu"), - # ((32, 8), (_vars["a"], 8, 10), True, "gelu"), - # # ND x 2D - # ((3, 6, 32, 8), (8, 10), False, "bias"), - # ((_vars["a"], _vars["b"], 6, 32, 8), (8, 10), False, "none"), - # # 2D x ND - # ((32, 8), (5, 3, 8, 10), False, "gelu"), - # # ND x ND - # ((5, 3, 32, 8), (5, 3, 8, 10), True, "relu"), - # ((3, 2, 4, 16, 15), (1, 1, 15, 2), True, "gelu"), - # ((1, 1, 16, 15), (3, 2, _vars["a"], 15, 2), False, "none"), + ((6, 32, 8), (6, 8, 10), False, "bias"), + ((6, 32, 8), (6, 8, 10), True, "none"), + ((_vars["a"], 32, 8), (_vars["a"], 8, 10), True, "gelu"), + # ND x ND + ((5, 3, 32, 8), (5, 3, 8, 10), True, "relu"), ], ) @pytest.mark.parametrize( @@ -200,13 +185,7 @@ def test_matmul_offload( ref = build_and_run(mod, args, "llvm", legalize=True) tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) - print("ok") if __name__ == "__main__": - # tvm.testing.main() - test_matmul_offload((32, 8), (8, 16), False, "none", "float16") - # test_matmul_offload((32, 8), (8, 16), True, "relu", "float32") - # test_matmul_offload((32, 8), (8, 16), False, "bias", "float32") - # test_matmul_offload((32, 8), (8, 16), False, "gelu", "float32") - # test_matmul_offload((_vars["a"], 8), (8, 16), False, "relu", "float32") + tvm.testing.main() From bc0b2c57b116479d8711f8ad0508321fc6946cf0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 14 Mar 2023 10:42:17 +0900 Subject: [PATCH 10/15] clean up --- python/tvm/relax/backend/contrib/cublas.py | 2 +- src/relax/backend/contrib/cublas/codegen.cc | 32 +++++-- src/runtime/contrib/cublas/cublas.cc | 68 ++++++------- .../contrib/cublas/cublas_json_runtime.cc | 96 ++++++------------- src/runtime/contrib/cublas/cublas_utils.h | 1 + 5 files changed, 92 insertions(+), 107 deletions(-) diff --git a/python/tvm/relax/backend/contrib/cublas.py b/python/tvm/relax/backend/contrib/cublas.py index 3df205dccbd0..d7e32f241dd2 100644 --- a/python/tvm/relax/backend/contrib/cublas.py +++ b/python/tvm/relax/backend/contrib/cublas.py @@ -69,7 +69,7 @@ def _check_matmul( lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1) rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1) - # CublasLT does not seem to support batched GEMM with one of matrices having + # cuBLASLt does not seem to support batched GEMM with one of matrices having # one batch (with batch_stride 0). So for batched GEMM, the two batch counts # must be equal. return ( diff --git a/src/relax/backend/contrib/cublas/codegen.cc b/src/relax/backend/contrib/cublas/codegen.cc index 8fb23f29c0ff..74508ca002b8 100644 --- a/src/relax/backend/contrib/cublas/codegen.cc +++ b/src/relax/backend/contrib/cublas/codegen.cc @@ -22,7 +22,6 @@ * \brief Implementation of the CUBLAS JSON serializer. */ #include - #include #include "../codegen_json/codegen_json.h" @@ -37,6 +36,15 @@ using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; using JSONSerializer = backend::contrib::JSONSerializer; using backend::contrib::NodeEntries; + +Map ExtractArgIdx(String pattern_name, Function f) { + Map arg_idx; + arg_idx.Set("lhs", IntImm(DataType::Int(64), 1)); + arg_idx.Set("rhs", IntImm(DataType::Int(64), 0)); + arg_idx.Set("bias", IntImm(DataType::Int(64), 2)); + return arg_idx; +} + class CublasJSONSerializer : public JSONSerializer { public: CublasJSONSerializer(Map constant_names, Map bindings) @@ -55,21 +63,27 @@ class CublasJSONSerializer : public JSONSerializer { std::string composite_name = composite_opt.value(); - NodeEntries inputs; + NodeEntries inputs_tmp; for (const auto& arg : call_node->args) { auto res = VisitExpr(arg); - inputs.insert(inputs.end(), res.begin(), res.end()); + inputs_tmp.insert(inputs_tmp.end(), res.begin(), res.end()); } + + ICHECK(inputs_tmp.size() <= 3); + NodeEntries inputs(inputs_tmp.size()); + + auto arg_idx = ExtractArgIdx(composite_name, fn); + inputs[0] = inputs_tmp[arg_idx["lhs"]->value]; + inputs[1] = inputs_tmp[arg_idx["rhs"]->value]; + if (inputs_tmp.size() == 3) { + inputs[2] = inputs_tmp[arg_idx["bias"]->value]; + } + auto node = std::make_shared(composite_name, /* name_ */ "kernel", /* op_type_ */ inputs, 1 /* num_outputs_ */); - const CallNode* root_call = nullptr; - if (composite_name.find("matmul") != std::string::npos) { - root_call = backend::GetOpInFunction(fn, "relax.matmul"); - } else { - LOG(FATAL) << "Unimplemented pattern: " << composite_name; - } + const CallNode* root_call = backend::GetOpInFunction(fn, "relax.matmul");; SetCallNodeAttribute(node, root_call); return AddNode(node, GetRef(call_node)); diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index b53e99660c5a..b49f15008cff 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -142,8 +142,6 @@ void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, co transa = IsInPlaceTransposed(A) ? !transa : transa; transb = IsInPlaceTransposed(B) ? !transb : transb; - cublasOperation_t opTransA = CUBLASBooleanToTranspose(transa); - cublasOperation_t opTransB = CUBLASBooleanToTranspose(transb); auto compute_type = CUBLAS_COMPUTE_32F; auto scale_type = CUDA_R_32F; cudaDataType_t ab_type = CUDA_R_32F; @@ -167,35 +165,26 @@ void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, co beta = &zero_fp16; } - cublasLtMatmulDesc_t operationDesc = nullptr; - CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&operationDesc, compute_type, scale_type)); - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, - &opTransB, sizeof(opTransA))); - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, - &opTransA, sizeof(opTransB))); + cublasLtMatmulDesc_t op_desc; + cublasOperation_t op_transa = CUBLASBooleanToTranspose(transa); + cublasOperation_t op_transb = CUBLASBooleanToTranspose(transb); + + CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&op_desc, compute_type, scale_type)); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_TRANSA, + &op_transb, sizeof(op_transa))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_TRANSB, + &op_transa, sizeof(op_transb))); if (bias != nullptr) { - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias->data, sizeof(float*))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, + &bias->data, sizeof(float*))); } if (epilogue != CUBLASLT_EPILOGUE_DEFAULT) { - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); } - cublasLtMatrixLayout_t Adesc = nullptr; - cublasLtMatrixLayout_t Bdesc = nullptr; - cublasLtMatrixLayout_t Cdesc = nullptr; - - auto set_batch = [](cublasLtMatrixLayout_t mat_desc, int batch_count, int64_t batch_stride) { - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( - mat_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(batch_count))); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(mat_desc, - CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, - &batch_stride, sizeof(batch_stride))); - }; - int batch_offset_A = A->ndim - 2; int batch_offset_B = B->ndim - 2; @@ -206,11 +195,13 @@ void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, co int lda = transb ? K : M; int ldb = transa ? N : K; int ldc = M; + + cublasLtMatrixLayout_t A_desc, B_desc, C_desc; CHECK_CUBLAS_ERROR( - cublasLtMatrixLayoutCreate(&Adesc, ab_type, !transb ? M : K, !transb ? K : M, lda)); + cublasLtMatrixLayoutCreate(&A_desc, ab_type, !transb ? M : K, !transb ? K : M, lda)); CHECK_CUBLAS_ERROR( - cublasLtMatrixLayoutCreate(&Bdesc, ab_type, !transa ? K : N, !transa ? N : K, ldb)); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Cdesc, c_type, M, N, ldc)); + cublasLtMatrixLayoutCreate(&B_desc, ab_type, !transa ? K : N, !transa ? N : K, ldb)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&C_desc, c_type, M, N, ldc)); if (A->ndim != 2 || B->ndim != 2) { auto get_batch_count = [](int64_t* shape, int batch_offset) { @@ -220,6 +211,14 @@ void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, co } return count; }; + auto set_batch = [](cublasLtMatrixLayout_t mat_desc, int batch_count, int64_t batch_stride) { + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + mat_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(batch_count))); + CHECK_CUBLAS_ERROR( + cublasLtMatrixLayoutSetAttribute(mat_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &batch_stride, sizeof(batch_stride))); + }; + int batch_count_A = get_batch_count(A->shape, batch_offset_A); int batch_count_B = get_batch_count(B->shape, batch_offset_B); int batch_count_C = get_batch_count(C->shape, C->ndim - 2); @@ -227,21 +226,26 @@ void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, co int64_t batch_stride_B = K * N; int64_t batch_stride_C = M * N; - // CublasLT does not seem to support batched GEMM with one of matrices having + // cuBLASLt does not seem to support batched GEMM with one of matrices having // one batch (with batch_stride 0). ICHECK_EQ(batch_count_A, batch_count_B); - set_batch(Adesc, batch_count_A, batch_stride_A); - set_batch(Bdesc, batch_count_B, batch_stride_B); - set_batch(Cdesc, batch_count_C, batch_stride_C); + set_batch(A_desc, batch_count_A, batch_stride_A); + set_batch(B_desc, batch_count_B, batch_stride_B); + set_batch(C_desc, batch_count_C, batch_stride_C); } auto A_data = static_cast(A->data) + A->byte_offset; auto B_data = static_cast(B->data) + B->byte_offset; auto C_data = static_cast(C->data) + C->byte_offset; - CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl, operationDesc, alpha, B_data, Adesc, A_data, Bdesc, beta, - C_data, Cdesc, C_data, Cdesc, nullptr, nullptr, 0, nullptr)); + CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl, op_desc, alpha, B_data, A_desc, A_data, B_desc, beta, + C_data, C_desc, C_data, C_desc, nullptr, nullptr, 0, nullptr)); + + cublasLtMatmulDescDestroy(op_desc); + cublasLtMatrixLayoutDestroy(A_desc); + cublasLtMatrixLayoutDestroy(B_desc); + cublasLtMatrixLayoutDestroy(C_desc); } inline void CallLtIgemm(TVMArgs args, TVMRetValue* ret, cublasLtHandle_t hdl) { diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc b/src/runtime/contrib/cublas/cublas_json_runtime.cc index 306d8161a59e..8afccb27302d 100644 --- a/src/runtime/contrib/cublas/cublas_json_runtime.cc +++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc @@ -50,6 +50,7 @@ class CublasJSONRuntime : public JSONRuntimeBase { void Init(const Array& consts) override {} void Run() override { + // TODO(masahi): Reuse the same handle across different subgraphs cublasLtHandle_t handle; cublasLtCreate(&handle); @@ -57,74 +58,39 @@ class CublasJSONRuntime : public JSONRuntimeBase { const auto& node = nodes_[i]; if (node.GetOpType() == "kernel") { auto op_name = node.GetOpName(); - if (op_name == "cublas.matmul") { - uint32_t output_eid = EntryID(outputs_[0]); - auto out_ptr = data_entry_[output_eid]; - auto a_ptr = GetInput(node, 0); - auto b_ptr = GetInput(node, 1); - tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, nullptr, out_ptr, false, false, - CUBLASLT_EPILOGUE_DEFAULT); - } else if (op_name == "cublas.matmul_transposed") { - uint32_t output_eid = EntryID(outputs_[0]); - auto out_ptr = data_entry_[output_eid]; - // TODO: fix - auto a_ptr = GetInput(node, 1); - auto b_ptr = GetInput(node, 0); - tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, nullptr, out_ptr, false, true, - CUBLASLT_EPILOGUE_DEFAULT); - } else if (op_name == "cublas.matmul_bias") { - uint32_t output_eid = EntryID(outputs_[0]); - auto out_ptr = data_entry_[output_eid]; - auto a_ptr = GetInput(node, 0); - auto b_ptr = GetInput(node, 1); - auto bias_ptr = GetInput(node, 2); - tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, false, false, - CUBLASLT_EPILOGUE_BIAS); - } else if (op_name == "cublas.matmul_bias_relu") { - uint32_t output_eid = EntryID(outputs_[0]); - auto out_ptr = data_entry_[output_eid]; - auto a_ptr = GetInput(node, 0); - auto b_ptr = GetInput(node, 1); - auto bias_ptr = GetInput(node, 2); - tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, false, false, - CUBLASLT_EPILOGUE_RELU_BIAS); - } else if (op_name == "cublas.matmul_bias_gelu") { - uint32_t output_eid = EntryID(outputs_[0]); - auto out_ptr = data_entry_[output_eid]; - auto a_ptr = GetInput(node, 0); - auto b_ptr = GetInput(node, 1); - auto bias_ptr = GetInput(node, 2); - tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, false, false, - CUBLASLT_EPILOGUE_GELU_BIAS); - } else if (op_name == "cublas.matmul_transposed_bias") { - uint32_t output_eid = EntryID(outputs_[0]); - auto out_ptr = data_entry_[output_eid]; - auto a_ptr = GetInput(node, 1); - auto b_ptr = GetInput(node, 0); - auto bias_ptr = GetInput(node, 2); - tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, false, true, - CUBLASLT_EPILOGUE_BIAS); - } else if (op_name == "cublas.matmul_transposed_bias_relu") { - uint32_t output_eid = EntryID(outputs_[0]); - auto out_ptr = data_entry_[output_eid]; - auto a_ptr = GetInput(node, 1); - auto b_ptr = GetInput(node, 0); - auto bias_ptr = GetInput(node, 2); - tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, false, true, - CUBLASLT_EPILOGUE_RELU_BIAS); - } else if (op_name == "cublas.matmul_transposed_bias_gelu") { - uint32_t output_eid = EntryID(outputs_[0]); - auto out_ptr = data_entry_[output_eid]; - auto a_ptr = GetInput(node, 1); - auto b_ptr = GetInput(node, 0); - auto bias_ptr = GetInput(node, 2); - tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, false, true, - CUBLASLT_EPILOGUE_GELU_BIAS); - } else { - LOG(FATAL) << op_name; + uint32_t output_eid = EntryID(outputs_[0]); + auto out_ptr = data_entry_[output_eid]; + bool transa = false; + bool transb = false; + cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; + + if (op_name.find("transposed") != std::string::npos) { + transb = true; + } + + if (op_name.find("relu") != std::string::npos) { + epilogue = CUBLASLT_EPILOGUE_RELU_BIAS; + } else if (op_name.find("gelu") != std::string::npos) { + epilogue = CUBLASLT_EPILOGUE_GELU_BIAS; + } else if (op_name.find("bias") != std::string::npos) { + epilogue = CUBLASLT_EPILOGUE_BIAS; } + + auto get_inputs = [this](const JSONGraphNode& node, bool has_bias) { + const DLTensor* bias = nullptr; + if (has_bias) { + bias = GetInput(node, 2); + } + return std::make_tuple(GetInput(node, 0), GetInput(node, 1), bias); + }; + + auto [a_ptr, b_ptr, bias_ptr] = get_inputs(node, epilogue != CUBLASLT_EPILOGUE_DEFAULT); + + tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, transa, transb, + epilogue); } } + cublasLtDestroy(handle); } private: diff --git a/src/runtime/contrib/cublas/cublas_utils.h b/src/runtime/contrib/cublas/cublas_utils.h index ae730b67a945..ac03b1210366 100644 --- a/src/runtime/contrib/cublas/cublas_utils.h +++ b/src/runtime/contrib/cublas/cublas_utils.h @@ -105,6 +105,7 @@ inline cudaDataType_t GetCudaDataType(DLDataType type) { LOG(FATAL) << "Unsupported cuda type"; } +/*! \brief Execute matrix multiply followed by the specified epilogue, using cuBLASLt. */ void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, const DLTensor* bias, const DLTensor* C, bool transa, bool transb, cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT); From 021f6ae5002a6c4814cacf2a3ad681d53898ebc3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 14 Mar 2023 12:22:58 +0900 Subject: [PATCH 11/15] access arguments properly --- src/relax/backend/contrib/cublas/codegen.cc | 34 +++++++++++++++++---- 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/src/relax/backend/contrib/cublas/codegen.cc b/src/relax/backend/contrib/cublas/codegen.cc index 74508ca002b8..3d61ec6f1368 100644 --- a/src/relax/backend/contrib/cublas/codegen.cc +++ b/src/relax/backend/contrib/cublas/codegen.cc @@ -22,8 +22,11 @@ * \brief Implementation of the CUBLAS JSON serializer. */ #include +#include + #include +#include "../../pattern_registry.h" #include "../codegen_json/codegen_json.h" #include "../utils.h" @@ -36,12 +39,32 @@ using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; using JSONSerializer = backend::contrib::JSONSerializer; using backend::contrib::NodeEntries; - Map ExtractArgIdx(String pattern_name, Function f) { Map arg_idx; - arg_idx.Set("lhs", IntImm(DataType::Int(64), 1)); - arg_idx.Set("rhs", IntImm(DataType::Int(64), 0)); - arg_idx.Set("bias", IntImm(DataType::Int(64), 2)); + auto pattern = backend::GetPattern(pattern_name); + ICHECK(pattern) << "Unsupported op_type " << pattern_name; + + auto bindings = AnalyzeVar2Value(f); + auto inner_body = Downcast(f->body)->body; + auto matched_expr = relax::ExtractMatchedExpr(pattern.value()->pattern, inner_body, bindings); + ICHECK(matched_expr); + + auto find_index = [](const Array& params, Var v) { + for (size_t i = 0; i < params.size(); ++i) { + if (params[i] == v) { + return i; + } + } + LOG(FATAL) << "Variable not found " << v; + return size_t(0); + }; + + for (const auto& [name, pat] : pattern.value()->arg_patterns) { + auto arg_var = matched_expr.value()[pat]; + auto idx = find_index(f->params, Downcast(arg_var)); + arg_idx.Set(name, IntImm(DataType::Int(64), idx)); + } + return arg_idx; } @@ -83,8 +106,7 @@ class CublasJSONSerializer : public JSONSerializer { "kernel", /* op_type_ */ inputs, 1 /* num_outputs_ */); - const CallNode* root_call = backend::GetOpInFunction(fn, "relax.matmul");; - + const CallNode* root_call = backend::GetOpInFunction(fn, "relax.matmul"); SetCallNodeAttribute(node, root_call); return AddNode(node, GetRef(call_node)); } From 57db54f85ad1c606f3bb10f2afd855fc5363e241 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 14 Mar 2023 12:51:46 +0900 Subject: [PATCH 12/15] expose ExtractArgIdx to python and use it in cutlass byoc --- python/tvm/contrib/cutlass/build.py | 7 ++- src/relax/backend/contrib/cublas/codegen.cc | 33 +---------- src/relax/backend/contrib/utils.cc | 64 +++++++++++++++++++++ src/relax/backend/contrib/utils.h | 13 +++++ 4 files changed, 84 insertions(+), 33 deletions(-) create mode 100644 src/relax/backend/contrib/utils.cc diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 93d1331ac443..48ee933d1a82 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -575,7 +575,12 @@ def _extract_arg_idx(pattern_name, f): continue arg_idx[name] = func_args.index(arg_expr) - return arg_idx +# return arg_idx + + # TODO fix + extract_func = tvm.get_global_func("relax.contrib.extract_arg_idx") + arg_indices = extract_func(pattern_name, f) + return {k: int(v) for k, v in arg_indices.items()} def is_shape_valid_for_cutlass_matmul( diff --git a/src/relax/backend/contrib/cublas/codegen.cc b/src/relax/backend/contrib/cublas/codegen.cc index 3d61ec6f1368..e573d9a12385 100644 --- a/src/relax/backend/contrib/cublas/codegen.cc +++ b/src/relax/backend/contrib/cublas/codegen.cc @@ -22,11 +22,9 @@ * \brief Implementation of the CUBLAS JSON serializer. */ #include -#include #include -#include "../../pattern_registry.h" #include "../codegen_json/codegen_json.h" #include "../utils.h" @@ -39,35 +37,6 @@ using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; using JSONSerializer = backend::contrib::JSONSerializer; using backend::contrib::NodeEntries; -Map ExtractArgIdx(String pattern_name, Function f) { - Map arg_idx; - auto pattern = backend::GetPattern(pattern_name); - ICHECK(pattern) << "Unsupported op_type " << pattern_name; - - auto bindings = AnalyzeVar2Value(f); - auto inner_body = Downcast(f->body)->body; - auto matched_expr = relax::ExtractMatchedExpr(pattern.value()->pattern, inner_body, bindings); - ICHECK(matched_expr); - - auto find_index = [](const Array& params, Var v) { - for (size_t i = 0; i < params.size(); ++i) { - if (params[i] == v) { - return i; - } - } - LOG(FATAL) << "Variable not found " << v; - return size_t(0); - }; - - for (const auto& [name, pat] : pattern.value()->arg_patterns) { - auto arg_var = matched_expr.value()[pat]; - auto idx = find_index(f->params, Downcast(arg_var)); - arg_idx.Set(name, IntImm(DataType::Int(64), idx)); - } - - return arg_idx; -} - class CublasJSONSerializer : public JSONSerializer { public: CublasJSONSerializer(Map constant_names, Map bindings) @@ -95,7 +64,7 @@ class CublasJSONSerializer : public JSONSerializer { ICHECK(inputs_tmp.size() <= 3); NodeEntries inputs(inputs_tmp.size()); - auto arg_idx = ExtractArgIdx(composite_name, fn); + auto arg_idx = backend::ExtractArgIdx(composite_name, fn); inputs[0] = inputs_tmp[arg_idx["lhs"]->value]; inputs[1] = inputs_tmp[arg_idx["rhs"]->value]; if (inputs_tmp.size() == 3) { diff --git a/src/relax/backend/contrib/utils.cc b/src/relax/backend/contrib/utils.cc new file mode 100644 index 000000000000..e0a93e93dac6 --- /dev/null +++ b/src/relax/backend/contrib/utils.cc @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "utils.h" + +#include +#include +#include + +#include "../pattern_registry.h" + +namespace tvm { +namespace relax { +namespace backend { + +Map ExtractArgIdx(String pattern_name, Function f) { + Map arg_idx; + auto pattern = backend::GetPattern(pattern_name); + ICHECK(pattern) << "Unsupported op_type " << pattern_name; + + auto bindings = AnalyzeVar2Value(f); + auto inner_body = Downcast(f->body)->body; + auto matched_expr = relax::ExtractMatchedExpr(pattern.value()->pattern, inner_body, bindings); + ICHECK(matched_expr); + + auto find_index = [](const Array& params, Var v) { + for (size_t i = 0; i < params.size(); ++i) { + if (params[i] == v) { + return i; + } + } + LOG(FATAL) << "Variable not found " << v; + return size_t(0); + }; + + for (const auto& [name, pat] : pattern.value()->arg_patterns) { + auto arg_var = matched_expr.value()[pat]; + auto idx = find_index(f->params, Downcast(arg_var)); + arg_idx.Set(name, IntImm(DataType::Int(64), idx)); + } + + return arg_idx; +} + +TVM_REGISTER_GLOBAL("relax.contrib.extract_arg_idx").set_body_typed(ExtractArgIdx); + +} // namespace backend +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/contrib/utils.h b/src/relax/backend/contrib/utils.h index 4190ad66b6df..ee1240aaed2e 100644 --- a/src/relax/backend/contrib/utils.h +++ b/src/relax/backend/contrib/utils.h @@ -120,6 +120,19 @@ inline const CallNode* GetOpInFunction(Function f, const std::string& op_name) { return nullptr; } +/*! + * \brief Extract indices of the argument patterns in the function parameter list. + * Each composite function pattern can register a mapping between variable names and the + * corresponding patterns. This function tells at which index a given parameter + * in the function pattern, identified by its name, appears in the partitioned function parameter + * list. + * \param pattern_name The name the composite function pattern. + * \param f The function partitioned according to the function pattern. + * \return A mapping between variable pattern names and their positions in the partitioned + * function parameter list. + */ +Map ExtractArgIdx(String pattern_name, Function f); + } // namespace backend } // namespace relax } // namespace tvm From 5238694a6601e97ade8a11cc04593329ae723fa0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 14 Mar 2023 13:08:52 +0900 Subject: [PATCH 13/15] put matmul ir into common testing file --- python/tvm/relax/testing/__init__.py | 1 + python/tvm/relax/testing/matmul.py | 65 ++++++++++++++++++++++ tests/python/relax/test_codegen_cublas.py | 37 +----------- tests/python/relax/test_codegen_cutlass.py | 45 +-------------- 4 files changed, 68 insertions(+), 80 deletions(-) create mode 100644 python/tvm/relax/testing/matmul.py diff --git a/python/tvm/relax/testing/__init__.py b/python/tvm/relax/testing/__init__.py index a6e3a9425147..4256ebc3be89 100644 --- a/python/tvm/relax/testing/__init__.py +++ b/python/tvm/relax/testing/__init__.py @@ -20,3 +20,4 @@ from .nn import * from .relay_translator import * from .ast_printer import dump_ast +from .matmul import * diff --git a/python/tvm/relax/testing/matmul.py b/python/tvm/relax/testing/matmul.py new file mode 100644 index 000000000000..ca95c2f104d4 --- /dev/null +++ b/python/tvm/relax/testing/matmul.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Utilities to construct matmul workloads.""" +import tvm +from tvm.script import relax as R +from tvm.script.ir_builder import IRBuilder +from tvm.script.ir_builder import relax as relax_builder + + +def get_relax_matmul_module( + x_shape, + y_shape, + dtype, + transposed_y=False, + with_bias=False, + activation=None, + residual_bin_op=None, + residual_activation=None, +): + if transposed_y: + n = y_shape[-2] + else: + n = y_shape[-1] + + with IRBuilder() as builder: + with relax_builder.function(): + R.func_name("main") + x = R.arg("x", R.Tensor(x_shape, dtype)) + y = R.arg("y", R.Tensor(y_shape, dtype)) + if with_bias: + bias = R.arg("bias", R.Tensor((n,), dtype)) + + with R.dataflow() as frame: + if transposed_y: + axes = list(range(len(y_shape) - 2)) + [-1, -2] + y = R.emit(R.permute_dims(y, axes=axes)) + result = R.emit(R.matmul(x, y, out_dtype=dtype)) + if with_bias: + result = R.emit(result + bias) + if activation is not None: + result = R.emit(activation(result)) + if residual_bin_op is not None: + result = R.emit(residual_bin_op(result, x)) + if residual_activation is not None: + result = R.emit(residual_activation(result)) + R.output(result) + + R.func_ret_value(frame.output_vars[0]) + + func = builder.get() + return tvm.IRModule({"main": func}) diff --git a/tests/python/relax/test_codegen_cublas.py b/tests/python/relax/test_codegen_cublas.py index 67484242a778..023054256efd 100644 --- a/tests/python/relax/test_codegen_cublas.py +++ b/tests/python/relax/test_codegen_cublas.py @@ -22,6 +22,7 @@ import tvm.topi.testing from tvm import relax from tvm.relax.backend.contrib.cublas import partition_for_cublas +from tvm.relax.testing import get_relax_matmul_module from tvm.script import relax as R @@ -59,42 +60,6 @@ def get_result_with_relax_cublas_offload(mod, *args): return build_and_run(mod, args, "cuda") -def get_relax_matmul_module( - x_shape, y_shape, dtype, transposed_y=False, with_bias=False, activation=None -): - if transposed_y: - n = y_shape[-2] - else: - n = y_shape[-1] - - from tvm.script.ir_builder import IRBuilder - from tvm.script.ir_builder import relax as relax_builder - - with IRBuilder() as builder: - with relax_builder.function(): - R.func_name("main") - x = R.arg("x", R.Tensor(x_shape, dtype)) - y = R.arg("y", R.Tensor(y_shape, dtype)) - if with_bias: - bias = R.arg("bias", R.Tensor((n,), dtype)) - - with R.dataflow() as frame: - if transposed_y: - axes = list(range(len(y_shape) - 2)) + [-1, -2] - y = R.emit(R.permute_dims(y, axes=axes)) - result = R.emit(R.matmul(x, y, out_dtype=dtype)) - if with_bias: - result = R.emit(result + bias) - if activation is not None: - result = R.emit(activation(result)) - R.output(result) - - R.func_ret_value(frame.output_vars[0]) - - func = builder.get() - return tvm.IRModule({"main": func}) - - def _to_concrete_shape(symbolic_shape, var_table): result = [] for dim in symbolic_shape: diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 1c814294c94e..b9ba4f4dc9af 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -24,6 +24,7 @@ from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul from tvm.contrib.pickle_memoize import memoize from tvm.relax.backend.contrib.cutlass import partition_for_cutlass +from tvm.relax.testing import get_relax_matmul_module from tvm.script import relax as R from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import relax as relax_builder @@ -164,50 +165,6 @@ def get_relax_conv2d_module( return tvm.IRModule({"main": func}) -def get_relax_matmul_module( - x_shape, - y_shape, - dtype, - transposed_y=False, - with_bias=False, - activation=None, - residual_bin_op=None, - residual_activation=None, -): - if transposed_y: - n = y_shape[-2] - else: - n = y_shape[-1] - - with IRBuilder() as builder: - with relax_builder.function(): - R.func_name("main") - x = R.arg("x", R.Tensor(x_shape, dtype)) - y = R.arg("y", R.Tensor(y_shape, dtype)) - if with_bias: - bias = R.arg("bias", R.Tensor((n,), dtype)) - - with R.dataflow() as frame: - if transposed_y: - axes = list(range(len(y_shape) - 2)) + [-1, -2] - y = R.emit(R.permute_dims(y, axes=axes)) - result = R.emit(R.matmul(x, y, out_dtype=dtype)) - if with_bias: - result = R.emit(result + bias) - if activation is not None: - result = R.emit(activation(result)) - if residual_bin_op is not None: - result = R.emit(residual_bin_op(result, x)) - if residual_activation is not None: - result = R.emit(residual_activation(result)) - R.output(result) - - R.func_ret_value(frame.output_vars[0]) - - func = builder.get() - return tvm.IRModule({"main": func}) - - def _to_concrete_shape(symbolic_shape, var_table): result = [] for dim in symbolic_shape: From 4019e05c9f34b16b37090f9b4ffece9bb4b7b009 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 22 Mar 2023 18:11:31 +0900 Subject: [PATCH 14/15] updated for the latest rev --- python/tvm/contrib/cutlass/build.py | 18 -------------- python/tvm/relax/backend/contrib/cublas.py | 29 +++++----------------- src/relax/backend/contrib/utils.cc | 18 ++++++++------ 3 files changed, 17 insertions(+), 48 deletions(-) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 48ee933d1a82..43494991a04c 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -560,24 +560,6 @@ def _extract_relax_function_signature(f): def _extract_arg_idx(pattern_name, f): - pattern_entry = relax.backend.get_pattern(pattern_name) - if pattern_entry is None: - raise ValueError(f"Unsupported op_type {pattern_name}") - var2val = relax.analysis.get_var2val(f) - matched_expr = pattern_entry.pattern.extract_matched_expr(f.body.body, var2val) - - func_args = list(f.params) - - arg_idx = {} - for name, annotation_pattern in pattern_entry.annotation_patterns.items(): - arg_expr = matched_expr[annotation_pattern] - if arg_expr not in func_args: - continue - arg_idx[name] = func_args.index(arg_expr) - -# return arg_idx - - # TODO fix extract_func = tvm.get_global_func("relax.contrib.extract_arg_idx") arg_indices = extract_func(pattern_name, f) return {k: int(v) for k, v in arg_indices.items()} diff --git a/python/tvm/relax/backend/contrib/cublas.py b/python/tvm/relax/backend/contrib/cublas.py index d7e32f241dd2..b8c4952d8da1 100644 --- a/python/tvm/relax/backend/contrib/cublas.py +++ b/python/tvm/relax/backend/contrib/cublas.py @@ -19,14 +19,12 @@ import operator from functools import reduce -from typing import Mapping - import tvm -from tvm.relax import Call, Expr, transform -from tvm.relax.dpl import CallPattern, DFPattern +from tvm.relax import transform from ..pattern_registry import get_patterns_with_prefix, register_patterns from ..patterns import make_matmul_pattern +from tvm.relax.transform import PatternCheckContext def _is_supported_dtype(lhs_dtype, rhs_dtype): @@ -36,23 +34,9 @@ def _is_supported_dtype(lhs_dtype, rhs_dtype): ) -def _check_matmul( - match_result: Mapping[DFPattern, Expr], - _: Expr, -) -> bool: - matmul_call: Call = None - for pattern, expr in match_result.items(): - if ( - isinstance(expr, Call) - and isinstance(pattern, CallPattern) - and isinstance(expr.op, tvm.ir.Op) - and expr.op.name == "relax.matmul" - ): - matmul_call = expr - if matmul_call is None: - raise ValueError("Cannot find call to matmul from match_result.") - - lhs, rhs, *_ = matmul_call.args +def _check_matmul(context: PatternCheckContext) -> bool: + lhs = context.annotated_expr["lhs"] + rhs = context.annotated_expr["rhs"] lhs_dtype = lhs.struct_info.dtype rhs_dtype = rhs.struct_info.dtype @@ -166,6 +150,5 @@ def partition_for_cublas(mod): offloaded to the cuBLAS backend. """ - cublas_pattern_entries = get_patterns_with_prefix("cublas") - patterns = [(e.name, e.pattern, e.check) for e in cublas_pattern_entries] + patterns = get_patterns_with_prefix("cublas") return transform.FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=True)(mod) diff --git a/src/relax/backend/contrib/utils.cc b/src/relax/backend/contrib/utils.cc index e0a93e93dac6..565b7769f0ad 100644 --- a/src/relax/backend/contrib/utils.cc +++ b/src/relax/backend/contrib/utils.cc @@ -22,6 +22,8 @@ #include #include +#include + #include "../pattern_registry.h" namespace tvm { @@ -38,20 +40,22 @@ Map ExtractArgIdx(String pattern_name, Function f) { auto matched_expr = relax::ExtractMatchedExpr(pattern.value()->pattern, inner_body, bindings); ICHECK(matched_expr); - auto find_index = [](const Array& params, Var v) { + auto find_index = [](const Array& params, Var v) -> std::optional { for (size_t i = 0; i < params.size(); ++i) { if (params[i] == v) { return i; } } - LOG(FATAL) << "Variable not found " << v; - return size_t(0); + return std::nullopt; }; - for (const auto& [name, pat] : pattern.value()->arg_patterns) { - auto arg_var = matched_expr.value()[pat]; - auto idx = find_index(f->params, Downcast(arg_var)); - arg_idx.Set(name, IntImm(DataType::Int(64), idx)); + for (const auto& [name, pat] : pattern.value()->annotation_patterns) { + auto exp = matched_expr.value()[pat]; + if (auto arg_var = exp.as()) { + if (auto idx = find_index(f->params, GetRef(arg_var))) { + arg_idx.Set(name, IntImm(DataType::Int(64), *idx)); + } + } } return arg_idx; From 54ca782f554d102af5ca3cf644e3cfd127c886f2 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 3 Apr 2023 13:49:13 +0900 Subject: [PATCH 15/15] pylint --- python/tvm/relax/backend/contrib/cublas.py | 2 +- python/tvm/relax/backend/contrib/cutlass.py | 5 ++--- python/tvm/relax/testing/matmul.py | 1 + 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/backend/contrib/cublas.py b/python/tvm/relax/backend/contrib/cublas.py index b8c4952d8da1..627c9369935b 100644 --- a/python/tvm/relax/backend/contrib/cublas.py +++ b/python/tvm/relax/backend/contrib/cublas.py @@ -21,10 +21,10 @@ import tvm from tvm.relax import transform +from tvm.relax.transform import PatternCheckContext from ..pattern_registry import get_patterns_with_prefix, register_patterns from ..patterns import make_matmul_pattern -from tvm.relax.transform import PatternCheckContext def _is_supported_dtype(lhs_dtype, rhs_dtype): diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index 8b243e52eae2..856cd4d7871f 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -17,11 +17,10 @@ """Pattern table for CUTLASS backend""" -from typing import Mapping, Optional, Sequence, Tuple +from typing import Mapping, Sequence -import tvm from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul -from tvm.relax import DataflowVar, ShapeExpr, Var, transform +from tvm.relax import DataflowVar, Var, transform from tvm.relax.transform import PatternCheckContext from ..pattern_registry import get_patterns_with_prefix, register_patterns diff --git a/python/tvm/relax/testing/matmul.py b/python/tvm/relax/testing/matmul.py index ca95c2f104d4..bac6fc6c9ae8 100644 --- a/python/tvm/relax/testing/matmul.py +++ b/python/tvm/relax/testing/matmul.py @@ -31,6 +31,7 @@ def get_relax_matmul_module( residual_bin_op=None, residual_activation=None, ): + """Create a matmul op followd by epilogue operations.""" if transposed_y: n = y_shape[-2] else: