From ccea726774fecb98700ab89d37fdef5c186c9d54 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 25 Oct 2023 03:58:29 +0900 Subject: [PATCH] Fix cuBLAS BYOC compatibilty with Disco with ThreadedSession --- .../contrib/cublas/cublas_json_runtime.cc | 68 ++++++++++++++----- 1 file changed, 51 insertions(+), 17 deletions(-) diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc b/src/runtime/contrib/cublas/cublas_json_runtime.cc index 9617559d7eec..c6916d4f86fa 100644 --- a/src/runtime/contrib/cublas/cublas_json_runtime.cc +++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc @@ -49,21 +49,69 @@ class CublasJSONRuntime : public JSONRuntimeBase { void Init(const Array& consts) override {} + PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) override { + // JSONRuntimeBase::SetInputOutputBuffers(...) is not thread safe. Since CublasJSONRuntime + // can be used by multiple GPUs running on different threads, we avoid using that function + // and directly call cuBLAS on the inputs from TVMArgs. + if (this->symbol_name_ == name) { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + ICHECK(this->initialized_) << "The module has not been initialized"; + this->Run(args); + }); + } else { + return JSONRuntimeBase::GetFunction(name, sptr_to_self); + } + } + const char* type_key() const override { return "cublas_json"; } // May be overridden - void Run() override { + void Run(TVMArgs args) { auto* entry_ptr = tvm::contrib::CuBlasLtThreadEntry::ThreadLocal(); auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream"); ICHECK(func != nullptr); cudaStream_t stream = static_cast((*func)().operator void*()); + std::vector dl_tensors(NumEntries()); + + for (size_t i = 0; i < static_cast(args.size()); i++) { + auto eid = i < input_var_eid_.size() ? input_var_eid_[i] + : EntryID(outputs_[i - input_var_eid_.size()]); + ICHECK(args[i].type_code() == kTVMNDArrayHandle || args[i].type_code() == kTVMDLTensorHandle) + << "Expect NDArray or DLTensor as inputs"; + + const DLTensor* arg; + if (args[i].IsObjectRef()) { + NDArray arr = args[i]; + arg = arr.operator->(); + } else { + arg = args[i].operator DLTensor*(); + } + + dl_tensors[eid] = arg; + } + + auto get_input = [this, &dl_tensors](const JSONGraphNode& node, int idx) { + ICHECK_LT(idx, node.GetInputs().size()); + auto eid = EntryID(node.GetInputs()[idx]); + ICHECK(eid < dl_tensors.size()); + return dl_tensors[eid]; + }; + + auto get_inputs = [=](const JSONGraphNode& node, bool has_bias) { + const DLTensor* bias = nullptr; + if (has_bias) { + bias = get_input(node, 2); + } + return std::make_tuple(get_input(node, 0), get_input(node, 1), bias); + }; + for (size_t i = 0; i < nodes_.size(); ++i) { const auto& node = nodes_[i]; if (node.GetOpType() == "kernel") { auto op_name = node.GetOpName(); uint32_t output_eid = EntryID(outputs_[0]); - auto out_ptr = data_entry_[output_eid]; + auto out_ptr = dl_tensors[output_eid]; bool transa = false; bool transb = false; cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; @@ -80,14 +128,6 @@ class CublasJSONRuntime : public JSONRuntimeBase { epilogue = CUBLASLT_EPILOGUE_BIAS; } - auto get_inputs = [this](const JSONGraphNode& node, bool has_bias) { - const DLTensor* bias = nullptr; - if (has_bias) { - bias = GetInput(node, 2); - } - return std::make_tuple(GetInput(node, 0), GetInput(node, 1), bias); - }; - auto [a_ptr, b_ptr, bias_ptr] = get_inputs(node, epilogue != CUBLASLT_EPILOGUE_DEFAULT); tvm::contrib::CallCublasLt(entry_ptr->handle, stream, a_ptr, b_ptr, bias_ptr, out_ptr, @@ -96,13 +136,7 @@ class CublasJSONRuntime : public JSONRuntimeBase { } } - private: - const DLTensor* GetInput(const JSONGraphNode& node, const int idx) { - ICHECK_LT(idx, node.GetInputs().size()); - auto eid = EntryID(node.GetInputs()[idx]); - ICHECK(eid < data_entry_.size()); - return data_entry_[eid]; - } + void Run() override { LOG(FATAL) << "Unreachable"; } }; runtime::Module CublasJSONRuntimeCreate(String symbol_name, String graph_json,