Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/runtime/opencl/opencl_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ struct BufferDescriptor;
class OpenCLWorkspace : public DeviceAPI {
public:
// type key
std::string type_key;
std::string type_key{"opencl"};
// available platforms
std::vector<cl_platform_id> platform_ids;
// map platform to its context
Expand Down Expand Up @@ -253,7 +253,7 @@ class OpenCLWorkspace : public DeviceAPI {
// Initialize the device.
void Init(const std::string& type_key, const std::string& device_type,
const std::string& platform_name = "");
virtual void Init() { Init("opencl", "gpu"); }
virtual void Init() { Init(this->type_key, "gpu"); }
// Check whether the context is OpenCL or not.
virtual bool IsOpenCLDevice(Device dev) { return dev.device_type == kDLOpenCL; }
// get the queue of the device
Expand Down Expand Up @@ -465,6 +465,8 @@ class OpenCLModuleNode : public OpenCLModuleNodeBase {
: OpenCLModuleNodeBase(fmap), data_(data), fmt_(fmt), source_(source) {}

PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) final;
// Return true if OpenCL program for the requested function and device was created
bool IsProgramCreated(const std::string& func_name, int device_id);
void SaveToFile(const String& file_name, const String& format) final;
void SaveToBinary(dmlc::Stream* stream) final;
void SetPreCompiledPrograms(const std::string& bytes);
Expand Down
4 changes: 4 additions & 0 deletions src/runtime/opencl/opencl_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ OpenCLWorkspace* OpenCLWorkspace::Global() {
}

cl_device_id OpenCLWorkspace::GetCLDeviceID(int device_id) {
this->Init();
ICHECK_LT(device_id, devices.size()) << "Invalid device id " << device_id << ". " << GetError();
return devices[device_id];
}
Expand Down Expand Up @@ -210,6 +211,7 @@ void OpenCLWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv)

void* OpenCLWorkspace::CreateHostPtrIfEnabled(cl::BufferDescriptor* desc, Device dev, size_t size) {
#if defined(OPENCL_ENABLE_HOST_PTR)
this->Init();
cl_int err_code;
desc->host_ptr = reinterpret_cast<cl_uchar*>(
clEnqueueMapBuffer(this->GetQueue(dev), desc->buffer, CL_TRUE, CL_MAP_WRITE, 0,
Expand Down Expand Up @@ -300,6 +302,7 @@ void OpenCLWorkspace::FreeTextureWorkspace(Device dev, void* ptr) {
}

void OpenCLWorkspace::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) {
this->Init();
size_t nbytes = GetDataSize(*from);
ICHECK_EQ(nbytes, GetDataSize(*to));
ICHECK(IsContiguous(*from) && IsContiguous(*to))
Expand Down Expand Up @@ -379,6 +382,7 @@ void OpenCLWorkspace::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHand
}

void OpenCLWorkspace::StreamSync(Device dev, TVMStreamHandle stream) {
this->Init();
ICHECK(stream == nullptr);
OPENCL_CALL(clFinish(this->GetQueue(dev)));
}
Expand Down
22 changes: 15 additions & 7 deletions src/runtime/opencl/opencl_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ String OpenCLModuleNode::GetSource(const String& format) {

void OpenCLModuleNode::Init() {
workspace_ = GetGlobalWorkspace();
workspace_->Init();
// initialize the kernel id, need to lock global table.
std::lock_guard<std::mutex> lock(workspace_->mu);
for (const auto& kv : fmap_) {
Expand All @@ -208,10 +207,17 @@ void OpenCLModuleNode::Init() {
<< "delimiter was found.";
ICHECK_EQ(fmap_.size(), parsed_kernels_.size())
<< "The number of parsed kernel sources does not match the number of kernel functions";
}

bool OpenCLModuleNode::IsProgramCreated(const std::string& func_name, int device_id) {
auto size = programs_[func_name].size();
if (size > 0 && programs_[func_name][device_id] != nullptr) return true;
auto dev_size = GetGlobalWorkspace()->devices.size();
ICHECK(device_id < static_cast<int>(dev_size))
<< "Device id " << device_id << " is bigger than number of available devices";
// zero initialize cl_program pointers for each device kernel
for (auto& kv : parsed_kernels_) {
programs_.insert({kv.first, std::vector<cl_program>(workspace_->devices.size(), nullptr)});
}
if (size == 0) programs_[func_name].resize(dev_size, nullptr);
return false;
}

cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t,
Expand All @@ -220,7 +226,7 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThre
int device_id = t->device.device_id;
auto did = w->GetCLDeviceID(device_id);
auto platform = w->device_to_platform[did];
if (programs_[func_name][device_id] == nullptr) {
if (!IsProgramCreated(func_name, device_id)) {
// create program
if (fmt_ == "cl") {
const char* s = parsed_kernels_[func_name].c_str();
Expand Down Expand Up @@ -268,6 +274,7 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThre
}

void OpenCLModuleNode::SetPreCompiledPrograms(const std::string& bytes) {
workspace_->Init();
std::string data = bytes;
dmlc::MemoryStringStream reader(&data);
dmlc::Stream* strm = &reader;
Expand All @@ -280,7 +287,7 @@ void OpenCLModuleNode::SetPreCompiledPrograms(const std::string& bytes) {
std::vector<unsigned char> bin_vector;
strm->Read(&name);
strm->Read(&bin_vector);
if (programs_[name][device_id] == nullptr) {
if (!IsProgramCreated(name, device_id)) {
cl_int err = 0;
cl_int binaryStatus;
size_t binarySize = bin_vector.size();
Expand Down Expand Up @@ -310,6 +317,7 @@ void OpenCLModuleNode::SetPreCompiledPrograms(const std::string& bytes) {
}

std::string OpenCLModuleNode::GetPreCompiledPrograms() {
workspace_->Init();
std::string data;
dmlc::MemoryStringStream writer(&data);
dmlc::Stream* strm = &writer;
Expand All @@ -319,7 +327,7 @@ std::string OpenCLModuleNode::GetPreCompiledPrograms() {
cl::OpenCLThreadEntry* t = workspace_->GetThreadEntry();
int device_id = t->device.device_id;
t->kernel_table.resize(workspace_->num_registered_kernels);
if (programs_[std::string(name)][device_id] == nullptr) {
if (!IsProgramCreated(name, device_id)) {
InstallKernel(workspace_, t, name, kid_map_[name]);
}
size_t size;
Expand Down
1 change: 1 addition & 0 deletions src/runtime/opencl/opencl_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ namespace runtime {
* \param data The module data.
* \param fmt The format of the data, can be "clbin", "cl"
* \param fmap The map function information map of each function.
* \param source Generated OpenCL kernels.
*/
Module OpenCLModuleCreate(std::string data, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap, std::string source);
Expand Down