diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index a3031413578f..8c1607c4e56f 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -220,7 +220,7 @@ struct BufferDescriptor; class OpenCLWorkspace : public DeviceAPI { public: // type key - std::string type_key; + std::string type_key{"opencl"}; // available platforms std::vector platform_ids; // map platform to its context @@ -253,7 +253,7 @@ class OpenCLWorkspace : public DeviceAPI { // Initialize the device. void Init(const std::string& type_key, const std::string& device_type, const std::string& platform_name = ""); - virtual void Init() { Init("opencl", "gpu"); } + virtual void Init() { Init(this->type_key, "gpu"); } // Check whether the context is OpenCL or not. virtual bool IsOpenCLDevice(Device dev) { return dev.device_type == kDLOpenCL; } // get the queue of the device @@ -465,6 +465,8 @@ class OpenCLModuleNode : public OpenCLModuleNodeBase { : OpenCLModuleNodeBase(fmap), data_(data), fmt_(fmt), source_(source) {} PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) final; + // Return true if OpenCL program for the requested function and device was created + bool IsProgramCreated(const std::string& func_name, int device_id); void SaveToFile(const String& file_name, const String& format) final; void SaveToBinary(dmlc::Stream* stream) final; void SetPreCompiledPrograms(const std::string& bytes); diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index 35e77eb6d1f6..fb9adc27573d 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -111,6 +111,7 @@ OpenCLWorkspace* OpenCLWorkspace::Global() { } cl_device_id OpenCLWorkspace::GetCLDeviceID(int device_id) { + this->Init(); ICHECK_LT(device_id, devices.size()) << "Invalid device id " << device_id << ". " << GetError(); return devices[device_id]; } @@ -210,6 +211,7 @@ void OpenCLWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) void* OpenCLWorkspace::CreateHostPtrIfEnabled(cl::BufferDescriptor* desc, Device dev, size_t size) { #if defined(OPENCL_ENABLE_HOST_PTR) + this->Init(); cl_int err_code; desc->host_ptr = reinterpret_cast( clEnqueueMapBuffer(this->GetQueue(dev), desc->buffer, CL_TRUE, CL_MAP_WRITE, 0, @@ -300,6 +302,7 @@ void OpenCLWorkspace::FreeTextureWorkspace(Device dev, void* ptr) { } void OpenCLWorkspace::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) { + this->Init(); size_t nbytes = GetDataSize(*from); ICHECK_EQ(nbytes, GetDataSize(*to)); ICHECK(IsContiguous(*from) && IsContiguous(*to)) @@ -379,6 +382,7 @@ void OpenCLWorkspace::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHand } void OpenCLWorkspace::StreamSync(Device dev, TVMStreamHandle stream) { + this->Init(); ICHECK(stream == nullptr); OPENCL_CALL(clFinish(this->GetQueue(dev))); } diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 6829d46d4339..567b7ad88a9e 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -185,7 +185,6 @@ String OpenCLModuleNode::GetSource(const String& format) { void OpenCLModuleNode::Init() { workspace_ = GetGlobalWorkspace(); - workspace_->Init(); // initialize the kernel id, need to lock global table. std::lock_guard lock(workspace_->mu); for (const auto& kv : fmap_) { @@ -208,10 +207,17 @@ void OpenCLModuleNode::Init() { << "delimiter was found."; ICHECK_EQ(fmap_.size(), parsed_kernels_.size()) << "The number of parsed kernel sources does not match the number of kernel functions"; +} + +bool OpenCLModuleNode::IsProgramCreated(const std::string& func_name, int device_id) { + auto size = programs_[func_name].size(); + if (size > 0 && programs_[func_name][device_id] != nullptr) return true; + auto dev_size = GetGlobalWorkspace()->devices.size(); + ICHECK(device_id < static_cast(dev_size)) + << "Device id " << device_id << " is bigger than number of available devices"; // zero initialize cl_program pointers for each device kernel - for (auto& kv : parsed_kernels_) { - programs_.insert({kv.first, std::vector(workspace_->devices.size(), nullptr)}); - } + if (size == 0) programs_[func_name].resize(dev_size, nullptr); + return false; } cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, @@ -220,7 +226,7 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThre int device_id = t->device.device_id; auto did = w->GetCLDeviceID(device_id); auto platform = w->device_to_platform[did]; - if (programs_[func_name][device_id] == nullptr) { + if (!IsProgramCreated(func_name, device_id)) { // create program if (fmt_ == "cl") { const char* s = parsed_kernels_[func_name].c_str(); @@ -268,6 +274,7 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThre } void OpenCLModuleNode::SetPreCompiledPrograms(const std::string& bytes) { + workspace_->Init(); std::string data = bytes; dmlc::MemoryStringStream reader(&data); dmlc::Stream* strm = &reader; @@ -280,7 +287,7 @@ void OpenCLModuleNode::SetPreCompiledPrograms(const std::string& bytes) { std::vector bin_vector; strm->Read(&name); strm->Read(&bin_vector); - if (programs_[name][device_id] == nullptr) { + if (!IsProgramCreated(name, device_id)) { cl_int err = 0; cl_int binaryStatus; size_t binarySize = bin_vector.size(); @@ -310,6 +317,7 @@ void OpenCLModuleNode::SetPreCompiledPrograms(const std::string& bytes) { } std::string OpenCLModuleNode::GetPreCompiledPrograms() { + workspace_->Init(); std::string data; dmlc::MemoryStringStream writer(&data); dmlc::Stream* strm = &writer; @@ -319,7 +327,7 @@ std::string OpenCLModuleNode::GetPreCompiledPrograms() { cl::OpenCLThreadEntry* t = workspace_->GetThreadEntry(); int device_id = t->device.device_id; t->kernel_table.resize(workspace_->num_registered_kernels); - if (programs_[std::string(name)][device_id] == nullptr) { + if (!IsProgramCreated(name, device_id)) { InstallKernel(workspace_, t, name, kid_map_[name]); } size_t size; diff --git a/src/runtime/opencl/opencl_module.h b/src/runtime/opencl/opencl_module.h index 834f53510ecc..22fc119e0318 100644 --- a/src/runtime/opencl/opencl_module.h +++ b/src/runtime/opencl/opencl_module.h @@ -42,6 +42,7 @@ namespace runtime { * \param data The module data. * \param fmt The format of the data, can be "clbin", "cl" * \param fmap The map function information map of each function. + * \param source Generated OpenCL kernels. */ Module OpenCLModuleCreate(std::string data, std::string fmt, std::unordered_map fmap, std::string source);