Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 51 additions & 17 deletions src/runtime/contrib/cublas/cublas_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,69 @@ class CublasJSONRuntime : public JSONRuntimeBase {

void Init(const Array<NDArray>& consts) override {}

PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) override {
// JSONRuntimeBase::SetInputOutputBuffers(...) is not thread safe. Since CublasJSONRuntime
// can be used by multiple GPUs running on different threads, we avoid using that function
// and directly call cuBLAS on the inputs from TVMArgs.
if (this->symbol_name_ == name) {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
ICHECK(this->initialized_) << "The module has not been initialized";
this->Run(args);
});
} else {
return JSONRuntimeBase::GetFunction(name, sptr_to_self);
}
}

const char* type_key() const override { return "cublas_json"; } // May be overridden

void Run() override {
void Run(TVMArgs args) {
auto* entry_ptr = tvm::contrib::CuBlasLtThreadEntry::ThreadLocal();

auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
ICHECK(func != nullptr);
cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());

std::vector<const DLTensor*> dl_tensors(NumEntries());

for (size_t i = 0; i < static_cast<size_t>(args.size()); i++) {
auto eid = i < input_var_eid_.size() ? input_var_eid_[i]
: EntryID(outputs_[i - input_var_eid_.size()]);
ICHECK(args[i].type_code() == kTVMNDArrayHandle || args[i].type_code() == kTVMDLTensorHandle)
<< "Expect NDArray or DLTensor as inputs";

const DLTensor* arg;
if (args[i].IsObjectRef<NDArray>()) {
NDArray arr = args[i];
arg = arr.operator->();
} else {
arg = args[i].operator DLTensor*();
}

dl_tensors[eid] = arg;
}

auto get_input = [this, &dl_tensors](const JSONGraphNode& node, int idx) {
ICHECK_LT(idx, node.GetInputs().size());
auto eid = EntryID(node.GetInputs()[idx]);
ICHECK(eid < dl_tensors.size());
return dl_tensors[eid];
};

auto get_inputs = [=](const JSONGraphNode& node, bool has_bias) {
const DLTensor* bias = nullptr;
if (has_bias) {
bias = get_input(node, 2);
}
return std::make_tuple(get_input(node, 0), get_input(node, 1), bias);
};

for (size_t i = 0; i < nodes_.size(); ++i) {
const auto& node = nodes_[i];
if (node.GetOpType() == "kernel") {
auto op_name = node.GetOpName();
uint32_t output_eid = EntryID(outputs_[0]);
auto out_ptr = data_entry_[output_eid];
auto out_ptr = dl_tensors[output_eid];
bool transa = false;
bool transb = false;
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
Expand All @@ -80,14 +128,6 @@ class CublasJSONRuntime : public JSONRuntimeBase {
epilogue = CUBLASLT_EPILOGUE_BIAS;
}

auto get_inputs = [this](const JSONGraphNode& node, bool has_bias) {
const DLTensor* bias = nullptr;
if (has_bias) {
bias = GetInput(node, 2);
}
return std::make_tuple(GetInput(node, 0), GetInput(node, 1), bias);
};

auto [a_ptr, b_ptr, bias_ptr] = get_inputs(node, epilogue != CUBLASLT_EPILOGUE_DEFAULT);

tvm::contrib::CallCublasLt(entry_ptr->handle, stream, a_ptr, b_ptr, bias_ptr, out_ptr,
Expand All @@ -96,13 +136,7 @@ class CublasJSONRuntime : public JSONRuntimeBase {
}
}

private:
const DLTensor* GetInput(const JSONGraphNode& node, const int idx) {
ICHECK_LT(idx, node.GetInputs().size());
auto eid = EntryID(node.GetInputs()[idx]);
ICHECK(eid < data_entry_.size());
return data_entry_[eid];
}
void Run() override { LOG(FATAL) << "Unreachable"; }
};

runtime::Module CublasJSONRuntimeCreate(String symbol_name, String graph_json,
Expand Down