Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/runtime/opencl/opencl_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,15 @@ class OpenCLWorkspace : public DeviceAPI {
ICHECK(IsOpenCLDevice(dev));
this->Init();
ICHECK(dev.device_id >= 0 && static_cast<size_t>(dev.device_id) < queues.size())
<< "Invalid OpenCL device_id=" << dev.device_id;
<< "Invalid OpenCL device_id=" << dev.device_id << ". " << GetError();
return queues[dev.device_id];
}
// get the event queue of the context
std::vector<cl_event>& GetEventQueue(Device dev) {
ICHECK(IsOpenCLDevice(dev));
this->Init();
ICHECK(dev.device_id >= 0 && static_cast<size_t>(dev.device_id) < queues.size())
<< "Invalid OpenCL device_id=" << dev.device_id;
<< "Invalid OpenCL device_id=" << dev.device_id << ". " << GetError();
return events[dev.device_id];
}
// is current clCommandQueue in profiling mode
Expand Down Expand Up @@ -310,6 +310,13 @@ class OpenCLWorkspace : public DeviceAPI {
static OpenCLWorkspace* Global();

void CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) final;

private:
std::string GetError() {
if (this->devices.size() == 0) return noDevicesErrorMsg;
return "";
}
std::string noDevicesErrorMsg = "";
};

/*! \brief Thread local workspace */
Expand Down
70 changes: 52 additions & 18 deletions src/runtime/opencl/opencl_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include <tvm/runtime/profiling.h>
#include <tvm/runtime/registry.h>

#include <sstream>

#include "opencl_common.h"

namespace tvm {
Expand All @@ -33,6 +35,7 @@ namespace cl {

std::string GetPlatformInfo(cl_platform_id pid, cl_platform_info param_name);
std::string GetDeviceInfo(cl_device_id pid, cl_device_info param_name);
std::string GetOpenCLVersion(cl_device_id pid);

struct ImageInfo {
size_t origin[3] = {};
Expand Down Expand Up @@ -111,7 +114,7 @@ void OpenCLWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv)
*rv = static_cast<int>(index < devices.size());
return;
}
ICHECK_LT(index, devices.size()) << "Invalid device id " << index;
ICHECK_LT(index, devices.size()) << "Invalid device id " << index << ". " << GetError();
switch (kind) {
case kExist:
break;
Expand Down Expand Up @@ -139,17 +142,9 @@ void OpenCLWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv)
*rv = static_cast<int64_t>(value);
break;
}
case kComputeVersion: {
// String returned is "OpenCL $MAJOR.$MINOR $VENDOR_INFO". To
// match other implementations, we want to return "$MAJOR.$MINOR"
std::string ret = GetDeviceInfo(devices[index], CL_DEVICE_VERSION);

const size_t version_start = 7; // Length of initial "OpenCL " prefix to skip
const size_t version_end = ret.find(' ', version_start);
*rv = ret.substr(version_start, version_end - version_start);
case kComputeVersion:
*rv = GetOpenCLVersion(devices[index]);
break;
}
return;
case kDeviceName:
*rv = GetDeviceInfo(devices[index], CL_DEVICE_NAME);
break;
Expand Down Expand Up @@ -200,7 +195,7 @@ void OpenCLWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv)
void* OpenCLWorkspace::AllocDataSpace(Device dev, size_t size, size_t alignment,
DLDataType type_hint) {
this->Init();
ICHECK(context != nullptr) << "No OpenCL device";
ICHECK(context != nullptr) << "No OpenCL device. " << GetError();
cl_int err_code;
cl::BufferDescriptor* desc = new cl::BufferDescriptor;
// CL_INVALID_BUFFER_SIZE if size is 0.
Expand Down Expand Up @@ -245,7 +240,7 @@ void OpenCLWorkspace::FreeDataSpace(Device dev, void* ptr) {
cl_mem OpenCLWorkspace::AllocTexture(Device dev, size_t width, size_t height,
DLDataType type_hint) {
this->Init();
ICHECK(context != nullptr) << "No OpenCL device";
ICHECK(context != nullptr) << "No OpenCL device. " << GetError();
cl_int err_code;
cl_channel_type cl_type = DTypeToOpenCLChannelType(type_hint);
cl_image_format format = {CL_RGBA, cl_type};
Expand Down Expand Up @@ -373,12 +368,23 @@ std::string GetPlatformInfo(cl_platform_id pid, cl_platform_info param_name) {
std::string GetDeviceInfo(cl_device_id pid, cl_device_info param_name) {
size_t ret_size;
OPENCL_CALL(clGetDeviceInfo(pid, param_name, 0, nullptr, &ret_size));
std::string ret;
ret.resize(ret_size);
OPENCL_CALL(clGetDeviceInfo(pid, param_name, ret_size, &ret[0], nullptr));
char* info = new char[ret_size];
OPENCL_CALL(clGetDeviceInfo(pid, param_name, ret_size, info, nullptr));
std::string ret = info;
delete[] info;
return ret;
}

std::string GetOpenCLVersion(cl_device_id pid) {
// String returned is "OpenCL $MAJOR.$MINOR $VENDOR_INFO". To
// match other implementations, we want to return "$MAJOR.$MINOR"
std::string ret = GetDeviceInfo(pid, CL_DEVICE_VERSION);

const size_t version_start = 7; // Length of initial "OpenCL " prefix to skip
const size_t version_end = ret.find(' ', version_start);
return ret.substr(version_start, version_end - version_start);
}

std::vector<cl_platform_id> GetPlatformIDs() {
cl_uint ret_size;
cl_int code = clGetPlatformIDs(0, nullptr, &ret_size);
Expand Down Expand Up @@ -432,16 +438,44 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic
LOG(WARNING) << "Using CPU OpenCL device";
devices_matched = cl::GetDeviceIDs(platform_id, "cpu");
}
if (devices_matched.size() > 0) {
std::vector<cl_device_id> supported_devices = {};
auto get_version_str = [](int version) {
std::ostringstream out;
out.precision(1);
out << std::fixed << version / 100.f;
return out.str();
};
for (auto& device : devices_matched) {
std::string ver = GetOpenCLVersion(device);
int opencl_version = std::stod(ver) * 100;
if (opencl_version >= CL_TARGET_OPENCL_VERSION) {
supported_devices.push_back(device);
} else {
std::string dev_msg = GetDeviceInfo(device, CL_DEVICE_NAME) +
" has OpenCL version == " + get_version_str(opencl_version);
LOG(WARNING) << "TVM supports devices with OpenCL version >= "
<< get_version_str(CL_TARGET_OPENCL_VERSION) << ", device " << dev_msg
<< ". This device will be ignored.";

if (noDevicesErrorMsg.empty()) {
noDevicesErrorMsg =
"Probably this error happen because TVM supports devices with OpenCL version >= " +
get_version_str(CL_TARGET_OPENCL_VERSION) + ". We found the following devices:\n";
}
noDevicesErrorMsg += "\t" + dev_msg + "\n";
}
}
if (supported_devices.size() > 0) {
this->platform_id = platform_id;
this->platform_name = cl::GetPlatformInfo(platform_id, CL_PLATFORM_NAME);
this->device_type = device_type;
this->devices = devices_matched;
this->devices = supported_devices;
break;
}
}
if (this->platform_id == nullptr) {
LOG(WARNING) << "No OpenCL device";
initialized_ = true;
return;
}
cl_int err_code;
Expand Down