Skip to content
1 change: 1 addition & 0 deletions cmake/modules/OpenCL.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ endif(USE_AOCL)

if(USE_OPENCL)
tvm_file_glob(GLOB RUNTIME_OPENCL_SRCS src/runtime/opencl/*.cc)
list(APPEND COMPILER_SRCS src/target/spirv/spirv_utils.cc)

if(${USE_OPENCL} MATCHES ${IS_TRUE_PATTERN})
message(STATUS "Enabled runtime search for OpenCL library location")
Expand Down
1 change: 1 addition & 0 deletions cmake/modules/Vulkan.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ if(USE_VULKAN)
list(APPEND COMPILER_SRCS ${COMPILER_VULKAN_SRCS})
list(APPEND TVM_LINKER_LIBS ${Vulkan_SPIRV_TOOLS_LIBRARY})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${Vulkan_LIBRARY})
add_definitions(-DTVM_ENABLE_SPIRV=1)
endif(USE_VULKAN)
60 changes: 38 additions & 22 deletions src/runtime/opencl/opencl_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,18 +411,16 @@ struct BufferDescriptor {
// To make the call thread-safe, we create a thread-local kernel table
// and lazily install new kernels into the kernel table when the kernel is called.
// The kernels are recycled when the module get destructed.
class OpenCLModuleNode : public ModuleNode {
class OpenCLModuleNodeBase : public ModuleNode {
public:
// Kernel table reference entry.
struct KTRefEntry {
size_t kernel_id;
size_t version;
};
explicit OpenCLModuleNode(std::string data, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
: data_(data), fmt_(fmt), fmap_(fmap), source_(source) {}
explicit OpenCLModuleNodeBase(std::unordered_map<std::string, FunctionInfo> fmap) : fmap_(fmap) {}
// destructor
~OpenCLModuleNode();
~OpenCLModuleNodeBase();

/*!
* \brief Get the global workspace
Expand All @@ -436,38 +434,56 @@ class OpenCLModuleNode : public ModuleNode {
return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable;
}

PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;
void SaveToFile(const std::string& file_name, const std::string& format) final;
void SaveToBinary(dmlc::Stream* stream) final;
std::string GetSource(const std::string& format) final;
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) override;

// Initialize the programs
void Init();
virtual void Init() = 0;
// install a new kernel to thread local entry
cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t,
const std::string& func_name, const KTRefEntry& e);
void SetPreCompiledPrograms(const std::string& bytes);
std::string GetPreCompiledPrograms();
virtual cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t,
const std::string& func_name, const KTRefEntry& e) = 0;

private:
protected:
// The workspace, need to keep reference to use it in destructor.
// In case of static destruction order problem.
cl::OpenCLWorkspace* workspace_;
// the binary data
std::string data_;
// The format
std::string fmt_;
// function information table.
std::unordered_map<std::string, FunctionInfo> fmap_;
// Module local mutex
std::mutex build_lock_;
// The OpenCL source.
std::string source_;
// Mapping from primitive name to cl program for each device.
std::unordered_map<std::string, std::vector<cl_program>> programs_;
// kernel id cache
std::unordered_map<std::string, KTRefEntry> kid_map_;
// kernels build so far.
// kernels built so far.
std::vector<cl_kernel> kernels_;
};

class OpenCLModuleNode : public OpenCLModuleNodeBase {
public:
explicit OpenCLModuleNode(std::string data, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
: OpenCLModuleNodeBase(fmap), data_(data), fmt_(fmt), source_(source) {}

PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;
void SaveToFile(const std::string& file_name, const std::string& format) final;
void SaveToBinary(dmlc::Stream* stream) final;
void SetPreCompiledPrograms(const std::string& bytes);
std::string GetPreCompiledPrograms();
std::string GetSource(const std::string& format) final;

// Initialize the programs
void Init() override;
// install a new kernel to thread local entry
cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t,
const std::string& func_name, const KTRefEntry& e) override;

private:
// the binary data
std::string data_;
// The format
std::string fmt_;
// The OpenCL source.
std::string source_;
// parsed kernel data
std::unordered_map<std::string, std::string> parsed_kernels_;
};
Expand Down
36 changes: 21 additions & 15 deletions src/runtime/opencl/opencl_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ namespace runtime {
class OpenCLWrappedFunc {
public:
// initialize the OpenCL function.
void Init(OpenCLModuleNode* m, ObjectPtr<Object> sptr, OpenCLModuleNode::KTRefEntry entry,
void Init(OpenCLModuleNodeBase* m, ObjectPtr<Object> sptr, OpenCLModuleNode::KTRefEntry entry,
std::string func_name, std::vector<size_t> arg_size,
const std::vector<std::string>& launch_param_tags) {
w_ = m->GetGlobalWorkspace();
Expand Down Expand Up @@ -95,7 +95,7 @@ class OpenCLWrappedFunc {
// global workspace.
cl::OpenCLWorkspace* w_;
// The module
OpenCLModuleNode* m_;
OpenCLModuleNodeBase* m_;
// resource handle
ObjectPtr<Object> sptr_;
// global kernel id in the kernel table.
Expand All @@ -108,7 +108,7 @@ class OpenCLWrappedFunc {
LaunchParamConfig launch_param_config_;
};

OpenCLModuleNode::~OpenCLModuleNode() {
OpenCLModuleNodeBase::~OpenCLModuleNodeBase() {
{
// free the kernel ids in global table.
std::lock_guard<std::mutex> lock(workspace_->mu);
Expand All @@ -130,22 +130,13 @@ OpenCLModuleNode::~OpenCLModuleNode() {
}
}

cl::OpenCLWorkspace* OpenCLModuleNode::GetGlobalWorkspace() {
cl::OpenCLWorkspace* OpenCLModuleNodeBase::GetGlobalWorkspace() {
return cl::OpenCLWorkspace::Global();
}

PackedFunc OpenCLModuleNode::GetFunction(const std::string& name,
const ObjectPtr<Object>& sptr_to_self) {
PackedFunc OpenCLModuleNodeBase::GetFunction(const std::string& name,
const ObjectPtr<Object>& sptr_to_self) {
ICHECK_EQ(sptr_to_self.get(), this);
if (name == "opencl.GetPreCompiledPrograms") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->GetPreCompiledPrograms();
});
} else if (name == "opencl.SetPreCompiledPrograms") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->SetPreCompiledPrograms(args[0]);
});
}
ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
auto it = fmap_.find(name);
if (it == fmap_.end()) return PackedFunc();
Expand Down Expand Up @@ -344,6 +335,21 @@ std::string OpenCLModuleNode::GetPreCompiledPrograms() {
return data;
}

PackedFunc OpenCLModuleNode::GetFunction(const std::string& name,
const ObjectPtr<Object>& sptr_to_self) {
ICHECK_EQ(sptr_to_self.get(), this);
if (name == "opencl.GetPreCompiledPrograms") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->GetPreCompiledPrograms();
});
} else if (name == "opencl.SetPreCompiledPrograms") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->SetPreCompiledPrograms(args[0]);
});
}
return OpenCLModuleNodeBase::GetFunction(name, sptr_to_self);
}

Module OpenCLModuleCreate(std::string data, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
auto n = make_object<OpenCLModuleNode>(data, fmt, fmap, source);
Expand Down
14 changes: 13 additions & 1 deletion src/runtime/opencl/opencl_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,30 @@
#include <vector>

#include "../meta_data.h"
#include "../spirv/spirv_shader.h"

namespace tvm {
namespace runtime {
/*!
* \brief create a opencl module for GPU devices from data.
* \brief Create a opencl module for GPU devices from data.
*
* \param data The module data.
* \param fmt The format of the data, can be "clbin", "cl"
* \param fmap The map function information map of each function.
*/
Module OpenCLModuleCreate(std::string data, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap, std::string source);

/*!
* \brief Create a opencl module from SPIRV.
*
* \param shaders The map from function names to SPIRV binaries.
* \param spirv_text The concatenated text representation of SPIRV modules.
* \param fmap The map function information map of each function.
*/
Module OpenCLModuleCreate(const std::unordered_map<std::string, spirv::SPIRVShader>& shaders,
const std::string& spirv_text,
std::unordered_map<std::string, FunctionInfo> fmap);
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_OPENCL_OPENCL_MODULE_H_
137 changes: 137 additions & 0 deletions src/runtime/opencl/opencl_module_spirv.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <dmlc/memory_io.h>
#include <tvm/runtime/registry.h>

#include <string>
#include <unordered_map>
#include <vector>

#include "../source_utils.h"
#include "../spirv/spirv_shader.h"
#include "opencl_common.h"
#include "opencl_module.h"

namespace tvm {
namespace runtime {

class OpenCLSPIRVModuleNode : public OpenCLModuleNodeBase {
public:
explicit OpenCLSPIRVModuleNode(const std::unordered_map<std::string, SPIRVShader>& shaders,
const std::string& spirv_text,
std::unordered_map<std::string, FunctionInfo> fmap)
: OpenCLModuleNodeBase(fmap), shaders_(shaders), spirv_text_(spirv_text) {}

void SaveToFile(const std::string& file_name, const std::string& format) final;
void SaveToBinary(dmlc::Stream* stream) final;
std::string GetSource(const std::string&) final { return spirv_text_; }

void Init() override;
cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t,
const std::string& func_name, const KTRefEntry& e) override;

private:
std::unordered_map<std::string, SPIRVShader> shaders_;
std::string spirv_text_;
};

void OpenCLSPIRVModuleNode::SaveToFile(const std::string& file_name, const std::string& format) {
// TODO(masahi): How SPIRV binaries should be save to a file?
LOG(FATAL) << "Not implemented.";
}

void OpenCLSPIRVModuleNode::SaveToBinary(dmlc::Stream* stream) {
stream->Write(fmap_);
stream->Write(shaders_);
}

void OpenCLSPIRVModuleNode::Init() {
workspace_ = GetGlobalWorkspace();
workspace_->Init();
// initialize the kernel id, need to lock global table.
std::lock_guard<std::mutex> lock(workspace_->mu);
for (const auto& kv : fmap_) {
const std::string& key = kv.first;
KTRefEntry e;
if (workspace_->free_kernel_ids.size() != 0) {
e.kernel_id = workspace_->free_kernel_ids.back();
workspace_->free_kernel_ids.pop_back();
} else {
e.kernel_id = workspace_->num_registered_kernels++;
}
e.version = workspace_->timestamp++;
kid_map_[key] = e;
}

// zero initialize cl_program pointers for each device kernel
for (auto& kv : shaders_) {
programs_.insert({kv.first, std::vector<cl_program>(workspace_->devices.size(), nullptr)});
}
}

cl_kernel OpenCLSPIRVModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t,
const std::string& func_name, const KTRefEntry& e) {
std::lock_guard<std::mutex> lock(build_lock_);
int device_id = t->device.device_id;
if (programs_[func_name][device_id] == nullptr) {
auto it = shaders_.find(func_name);
const unsigned char* s = (const unsigned char*)it->second.data.data();
size_t len = it->second.data.size() * sizeof(uint32_t);
cl_int err;
cl_device_id dev = w->devices[device_id];
auto platform = w->device_to_platform[dev];
programs_[func_name][device_id] =
clCreateProgramWithBinary(w->contexts[platform], 1, &dev, &len, &s, nullptr, &err);
OPENCL_CHECK_ERROR(err);

// build program
err = clBuildProgram(programs_[func_name][device_id], 1, &dev, nullptr, nullptr, nullptr);

if (err != CL_SUCCESS) {
size_t len;
std::string log;
clGetProgramBuildInfo(programs_[func_name][device_id], dev, CL_PROGRAM_BUILD_LOG, 0, nullptr,
&len);
log.resize(len);
clGetProgramBuildInfo(programs_[func_name][device_id], dev, CL_PROGRAM_BUILD_LOG, len,
&log[0], nullptr);
LOG(FATAL) << "OpenCL build error for device=" << dev << "\n" << log;
}
}
// build kernel
cl_int err;
cl_kernel kernel = clCreateKernel(programs_[func_name][device_id], func_name.c_str(), &err);
OPENCL_CHECK_ERROR(err);
t->kernel_table[e.kernel_id].kernel = kernel;
t->kernel_table[e.kernel_id].version = e.version;
kernels_.push_back(kernel);
return kernel;
}

Module OpenCLModuleCreate(const std::unordered_map<std::string, SPIRVShader>& shaders,
const std::string& spirv_text,
std::unordered_map<std::string, FunctionInfo> fmap) {
auto n = make_object<OpenCLSPIRVModuleNode>(shaders, spirv_text, fmap);
n->Init();
return Module(n);
}

} // namespace runtime
} // namespace tvm
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
* under the License.
*/

#ifndef TVM_RUNTIME_VULKAN_VULKAN_SHADER_H_
#define TVM_RUNTIME_VULKAN_VULKAN_SHADER_H_
#ifndef TVM_RUNTIME_SPIRV_SPIRV_SHADER_H_
#define TVM_RUNTIME_SPIRV_SPIRV_SHADER_H_

#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/device_api.h>
Expand All @@ -29,9 +29,9 @@

namespace tvm {
namespace runtime {
namespace vulkan {
namespace spirv {

struct VulkanShader {
struct SPIRVShader {
/*! \brief header flag */
uint32_t flag{0};
/*! \brief Data segment */
Expand All @@ -48,13 +48,13 @@ struct VulkanShader {
}
};

} // namespace vulkan
} // namespace spirv

using vulkan::VulkanShader;
using spirv::SPIRVShader;
} // namespace runtime
} // namespace tvm

namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::vulkan::VulkanShader, true);
DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::spirv::SPIRVShader, true);
} // namespace dmlc
#endif // TVM_RUNTIME_VULKAN_VULKAN_SHADER_H_
#endif // TVM_RUNTIME_SPIRV_SPIRV_SHADER_H_
Loading