From 51eefd76ddee3b0bfd2832e6bd2dac02d5f323ad Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 31 Dec 2022 17:47:12 +0900 Subject: [PATCH 01/13] [OpenCL] Support SPIRV module ingestion --- src/runtime/opencl/opencl_common.h | 49 +++++-- src/runtime/opencl/opencl_module.cc | 121 ++++++++++++++++-- src/runtime/opencl/opencl_module.h | 5 + .../{vulkan_shader.h => spirv_shader.h} | 6 +- src/runtime/vulkan/vulkan_module.cc | 6 +- src/runtime/vulkan/vulkan_module.h | 4 +- src/runtime/vulkan/vulkan_wrapped_func.h | 6 +- src/target/source/codegen_opencl.cc | 8 ++ src/target/spirv/build_vulkan.cc | 39 ++++-- src/target/spirv/codegen_spirv.cc | 5 +- src/target/spirv/codegen_spirv.h | 9 +- src/target/spirv/spirv_support.cc | 5 +- web/emcc/webgpu_runtime.cc | 1 + 13 files changed, 209 insertions(+), 55 deletions(-) rename src/runtime/vulkan/{vulkan_shader.h => spirv_shader.h} (92%) diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index a8a4cf3dc65c..6d760b550582 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -411,18 +411,16 @@ struct BufferDescriptor { // To make the call thread-safe, we create a thread-local kernel table // and lazily install new kernels into the kernel table when the kernel is called. // The kernels are recycled when the module get destructed. -class OpenCLModuleNode : public ModuleNode { +class OpenCLModuleNodeBase : public ModuleNode { public: // Kernel table reference entry. struct KTRefEntry { size_t kernel_id; size_t version; }; - explicit OpenCLModuleNode(std::string data, std::string fmt, - std::unordered_map fmap, std::string source) - : data_(data), fmt_(fmt), fmap_(fmap), source_(source) {} + explicit OpenCLModuleNodeBase(std::unordered_map fmap) : fmap_(fmap) {} // destructor - ~OpenCLModuleNode(); + ~OpenCLModuleNodeBase(); /*! * \brief Get the global workspace @@ -448,26 +446,51 @@ class OpenCLModuleNode : public ModuleNode { void SetPreCompiledPrograms(const std::string& bytes); std::string GetPreCompiledPrograms(); - private: + // Initialize the programs + virtual void Init() = 0; + // install a new kernel to thread local entry + virtual cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, + const std::string& func_name, const KTRefEntry& e) = 0; + + protected: // The workspace, need to keep reference to use it in destructor. // In case of static destruction order problem. cl::OpenCLWorkspace* workspace_; - // the binary data - std::string data_; - // The format - std::string fmt_; // function information table. std::unordered_map fmap_; // Module local mutex std::mutex build_lock_; - // The OpenCL source. - std::string source_; // Mapping from primitive name to cl program for each device. std::unordered_map> programs_; // kernel id cache std::unordered_map kid_map_; - // kernels build so far. + // kernels built so far. std::vector kernels_; +}; + +class OpenCLModuleNode : public OpenCLModuleNodeBase { + public: + explicit OpenCLModuleNode(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) + : OpenCLModuleNodeBase(fmap), data_(data), fmt_(fmt), source_(source) {} + + void SaveToFile(const std::string& file_name, const std::string& format) final; + void SaveToBinary(dmlc::Stream* stream) final; + + std::string GetSource(const std::string& format) final; + // Initialize the programs + void Init() override; + // install a new kernel to thread local entry + cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, + const std::string& func_name, const KTRefEntry& e) override; + + private: + // the binary data + std::string data_; + // The format + std::string fmt_; + // The OpenCL source. + std::string source_; // parsed kernel data std::unordered_map parsed_kernels_; }; diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 7c084758a456..bee88d897830 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -30,6 +30,7 @@ #include #include "../source_utils.h" +#include "../vulkan/spirv_shader.h" #include "opencl_common.h" namespace tvm { @@ -38,7 +39,7 @@ namespace runtime { class OpenCLWrappedFunc { public: // initialize the OpenCL function. - void Init(OpenCLModuleNode* m, ObjectPtr sptr, OpenCLModuleNode::KTRefEntry entry, + void Init(OpenCLModuleNodeBase* m, ObjectPtr sptr, OpenCLModuleNode::KTRefEntry entry, std::string func_name, std::vector arg_size, const std::vector& launch_param_tags) { w_ = m->GetGlobalWorkspace(); @@ -95,7 +96,7 @@ class OpenCLWrappedFunc { // global workspace. cl::OpenCLWorkspace* w_; // The module - OpenCLModuleNode* m_; + OpenCLModuleNodeBase* m_; // resource handle ObjectPtr sptr_; // global kernel id in the kernel table. @@ -108,7 +109,30 @@ class OpenCLWrappedFunc { LaunchParamConfig launch_param_config_; }; -OpenCLModuleNode::~OpenCLModuleNode() { +class OpenCLSPIRVModuleNode : public OpenCLModuleNodeBase { + public: + explicit OpenCLSPIRVModuleNode(const std::unordered_map& shaders, + const std::string& spirv_text, + std::unordered_map fmap) + : OpenCLModuleNodeBase(fmap), shaders_(shaders), spirv_text_(spirv_text) {} + + void SaveToFile(const std::string& file_name, const std::string& format) final; + void SaveToBinary(dmlc::Stream* stream) final; + + std::string GetSource(const std::string&) final { return spirv_text_; } + + // Initialize the programs + void Init() override; + // install a new kernel to thread local entry + cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, + const std::string& func_name, const KTRefEntry& e) override; + + private: + std::unordered_map shaders_; + std::string spirv_text_; +}; + +OpenCLModuleNodeBase::~OpenCLModuleNodeBase() { { // free the kernel ids in global table. std::lock_guard lock(workspace_->mu); @@ -130,12 +154,12 @@ OpenCLModuleNode::~OpenCLModuleNode() { } } -cl::OpenCLWorkspace* OpenCLModuleNode::GetGlobalWorkspace() { +cl::OpenCLWorkspace* OpenCLModuleNodeBase::GetGlobalWorkspace() { return cl::OpenCLWorkspace::Global(); } -PackedFunc OpenCLModuleNode::GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc OpenCLModuleNodeBase::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { ICHECK_EQ(sptr_to_self.get(), this); if (name == "opencl.GetPreCompiledPrograms") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { @@ -218,8 +242,8 @@ void OpenCLModuleNode::Init() { ICHECK_EQ(fmap_.size(), parsed_kernels_.size()) << "The number of parsed kernel sources does not match the number of kernel functions"; // zero initialize cl_program pointers for each device kernel - for (auto& kv : parsed_kernels_) { - programs_.insert({kv.first, std::vector(workspace_->devices.size(), nullptr)}); + for (const auto& [func_name, _] : parsed_kernels_) { + programs_.insert({func_name, std::vector(workspace_->devices.size(), nullptr)}); } } @@ -238,7 +262,7 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThre programs_[func_name][device_id] = clCreateProgramWithSource(w->contexts[platform], 1, &s, &len, &err); OPENCL_CHECK_ERROR(err); - } else if (fmt_ == "xclbin" || fmt_ == "awsxclbin" || fmt_ == "aocx") { + } else if (fmt_ == "xclbin" || fmt_ == "awsxclbin" || fmt_ == "aocx" || fmt_ == "spirv") { const unsigned char* s = (const unsigned char*)data_.c_str(); size_t len = data_.length(); cl_int err; @@ -351,6 +375,85 @@ Module OpenCLModuleCreate(std::string data, std::string fmt, return Module(n); } +void OpenCLSPIRVModuleNode::SaveToFile(const std::string& file_name, const std::string& format) { + LOG(FATAL) << "Not implemented"; +} + +void OpenCLSPIRVModuleNode::SaveToBinary(dmlc::Stream* stream) { + stream->Write(fmap_); + stream->Write(shaders_); +} + +void OpenCLSPIRVModuleNode::Init() { + workspace_ = GetGlobalWorkspace(); + workspace_->Init(); + // initialize the kernel id, need to lock global table. + std::lock_guard lock(workspace_->mu); + for (const auto& kv : fmap_) { + const std::string& key = kv.first; + KTRefEntry e; + if (workspace_->free_kernel_ids.size() != 0) { + e.kernel_id = workspace_->free_kernel_ids.back(); + workspace_->free_kernel_ids.pop_back(); + } else { + e.kernel_id = workspace_->num_registered_kernels++; + } + e.version = workspace_->timestamp++; + kid_map_[key] = e; + } + + // zero initialize cl_program pointers for each device kernel + for (const auto& [func_name, _] : shaders_) { + programs_.insert({func_name, std::vector(workspace_->devices.size(), nullptr)}); + } +} + +cl_kernel OpenCLSPIRVModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, + const std::string& func_name, const KTRefEntry& e) { + std::lock_guard lock(build_lock_); + int device_id = t->device.device_id; + if (programs_[func_name][device_id] == nullptr) { + auto it = shaders_.find(func_name); + const unsigned char* s = (const unsigned char*)it->second.data.data(); + size_t len = it->second.data.size() * sizeof(uint32_t); + cl_int err; + cl_device_id dev = w->devices[device_id]; + programs_[func_name][device_id] = + clCreateProgramWithBinary(w->context, 1, &dev, &len, &s, nullptr, &err); + OPENCL_CHECK_ERROR(err); + + // build program + err = clBuildProgram(programs_[func_name][device_id], 1, &dev, nullptr, nullptr, nullptr); + + if (err != CL_SUCCESS) { + size_t len; + std::string log; + clGetProgramBuildInfo(programs_[func_name][device_id], dev, CL_PROGRAM_BUILD_LOG, 0, nullptr, + &len); + log.resize(len); + clGetProgramBuildInfo(programs_[func_name][device_id], dev, CL_PROGRAM_BUILD_LOG, len, + &log[0], nullptr); + LOG(FATAL) << "OpenCL build error for device=" << dev << "\n" << log; + } + } + // build kernel + cl_int err; + cl_kernel kernel = clCreateKernel(programs_[func_name][device_id], func_name.c_str(), &err); + OPENCL_CHECK_ERROR(err); + t->kernel_table[e.kernel_id].kernel = kernel; + t->kernel_table[e.kernel_id].version = e.version; + kernels_.push_back(kernel); + return kernel; +} + +Module OpenCLModuleCreate(const std::unordered_map& shaders, + const std::string& spirv_text, + std::unordered_map fmap) { + auto n = make_object(shaders, spirv_text, fmap); + n->Init(); + return Module(n); +} + // Load module from module. Module OpenCLModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data; diff --git a/src/runtime/opencl/opencl_module.h b/src/runtime/opencl/opencl_module.h index 77f4b8010779..be1674116223 100644 --- a/src/runtime/opencl/opencl_module.h +++ b/src/runtime/opencl/opencl_module.h @@ -32,6 +32,7 @@ #include #include "../meta_data.h" +#include "../vulkan/spirv_shader.h" namespace tvm { namespace runtime { @@ -44,6 +45,10 @@ namespace runtime { */ Module OpenCLModuleCreate(std::string data, std::string fmt, std::unordered_map fmap, std::string source); + +Module OpenCLModuleCreate(const std::unordered_map& shaders, + const std::string& spirv_text, + std::unordered_map fmap); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_OPENCL_OPENCL_MODULE_H_ diff --git a/src/runtime/vulkan/vulkan_shader.h b/src/runtime/vulkan/spirv_shader.h similarity index 92% rename from src/runtime/vulkan/vulkan_shader.h rename to src/runtime/vulkan/spirv_shader.h index 513e3bccc36e..4a9f616d9bc1 100644 --- a/src/runtime/vulkan/vulkan_shader.h +++ b/src/runtime/vulkan/spirv_shader.h @@ -31,7 +31,7 @@ namespace tvm { namespace runtime { namespace vulkan { -struct VulkanShader { +struct SPIRVShader { /*! \brief header flag */ uint32_t flag{0}; /*! \brief Data segment */ @@ -50,11 +50,11 @@ struct VulkanShader { } // namespace vulkan -using vulkan::VulkanShader; +using vulkan::SPIRVShader; } // namespace runtime } // namespace tvm namespace dmlc { -DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::vulkan::VulkanShader, true); +DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::vulkan::SPIRVShader, true); } // namespace dmlc #endif // TVM_RUNTIME_VULKAN_VULKAN_SHADER_H_ diff --git a/src/runtime/vulkan/vulkan_module.cc b/src/runtime/vulkan/vulkan_module.cc index 89104d9d63d9..232cf1d58ec7 100644 --- a/src/runtime/vulkan/vulkan_module.cc +++ b/src/runtime/vulkan/vulkan_module.cc @@ -29,7 +29,7 @@ namespace tvm { namespace runtime { namespace vulkan { -Module VulkanModuleCreate(std::unordered_map smap, +Module VulkanModuleCreate(std::unordered_map smap, std::unordered_map fmap, std::string source) { auto n = make_object(smap, fmap, source); return Module(n); @@ -37,7 +37,7 @@ Module VulkanModuleCreate(std::unordered_map smap, Module VulkanModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data; - std::unordered_map smap; + std::unordered_map smap; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); @@ -54,7 +54,7 @@ Module VulkanModuleLoadFile(const std::string& file_name, const std::string& for Module VulkanModuleLoadBinary(void* strm) { dmlc::Stream* stream = static_cast(strm); - std::unordered_map smap; + std::unordered_map smap; std::unordered_map fmap; std::string fmt; diff --git a/src/runtime/vulkan/vulkan_module.h b/src/runtime/vulkan/vulkan_module.h index c75a077a361d..075a5d60bb83 100644 --- a/src/runtime/vulkan/vulkan_module.h +++ b/src/runtime/vulkan/vulkan_module.h @@ -24,12 +24,12 @@ #include #include "../meta_data.h" -#include "vulkan_shader.h" +#include "spirv_shader.h" namespace tvm { namespace runtime { namespace vulkan { -Module VulkanModuleCreate(std::unordered_map smap, +Module VulkanModuleCreate(std::unordered_map smap, std::unordered_map fmap, std::string source); } // namespace vulkan diff --git a/src/runtime/vulkan/vulkan_wrapped_func.h b/src/runtime/vulkan/vulkan_wrapped_func.h index 187736e82a6d..31f61cdbb7d8 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.h +++ b/src/runtime/vulkan/vulkan_wrapped_func.h @@ -30,10 +30,10 @@ #include "../meta_data.h" #include "../pack_args.h" #include "../thread_storage_scope.h" +#include "spirv_shader.h" #include "vulkan/vulkan_core.h" #include "vulkan_common.h" #include "vulkan_device.h" -#include "vulkan_shader.h" namespace tvm { namespace runtime { @@ -82,7 +82,7 @@ class VulkanWrappedFunc { class VulkanModuleNode final : public runtime::ModuleNode { public: - explicit VulkanModuleNode(std::unordered_map smap, + explicit VulkanModuleNode(std::unordered_map smap, std::unordered_map fmap, std::string source) : smap_(smap), fmap_(fmap), source_(source) {} ~VulkanModuleNode(); @@ -106,7 +106,7 @@ class VulkanModuleNode final : public runtime::ModuleNode { private: // function information table. - std::unordered_map smap_; + std::unordered_map smap_; // function information table. std::unordered_map fmap_; // The format diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 525ee95f4117..a039164055f3 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -30,6 +30,7 @@ #include "../../runtime/texture.h" #include "../../runtime/thread_storage_scope.h" #include "../build_common.h" +#include "../spirv/codegen_spirv.h" namespace tvm { namespace codegen { @@ -585,6 +586,13 @@ void CodeGenOpenCL::SetTextureScope( } runtime::Module BuildOpenCL(IRModule mod, Target target) { + Optional device = target->GetAttr("device"); + + if (device && device.value() == "spirv") { + auto [smap, spirv_text] = TranslateToSPIRV(mod, target); + return runtime::OpenCLModuleCreate(smap, spirv_text, ExtractFuncInfo(mod)); + } + using tvm::runtime::Registry; bool output_ssa = false; diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index dc1d8f865baa..a0ab42390f97 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -29,8 +29,8 @@ #include #include +#include "../../runtime/vulkan/spirv_shader.h" #include "../../runtime/vulkan/vulkan_module.h" -#include "../../runtime/vulkan/vulkan_shader.h" #include "../../support/utils.h" #include "../build_common.h" #include "codegen_spirv.h" @@ -47,19 +47,24 @@ class SPIRVTools { target->GetAttr("max_spirv_version").value_or(0x10000).IntValue(); spv_target_env validation_version; - if (vulkan_version >= VK_API_VERSION_1_2) { - validation_version = SPV_ENV_VULKAN_1_2; - } else if (vulkan_version >= VK_API_VERSION_1_1 && spirv_version >= 0x10400) { - validation_version = SPV_ENV_VULKAN_1_1_SPIRV_1_4; - } else if (vulkan_version >= VK_API_VERSION_1_1) { - validation_version = SPV_ENV_VULKAN_1_1; + if (target->kind->name == "opencl") { + validation_version = SPV_ENV_OPENCL_2_2; } else { - validation_version = SPV_ENV_VULKAN_1_0; + if (vulkan_version >= VK_API_VERSION_1_2) { + validation_version = SPV_ENV_VULKAN_1_2; + } else if (vulkan_version >= VK_API_VERSION_1_1 && spirv_version >= 0x10400) { + validation_version = SPV_ENV_VULKAN_1_1_SPIRV_1_4; + } else if (vulkan_version >= VK_API_VERSION_1_1) { + validation_version = SPV_ENV_VULKAN_1_1; + } else { + validation_version = SPV_ENV_VULKAN_1_0; + } } - ctx_ = spvContextCreate(validation_version); } + ~SPIRVTools() { spvContextDestroy(ctx_); } + std::string BinaryToText(const std::vector& bin) { spv_text text = nullptr; spv_diagnostic diagnostic = nullptr; @@ -97,13 +102,14 @@ class SPIRVTools { spv_context ctx_; }; -runtime::Module BuildSPIRV(IRModule mod, Target target) { +std::pair, std::string> TranslateToSPIRV( + IRModule mod, Target target, bool webgpu_restriction) { using tvm::runtime::Registry; - using tvm::runtime::VulkanShader; + using tvm::runtime::SPIRVShader; std::ostringstream code_data; SPIRVTools spirv_tools(target); - std::unordered_map smap; + std::unordered_map smap; const auto* postproc = Registry::Get("tvm_callback_vulkan_postproc"); @@ -124,7 +130,7 @@ runtime::Module BuildSPIRV(IRModule mod, Target target) { std::string f_name = global_symbol.value(); std::string entry = f_name; - VulkanShader shader = cg.BuildFunction(f, entry); + SPIRVShader shader = cg.BuildFunction(f, entry); if (auto path = std::getenv("TVM_VULKAN_DEBUG_SHADER_SAVEPATH")) { if (*path) { @@ -158,7 +164,12 @@ runtime::Module BuildSPIRV(IRModule mod, Target target) { smap[f_name] = std::move(shader); } - return runtime::VulkanModuleCreate(smap, ExtractFuncInfo(mod), code_data.str()); + return std::make_pair(smap, code_data.str()); +} + +runtime::Module BuildSPIRV(IRModule mod, Target target, bool webgpu_restriction) { + auto [smap, spirv_text] = TranslateToSPIRV(mod, target, webgpu_restriction); + return runtime::VulkanModuleCreate(smap, ExtractFuncInfo(mod), spirv_text); } TVM_REGISTER_GLOBAL("target.build.vulkan").set_body_typed([](IRModule mod, Target target) { diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index e3ef5acb8331..2a4233b44bcf 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -31,7 +31,6 @@ #include "../../runtime/pack_args.h" #include "../../runtime/vulkan/vulkan_common.h" -#include "../../runtime/vulkan/vulkan_shader.h" #include "../../tir/transforms/ir_utils.h" namespace tvm { @@ -39,7 +38,7 @@ namespace codegen { CodeGenSPIRV::CodeGenSPIRV(Target target) : spirv_support_(target) {} -runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::string& name) { +runtime::SPIRVShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::string& name) { this->InitFuncState(); ICHECK(f->HasNonzeroAttr(tir::attr::kNoAlias)) << "SPIRV only takes restricted memory model"; std::vector pod_args; @@ -79,7 +78,7 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std:: spirv::Value func_ptr = builder_->NewFunction(); builder_->StartFunction(func_ptr); - runtime::VulkanShader shader; + runtime::SPIRVShader shader; if (pod_args.size() != 0) { std::vector value_types; diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index 08b9db0ee539..a7de2faae82d 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -37,7 +37,7 @@ #include #include "../../runtime/thread_storage_scope.h" -#include "../../runtime/vulkan/vulkan_shader.h" +#include "../../runtime/vulkan/spirv_shader.h" #include "ir_builder.h" #include "spirv_support.h" @@ -66,7 +66,7 @@ class CodeGenSPIRV : public ExprFunctor, * \param name The name of the target function. * \return The final spirv module. */ - virtual runtime::VulkanShader BuildFunction(const PrimFunc& f, const std::string& name); + virtual runtime::SPIRVShader BuildFunction(const PrimFunc& f, const std::string& name); /*! * \brief Create Value for expression e * \param e The expression to be created value for. @@ -153,7 +153,7 @@ class CodeGenSPIRV : public ExprFunctor, * product of the number of lanes of the buffer element type and * the number of lanes of the index. */ - void CheckContentType(DataType type, int index_lanes = 1) { + void CheckContentType(DataType type, int index_lanes = 1) const { ICHECK(element_type_known) << "Cannot check element type of buffer " << name_hint << " no previous element type defined"; DataType expected_type = element_type.with_lanes(index_lanes * element_type.lanes()); @@ -220,6 +220,9 @@ class CodeGenSPIRV : public ExprFunctor, size_t shared_memory_bytes_used_{0}; }; +std::pair, std::string> TranslateToSPIRV( + IRModule mod, Target target, bool webgpu_restriction = false); + } // namespace codegen } // namespace tvm diff --git a/src/target/spirv/spirv_support.cc b/src/target/spirv/spirv_support.cc index 81b5cd8b8a6a..1b46e7f08339 100644 --- a/src/target/spirv/spirv_support.cc +++ b/src/target/spirv/spirv_support.cc @@ -32,8 +32,9 @@ namespace tvm { namespace codegen { SPIRVSupport::SPIRVSupport(tvm::Target target) { - ICHECK_EQ(target->GetTargetDeviceType(), kDLVulkan) - << "SPIRVSupport can only be checked for vulkan device type"; + auto device_type = target->GetTargetDeviceType(); + ICHECK(device_type == kDLVulkan || device_type == kDLOpenCL || device_type == kDLWebGPU) + << "Unsupported device type for SPIRV codegen:" << device_type; if (target->GetAttr("vulkan_api_version")) { vulkan_api_version = target->GetAttr("vulkan_api_version").value().IntValue(); diff --git a/web/emcc/webgpu_runtime.cc b/web/emcc/webgpu_runtime.cc index 17efcc8c70a7..2bd012991b67 100644 --- a/web/emcc/webgpu_runtime.cc +++ b/web/emcc/webgpu_runtime.cc @@ -38,6 +38,7 @@ #include #include "../../src/runtime/meta_data.h" +#include "../../src/runtime/vulkan/spirv_shader.h" #include "../../src/runtime/workspace_pool.h" namespace tvm { From 3812b11b11be370bae24eb9ed57ae224d21a2529 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 7 May 2023 04:06:29 +0900 Subject: [PATCH 02/13] compile fixed --- src/runtime/opencl/opencl_common.h | 15 ++++----------- src/runtime/opencl/opencl_module.cc | 27 +++++++++++++++++---------- src/target/spirv/build_vulkan.cc | 6 +++--- src/target/spirv/codegen_spirv.h | 2 +- 4 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index 6d760b550582..58b4756ffea4 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -434,17 +434,7 @@ class OpenCLModuleNodeBase : public ModuleNode { return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; } - PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; - void SaveToFile(const std::string& file_name, const std::string& format) final; - void SaveToBinary(dmlc::Stream* stream) final; - std::string GetSource(const std::string& format) final; - // Initialize the programs - void Init(); - // install a new kernel to thread local entry - cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, - const std::string& func_name, const KTRefEntry& e); - void SetPreCompiledPrograms(const std::string& bytes); - std::string GetPreCompiledPrograms(); + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) override; // Initialize the programs virtual void Init() = 0; @@ -476,6 +466,8 @@ class OpenCLModuleNode : public OpenCLModuleNodeBase { void SaveToFile(const std::string& file_name, const std::string& format) final; void SaveToBinary(dmlc::Stream* stream) final; + void SetPreCompiledPrograms(const std::string& bytes); + std::string GetPreCompiledPrograms(); std::string GetSource(const std::string& format) final; // Initialize the programs @@ -483,6 +475,7 @@ class OpenCLModuleNode : public OpenCLModuleNodeBase { // install a new kernel to thread local entry cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, const std::string& func_name, const KTRefEntry& e) override; + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) override; private: // the binary data diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index bee88d897830..d066aff0a5e5 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -161,15 +161,6 @@ cl::OpenCLWorkspace* OpenCLModuleNodeBase::GetGlobalWorkspace() { PackedFunc OpenCLModuleNodeBase::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { ICHECK_EQ(sptr_to_self.get(), this); - if (name == "opencl.GetPreCompiledPrograms") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetPreCompiledPrograms(); - }); - } else if (name == "opencl.SetPreCompiledPrograms") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - this->SetPreCompiledPrograms(args[0]); - }); - } ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; auto it = fmap_.find(name); if (it == fmap_.end()) return PackedFunc(); @@ -368,6 +359,21 @@ std::string OpenCLModuleNode::GetPreCompiledPrograms() { return data; } +PackedFunc OpenCLModuleNode::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { + ICHECK_EQ(sptr_to_self.get(), this); + if (name == "opencl.GetPreCompiledPrograms") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->GetPreCompiledPrograms(); + }); + } else if (name == "opencl.SetPreCompiledPrograms") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + this->SetPreCompiledPrograms(args[0]); + }); + } + return OpenCLModuleNodeBase::GetFunction(name, sptr_to_self); +} + Module OpenCLModuleCreate(std::string data, std::string fmt, std::unordered_map fmap, std::string source) { auto n = make_object(data, fmt, fmap, source); @@ -418,8 +424,9 @@ cl_kernel OpenCLSPIRVModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenC size_t len = it->second.data.size() * sizeof(uint32_t); cl_int err; cl_device_id dev = w->devices[device_id]; + auto platform = w->device_to_platform[dev]; programs_[func_name][device_id] = - clCreateProgramWithBinary(w->context, 1, &dev, &len, &s, nullptr, &err); + clCreateProgramWithBinary(w->contexts[platform], 1, &dev, &len, &s, nullptr, &err); OPENCL_CHECK_ERROR(err); // build program diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index a0ab42390f97..ca239a69c8bf 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -103,7 +103,7 @@ class SPIRVTools { }; std::pair, std::string> TranslateToSPIRV( - IRModule mod, Target target, bool webgpu_restriction) { + IRModule mod, Target target) { using tvm::runtime::Registry; using tvm::runtime::SPIRVShader; @@ -167,8 +167,8 @@ std::pair, std::string> Tr return std::make_pair(smap, code_data.str()); } -runtime::Module BuildSPIRV(IRModule mod, Target target, bool webgpu_restriction) { - auto [smap, spirv_text] = TranslateToSPIRV(mod, target, webgpu_restriction); +runtime::Module BuildSPIRV(IRModule mod, Target target) { + auto [smap, spirv_text] = TranslateToSPIRV(mod, target); return runtime::VulkanModuleCreate(smap, ExtractFuncInfo(mod), spirv_text); } diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index a7de2faae82d..5a5c21fd58c5 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -221,7 +221,7 @@ class CodeGenSPIRV : public ExprFunctor, }; std::pair, std::string> TranslateToSPIRV( - IRModule mod, Target target, bool webgpu_restriction = false); + IRModule mod, Target target); } // namespace codegen } // namespace tvm From f7a6e9b217546513a0684c96083cba642e85d521 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 7 May 2023 04:48:35 +0900 Subject: [PATCH 03/13] clean up --- src/runtime/opencl/opencl_common.h | 4 ++-- src/runtime/opencl/opencl_module.cc | 14 ++++++-------- src/runtime/opencl/opencl_module.h | 9 ++++++++- src/target/source/codegen_opencl.cc | 2 +- src/target/spirv/build_vulkan.cc | 4 ++-- src/target/spirv/codegen_spirv.h | 10 +++++++++- web/emcc/webgpu_runtime.cc | 1 - 7 files changed, 28 insertions(+), 16 deletions(-) diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index 58b4756ffea4..d25d2db0eb9f 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -464,18 +464,18 @@ class OpenCLModuleNode : public OpenCLModuleNodeBase { std::unordered_map fmap, std::string source) : OpenCLModuleNodeBase(fmap), data_(data), fmt_(fmt), source_(source) {} + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; void SaveToFile(const std::string& file_name, const std::string& format) final; void SaveToBinary(dmlc::Stream* stream) final; void SetPreCompiledPrograms(const std::string& bytes); std::string GetPreCompiledPrograms(); - std::string GetSource(const std::string& format) final; + // Initialize the programs void Init() override; // install a new kernel to thread local entry cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, const std::string& func_name, const KTRefEntry& e) override; - PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) override; private: // the binary data diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index d066aff0a5e5..e39b4be17fb7 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -118,12 +118,9 @@ class OpenCLSPIRVModuleNode : public OpenCLModuleNodeBase { void SaveToFile(const std::string& file_name, const std::string& format) final; void SaveToBinary(dmlc::Stream* stream) final; - std::string GetSource(const std::string&) final { return spirv_text_; } - // Initialize the programs void Init() override; - // install a new kernel to thread local entry cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, const std::string& func_name, const KTRefEntry& e) override; @@ -233,8 +230,8 @@ void OpenCLModuleNode::Init() { ICHECK_EQ(fmap_.size(), parsed_kernels_.size()) << "The number of parsed kernel sources does not match the number of kernel functions"; // zero initialize cl_program pointers for each device kernel - for (const auto& [func_name, _] : parsed_kernels_) { - programs_.insert({func_name, std::vector(workspace_->devices.size(), nullptr)}); + for (auto& kv : parsed_kernels_) { + programs_.insert({kv.first, std::vector(workspace_->devices.size(), nullptr)}); } } @@ -382,7 +379,8 @@ Module OpenCLModuleCreate(std::string data, std::string fmt, } void OpenCLSPIRVModuleNode::SaveToFile(const std::string& file_name, const std::string& format) { - LOG(FATAL) << "Not implemented"; + // TODO(masahi): How SPIRV binaries should be save to a file? + LOG(FATAL) << "Not implemented."; } void OpenCLSPIRVModuleNode::SaveToBinary(dmlc::Stream* stream) { @@ -409,8 +407,8 @@ void OpenCLSPIRVModuleNode::Init() { } // zero initialize cl_program pointers for each device kernel - for (const auto& [func_name, _] : shaders_) { - programs_.insert({func_name, std::vector(workspace_->devices.size(), nullptr)}); + for (auto& kv : shaders_) { + programs_.insert({kv.first, std::vector(workspace_->devices.size(), nullptr)}); } } diff --git a/src/runtime/opencl/opencl_module.h b/src/runtime/opencl/opencl_module.h index be1674116223..415704142e9e 100644 --- a/src/runtime/opencl/opencl_module.h +++ b/src/runtime/opencl/opencl_module.h @@ -37,7 +37,7 @@ namespace tvm { namespace runtime { /*! - * \brief create a opencl module for GPU devices from data. + * \brief Create a opencl module for GPU devices from data. * * \param data The module data. * \param fmt The format of the data, can be "clbin", "cl" @@ -46,6 +46,13 @@ namespace runtime { Module OpenCLModuleCreate(std::string data, std::string fmt, std::unordered_map fmap, std::string source); +/*! + * \brief Create a opencl module from SPIRV. + * + * \param shaders The map from function names to SPIRV binaries. + * \param spirv_text The concatenated text representation of SPIRV modules. + * \param fmap The map function information map of each function. + */ Module OpenCLModuleCreate(const std::unordered_map& shaders, const std::string& spirv_text, std::unordered_map fmap); diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index a039164055f3..5956aafd3498 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -589,7 +589,7 @@ runtime::Module BuildOpenCL(IRModule mod, Target target) { Optional device = target->GetAttr("device"); if (device && device.value() == "spirv") { - auto [smap, spirv_text] = TranslateToSPIRV(mod, target); + auto [smap, spirv_text] = LowerToSPIRV(mod, target); return runtime::OpenCLModuleCreate(smap, spirv_text, ExtractFuncInfo(mod)); } diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index ca239a69c8bf..a2f0cd168c23 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -102,7 +102,7 @@ class SPIRVTools { spv_context ctx_; }; -std::pair, std::string> TranslateToSPIRV( +std::pair, std::string> LowerToSPIRV( IRModule mod, Target target) { using tvm::runtime::Registry; using tvm::runtime::SPIRVShader; @@ -168,7 +168,7 @@ std::pair, std::string> Tr } runtime::Module BuildSPIRV(IRModule mod, Target target) { - auto [smap, spirv_text] = TranslateToSPIRV(mod, target); + auto [smap, spirv_text] = LowerToSPIRV(mod, target); return runtime::VulkanModuleCreate(smap, ExtractFuncInfo(mod), spirv_text); } diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index 5a5c21fd58c5..c7825e0ce9a9 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -220,7 +220,15 @@ class CodeGenSPIRV : public ExprFunctor, size_t shared_memory_bytes_used_{0}; }; -std::pair, std::string> TranslateToSPIRV( +/*! + * \brief Lower an IRModule to SPIRV modules. + * + * \param mod The IRModule to lower. + * \param target The target information. + * \return The map from function names to SPIRV binaries, and the concatenated text representation + * of the SPIRV modules. + */ +std::pair, std::string> LowerToSPIRV( IRModule mod, Target target); } // namespace codegen diff --git a/web/emcc/webgpu_runtime.cc b/web/emcc/webgpu_runtime.cc index 2bd012991b67..17efcc8c70a7 100644 --- a/web/emcc/webgpu_runtime.cc +++ b/web/emcc/webgpu_runtime.cc @@ -38,7 +38,6 @@ #include #include "../../src/runtime/meta_data.h" -#include "../../src/runtime/vulkan/spirv_shader.h" #include "../../src/runtime/workspace_pool.h" namespace tvm { From 8e8f4f72a29ec70789974c40893764f53261b26d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 7 May 2023 09:27:39 +0900 Subject: [PATCH 04/13] fix build when vulkan is not enabled --- src/target/source/codegen_opencl.cc | 2 +- src/target/spirv/build_vulkan.cc | 1 + src/target/spirv/codegen_spirv.h | 11 ------- src/target/spirv/spirv_utils.h | 45 +++++++++++++++++++++++++++++ 4 files changed, 47 insertions(+), 12 deletions(-) create mode 100644 src/target/spirv/spirv_utils.h diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 5956aafd3498..d8a0eddcc1ab 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -30,7 +30,7 @@ #include "../../runtime/texture.h" #include "../../runtime/thread_storage_scope.h" #include "../build_common.h" -#include "../spirv/codegen_spirv.h" +#include "../spirv/spirv_utils.h" namespace tvm { namespace codegen { diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index a2f0cd168c23..e5561dad34cc 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -34,6 +34,7 @@ #include "../../support/utils.h" #include "../build_common.h" #include "codegen_spirv.h" +#include "spirv_utils.h" namespace tvm { namespace codegen { diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index c7825e0ce9a9..7564eb3ca608 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -220,17 +220,6 @@ class CodeGenSPIRV : public ExprFunctor, size_t shared_memory_bytes_used_{0}; }; -/*! - * \brief Lower an IRModule to SPIRV modules. - * - * \param mod The IRModule to lower. - * \param target The target information. - * \return The map from function names to SPIRV binaries, and the concatenated text representation - * of the SPIRV modules. - */ -std::pair, std::string> LowerToSPIRV( - IRModule mod, Target target); - } // namespace codegen } // namespace tvm diff --git a/src/target/spirv/spirv_utils.h b/src/target/spirv/spirv_utils.h new file mode 100644 index 000000000000..26d89c4a17a5 --- /dev/null +++ b/src/target/spirv/spirv_utils.h @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TARGET_SPIRV_SPIRV_UTILS_H_ +#define TVM_TARGET_SPIRV_SPIRV_UTILS_H_ + +#include +#include + +#include +#include + +#include "../../runtime/vulkan/spirv_shader.h" + +namespace tvm { +namespace codegen { +/*! + * \brief Lower an IRModule to SPIRV modules. + * + * \param mod The IRModule to lower. + * \param target The target information. + * \return The map from function names to SPIRV binaries, and the concatenated text representation + * of the SPIRV modules. + */ +std::pair, std::string> LowerToSPIRV( + IRModule mod, Target target); + +} // namespace codegen +} // namespace tvm +#endif // TVM_TARGET_SPIRV_SPIRV_UTILS_H_ From 2ee4cef4cbba7543639404cbb9e9f30e6349f96e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 8 May 2023 03:50:09 +0900 Subject: [PATCH 05/13] Introduce spirv_utils.cc --- src/target/spirv/build_vulkan.cc | 138 ------------------------- src/target/spirv/spirv_utils.cc | 170 +++++++++++++++++++++++++++++++ 2 files changed, 170 insertions(+), 138 deletions(-) create mode 100644 src/target/spirv/spirv_utils.cc diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index e5561dad34cc..9dab3b6d8ef7 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -21,153 +21,15 @@ * \file build_vulkan.cc * \brief Build SPIRV block */ -// Use libspirv for parsing and validating code. -#include -#include -#include - -#include -#include #include "../../runtime/vulkan/spirv_shader.h" #include "../../runtime/vulkan/vulkan_module.h" -#include "../../support/utils.h" #include "../build_common.h" -#include "codegen_spirv.h" #include "spirv_utils.h" namespace tvm { namespace codegen { -class SPIRVTools { - public: - explicit SPIRVTools(Target target) { - uint32_t vulkan_version = - target->GetAttr("vulkan_api_version").value_or(VK_API_VERSION_1_0).IntValue(); - uint32_t spirv_version = - target->GetAttr("max_spirv_version").value_or(0x10000).IntValue(); - - spv_target_env validation_version; - if (target->kind->name == "opencl") { - validation_version = SPV_ENV_OPENCL_2_2; - } else { - if (vulkan_version >= VK_API_VERSION_1_2) { - validation_version = SPV_ENV_VULKAN_1_2; - } else if (vulkan_version >= VK_API_VERSION_1_1 && spirv_version >= 0x10400) { - validation_version = SPV_ENV_VULKAN_1_1_SPIRV_1_4; - } else if (vulkan_version >= VK_API_VERSION_1_1) { - validation_version = SPV_ENV_VULKAN_1_1; - } else { - validation_version = SPV_ENV_VULKAN_1_0; - } - } - ctx_ = spvContextCreate(validation_version); - } - - ~SPIRVTools() { spvContextDestroy(ctx_); } - - std::string BinaryToText(const std::vector& bin) { - spv_text text = nullptr; - spv_diagnostic diagnostic = nullptr; - spv_const_binary_t spv_bin{bin.data(), bin.size()}; - - spv_result_t res = - spvBinaryToText(ctx_, spv_bin.code, spv_bin.wordCount, - SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES | SPV_BINARY_TO_TEXT_OPTION_INDENT, - &text, &diagnostic); - - ICHECK_EQ(res, SPV_SUCCESS) << " line=" << diagnostic->position.line - << " column=" << diagnostic->position.column - << " index=" << diagnostic->position.index - << " error:" << diagnostic->error; - spvDiagnosticDestroy(diagnostic); - - std::string ret(text->str); - spvTextDestroy(text); - return ret; - } - - void ValidateShader(const std::vector& bin) { - spv_const_binary_t spv_bin{bin.data(), bin.size()}; - - spv_diagnostic diagnostic = nullptr; - spv_result_t res = spvValidate(ctx_, &spv_bin, &diagnostic); - - ICHECK_EQ(res, SPV_SUCCESS) << " index=" << diagnostic->position.index - << " error:" << diagnostic->error; - - spvDiagnosticDestroy(diagnostic); - } - - private: - spv_context ctx_; -}; - -std::pair, std::string> LowerToSPIRV( - IRModule mod, Target target) { - using tvm::runtime::Registry; - using tvm::runtime::SPIRVShader; - - std::ostringstream code_data; - SPIRVTools spirv_tools(target); - std::unordered_map smap; - - const auto* postproc = Registry::Get("tvm_callback_vulkan_postproc"); - - mod = tir::transform::PointerValueTypeRewrite()(std::move(mod)); - - CodeGenSPIRV cg(target); - - for (auto kv : mod->functions) { - ICHECK(kv.second->IsInstance()) << "CodeGenSPIRV: Can only take PrimFunc"; - auto f = Downcast(kv.second); - auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) - << "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()) - << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; - - std::string f_name = global_symbol.value(); - std::string entry = f_name; - - SPIRVShader shader = cg.BuildFunction(f, entry); - - if (auto path = std::getenv("TVM_VULKAN_DEBUG_SHADER_SAVEPATH")) { - if (*path) { - std::stringstream ss; - ss << path << "/" << f_name << "_"; - std::string prefix = ss.str(); - - std::ofstream(prefix + "tir.txt") << f; - std::ofstream(prefix + "spv.txt") << spirv_tools.BinaryToText(shader.data); - std::ofstream(prefix + "spv.spv", std::ios::binary) - .write(reinterpret_cast(shader.data.data()), - sizeof(shader.data[0]) * shader.data.size()); - } - } - - if (!support::BoolEnvironmentVar("TVM_VULKAN_DISABLE_SHADER_VALIDATION")) { - spirv_tools.ValidateShader(shader.data); - } - - if (postproc != nullptr) { - TVMByteArray arr; - arr.data = reinterpret_cast(dmlc::BeginPtr(shader.data)); - arr.size = shader.data.size() * sizeof(uint32_t); - std::string transformed = (*postproc)(arr); - ICHECK_EQ(transformed.length() % 4U, 0U); - shader.data.resize(transformed.size() / 4U); - std::copy(transformed.begin(), transformed.end(), - reinterpret_cast(dmlc::BeginPtr(shader.data))); - } - code_data << spirv_tools.BinaryToText(shader.data); - smap[f_name] = std::move(shader); - } - - return std::make_pair(smap, code_data.str()); -} - runtime::Module BuildSPIRV(IRModule mod, Target target) { auto [smap, spirv_text] = LowerToSPIRV(mod, target); return runtime::VulkanModuleCreate(smap, ExtractFuncInfo(mod), spirv_text); diff --git a/src/target/spirv/spirv_utils.cc b/src/target/spirv/spirv_utils.cc new file mode 100644 index 000000000000..21aa4231c713 --- /dev/null +++ b/src/target/spirv/spirv_utils.cc @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file spirv_utils.cc + * \brief Build SPIRV block + */ +// Use libspirv for parsing and validating code. +#include +#include +#include + +#include +#include + +#include "../../runtime/vulkan/spirv_shader.h" +#include "../../support/utils.h" +#include "codegen_spirv.h" +#include "spirv_utils.h" + +namespace tvm { +namespace codegen { + +class SPIRVTools { + public: + explicit SPIRVTools(Target target) { + uint32_t vulkan_version = + target->GetAttr("vulkan_api_version").value_or(VK_API_VERSION_1_0).IntValue(); + uint32_t spirv_version = + target->GetAttr("max_spirv_version").value_or(0x10000).IntValue(); + + spv_target_env validation_version; + if (target->kind->name == "opencl") { + validation_version = SPV_ENV_OPENCL_2_2; + } else { + if (vulkan_version >= VK_API_VERSION_1_2) { + validation_version = SPV_ENV_VULKAN_1_2; + } else if (vulkan_version >= VK_API_VERSION_1_1 && spirv_version >= 0x10400) { + validation_version = SPV_ENV_VULKAN_1_1_SPIRV_1_4; + } else if (vulkan_version >= VK_API_VERSION_1_1) { + validation_version = SPV_ENV_VULKAN_1_1; + } else { + validation_version = SPV_ENV_VULKAN_1_0; + } + } + ctx_ = spvContextCreate(validation_version); + } + + ~SPIRVTools() { spvContextDestroy(ctx_); } + + std::string BinaryToText(const std::vector& bin) { + spv_text text = nullptr; + spv_diagnostic diagnostic = nullptr; + spv_const_binary_t spv_bin{bin.data(), bin.size()}; + + spv_result_t res = + spvBinaryToText(ctx_, spv_bin.code, spv_bin.wordCount, + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES | SPV_BINARY_TO_TEXT_OPTION_INDENT, + &text, &diagnostic); + + ICHECK_EQ(res, SPV_SUCCESS) << " line=" << diagnostic->position.line + << " column=" << diagnostic->position.column + << " index=" << diagnostic->position.index + << " error:" << diagnostic->error; + spvDiagnosticDestroy(diagnostic); + + std::string ret(text->str); + spvTextDestroy(text); + return ret; + } + + void ValidateShader(const std::vector& bin) { + spv_const_binary_t spv_bin{bin.data(), bin.size()}; + + spv_diagnostic diagnostic = nullptr; + spv_result_t res = spvValidate(ctx_, &spv_bin, &diagnostic); + + ICHECK_EQ(res, SPV_SUCCESS) << " index=" << diagnostic->position.index + << " error:" << diagnostic->error; + + spvDiagnosticDestroy(diagnostic); + } + + private: + spv_context ctx_; +}; + +std::pair, std::string> LowerToSPIRV( + IRModule mod, Target target) { + using tvm::runtime::Registry; + using tvm::runtime::SPIRVShader; + + std::ostringstream code_data; + SPIRVTools spirv_tools(target); + std::unordered_map smap; + + const auto* postproc = Registry::Get("tvm_callback_vulkan_postproc"); + + mod = tir::transform::PointerValueTypeRewrite()(std::move(mod)); + + CodeGenSPIRV cg(target); + + for (auto kv : mod->functions) { + ICHECK(kv.second->IsInstance()) << "CodeGenSPIRV: Can only take PrimFunc"; + auto f = Downcast(kv.second); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) + << "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol.defined()) + << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; + + std::string f_name = global_symbol.value(); + std::string entry = f_name; + + SPIRVShader shader = cg.BuildFunction(f, entry); + + if (auto path = std::getenv("TVM_VULKAN_DEBUG_SHADER_SAVEPATH")) { + if (*path) { + std::stringstream ss; + ss << path << "/" << f_name << "_"; + std::string prefix = ss.str(); + + std::ofstream(prefix + "tir.txt") << f; + std::ofstream(prefix + "spv.txt") << spirv_tools.BinaryToText(shader.data); + std::ofstream(prefix + "spv.spv", std::ios::binary) + .write(reinterpret_cast(shader.data.data()), + sizeof(shader.data[0]) * shader.data.size()); + } + } + + if (!support::BoolEnvironmentVar("TVM_VULKAN_DISABLE_SHADER_VALIDATION")) { + spirv_tools.ValidateShader(shader.data); + } + + if (postproc != nullptr) { + TVMByteArray arr; + arr.data = reinterpret_cast(dmlc::BeginPtr(shader.data)); + arr.size = shader.data.size() * sizeof(uint32_t); + std::string transformed = (*postproc)(arr); + ICHECK_EQ(transformed.length() % 4U, 0U); + shader.data.resize(transformed.size() / 4U); + std::copy(transformed.begin(), transformed.end(), + reinterpret_cast(dmlc::BeginPtr(shader.data))); + } + code_data << spirv_tools.BinaryToText(shader.data); + smap[f_name] = std::move(shader); + } + + return std::make_pair(smap, code_data.str()); +} + +} // namespace codegen +} // namespace tvm From 414ccb6bb3011508b0789eded6eb4c3865ee2a30 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 8 May 2023 03:55:41 +0900 Subject: [PATCH 06/13] add dummy impl for LowerToSPIRV in case vulkan is not enabled --- cmake/modules/OpenCL.cmake | 1 + cmake/modules/Vulkan.cmake | 1 + src/runtime/vulkan/spirv_shader.h | 6 +++--- src/target/spirv/spirv_utils.cc | 24 +++++++++++++++++++++--- src/target/spirv/spirv_utils.h | 1 + 5 files changed, 27 insertions(+), 6 deletions(-) diff --git a/cmake/modules/OpenCL.cmake b/cmake/modules/OpenCL.cmake index 53199f19cb25..97bfaf76e188 100644 --- a/cmake/modules/OpenCL.cmake +++ b/cmake/modules/OpenCL.cmake @@ -41,6 +41,7 @@ endif(USE_AOCL) if(USE_OPENCL) tvm_file_glob(GLOB RUNTIME_OPENCL_SRCS src/runtime/opencl/*.cc) + list(APPEND COMPILER_SRCS src/target/spirv/spirv_utils.cc) if(${USE_OPENCL} MATCHES ${IS_TRUE_PATTERN}) message(STATUS "Enabled runtime search for OpenCL library location") diff --git a/cmake/modules/Vulkan.cmake b/cmake/modules/Vulkan.cmake index 7470fb6125a4..7bd75877f103 100644 --- a/cmake/modules/Vulkan.cmake +++ b/cmake/modules/Vulkan.cmake @@ -34,4 +34,5 @@ if(USE_VULKAN) list(APPEND COMPILER_SRCS ${COMPILER_VULKAN_SRCS}) list(APPEND TVM_LINKER_LIBS ${Vulkan_SPIRV_TOOLS_LIBRARY}) list(APPEND TVM_RUNTIME_LINKER_LIBS ${Vulkan_LIBRARY}) + add_definitions(-DTVM_USE_VULKAN=1) endif(USE_VULKAN) diff --git a/src/runtime/vulkan/spirv_shader.h b/src/runtime/vulkan/spirv_shader.h index 4a9f616d9bc1..3393692045ae 100644 --- a/src/runtime/vulkan/spirv_shader.h +++ b/src/runtime/vulkan/spirv_shader.h @@ -17,8 +17,8 @@ * under the License. */ -#ifndef TVM_RUNTIME_VULKAN_VULKAN_SHADER_H_ -#define TVM_RUNTIME_VULKAN_VULKAN_SHADER_H_ +#ifndef TVM_RUNTIME_VULKAN_SPIRV_SHADER_H_ +#define TVM_RUNTIME_VULKAN_SPIRV_SHADER_H_ #include #include @@ -57,4 +57,4 @@ using vulkan::SPIRVShader; namespace dmlc { DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::vulkan::SPIRVShader, true); } // namespace dmlc -#endif // TVM_RUNTIME_VULKAN_VULKAN_SHADER_H_ +#endif // TVM_RUNTIME_VULKAN_SPIRV_SHADER_H_ diff --git a/src/target/spirv/spirv_utils.cc b/src/target/spirv/spirv_utils.cc index 21aa4231c713..efc120485eff 100644 --- a/src/target/spirv/spirv_utils.cc +++ b/src/target/spirv/spirv_utils.cc @@ -22,21 +22,29 @@ * \brief Build SPIRV block */ // Use libspirv for parsing and validating code. -#include +#include "spirv_utils.h" + +#if TVM_USE_VULKAN #include + +#include "codegen_spirv.h" +#endif + #include +#include #include #include +#include #include "../../runtime/vulkan/spirv_shader.h" #include "../../support/utils.h" -#include "codegen_spirv.h" -#include "spirv_utils.h" namespace tvm { namespace codegen { +#if TVM_USE_VULKAN + class SPIRVTools { public: explicit SPIRVTools(Target target) { @@ -166,5 +174,15 @@ std::pair, std::string> Lo return std::make_pair(smap, code_data.str()); } +#else + +std::pair, std::string> LowerToSPIRV( + IRModule mod, Target target) { + LOG(FATAL) << "LowerToSPIRV is called but Vulkan is not enabled."; + return {}; +} + +#endif + } // namespace codegen } // namespace tvm diff --git a/src/target/spirv/spirv_utils.h b/src/target/spirv/spirv_utils.h index 26d89c4a17a5..cb14d32e9e70 100644 --- a/src/target/spirv/spirv_utils.h +++ b/src/target/spirv/spirv_utils.h @@ -24,6 +24,7 @@ #include #include +#include #include "../../runtime/vulkan/spirv_shader.h" From 70a8f84b307cbcb80a76032c86ef1c7e2b101bf6 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 8 May 2023 06:03:16 +0900 Subject: [PATCH 07/13] more fix --- cmake/modules/OpenCL.cmake | 2 +- src/target/source/codegen_opencl.cc | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/cmake/modules/OpenCL.cmake b/cmake/modules/OpenCL.cmake index 97bfaf76e188..f380ad75d14c 100644 --- a/cmake/modules/OpenCL.cmake +++ b/cmake/modules/OpenCL.cmake @@ -41,7 +41,7 @@ endif(USE_AOCL) if(USE_OPENCL) tvm_file_glob(GLOB RUNTIME_OPENCL_SRCS src/runtime/opencl/*.cc) - list(APPEND COMPILER_SRCS src/target/spirv/spirv_utils.cc) + list(APPEND COMPILER_SRCS src/target/spirv/spirv_utils.cc) if(${USE_OPENCL} MATCHES ${IS_TRUE_PATTERN}) message(STATUS "Enabled runtime search for OpenCL library location") diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index d8a0eddcc1ab..5690066951a6 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -586,12 +586,13 @@ void CodeGenOpenCL::SetTextureScope( } runtime::Module BuildOpenCL(IRModule mod, Target target) { +#if TVM_USE_VULKAN Optional device = target->GetAttr("device"); - if (device && device.value() == "spirv") { auto [smap, spirv_text] = LowerToSPIRV(mod, target); return runtime::OpenCLModuleCreate(smap, spirv_text, ExtractFuncInfo(mod)); } +#endif using tvm::runtime::Registry; bool output_ssa = false; From 17827b2ab5c3d80f2bba29ac305110c0990f34ee Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 8 May 2023 14:09:35 +0900 Subject: [PATCH 08/13] TVM_USE_VULKAN -> TVM_ENABLE_SPIRV --- cmake/modules/Vulkan.cmake | 2 +- src/target/source/codegen_opencl.cc | 2 +- src/target/spirv/spirv_utils.cc | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cmake/modules/Vulkan.cmake b/cmake/modules/Vulkan.cmake index 7bd75877f103..1f303f3a032b 100644 --- a/cmake/modules/Vulkan.cmake +++ b/cmake/modules/Vulkan.cmake @@ -34,5 +34,5 @@ if(USE_VULKAN) list(APPEND COMPILER_SRCS ${COMPILER_VULKAN_SRCS}) list(APPEND TVM_LINKER_LIBS ${Vulkan_SPIRV_TOOLS_LIBRARY}) list(APPEND TVM_RUNTIME_LINKER_LIBS ${Vulkan_LIBRARY}) - add_definitions(-DTVM_USE_VULKAN=1) + add_definitions(-DTVM_ENABLE_SPIRV=1) endif(USE_VULKAN) diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 5690066951a6..de96f923e2fa 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -586,7 +586,7 @@ void CodeGenOpenCL::SetTextureScope( } runtime::Module BuildOpenCL(IRModule mod, Target target) { -#if TVM_USE_VULKAN +#if TVM_ENABLE_SPIRV Optional device = target->GetAttr("device"); if (device && device.value() == "spirv") { auto [smap, spirv_text] = LowerToSPIRV(mod, target); diff --git a/src/target/spirv/spirv_utils.cc b/src/target/spirv/spirv_utils.cc index efc120485eff..fc84a9a64311 100644 --- a/src/target/spirv/spirv_utils.cc +++ b/src/target/spirv/spirv_utils.cc @@ -24,7 +24,7 @@ // Use libspirv for parsing and validating code. #include "spirv_utils.h" -#if TVM_USE_VULKAN +#if TVM_ENABLE_SPIRV #include #include "codegen_spirv.h" @@ -43,7 +43,7 @@ namespace tvm { namespace codegen { -#if TVM_USE_VULKAN +#if TVM_ENABLE_SPIRV class SPIRVTools { public: From 6e88482eb006b53bf16e375dc54905e5b5471246 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 8 May 2023 19:11:12 +0900 Subject: [PATCH 09/13] build fix when opencl is not enabled --- src/target/opt/build_opencl_off.cc | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/target/opt/build_opencl_off.cc b/src/target/opt/build_opencl_off.cc index 2367500eca92..9e368d5599cf 100644 --- a/src/target/opt/build_opencl_off.cc +++ b/src/target/opt/build_opencl_off.cc @@ -31,5 +31,12 @@ Module OpenCLModuleCreate(std::string data, std::string fmt, return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "opencl"); } +Module OpenCLModuleCreate(const std::unordered_map& shaders, + const std::string& spirv_text, + std::unordered_map fmap) { + LOG(FATAL) << "OpenCLModuleCreate is called but OpenCL is not enabled."; + return Module(); +} + } // namespace runtime } // namespace tvm From f706b811500fc6dc2b8a6e4dbb8639eb4e02ee24 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 8 May 2023 19:23:15 +0900 Subject: [PATCH 10/13] mv spirv_shader.h under new spirv folder --- src/runtime/opencl/opencl_module.cc | 2 +- src/runtime/opencl/opencl_module.h | 2 +- src/runtime/{vulkan => spirv}/spirv_shader.h | 6 +++--- src/runtime/vulkan/vulkan_module.h | 2 +- src/runtime/vulkan/vulkan_wrapped_func.h | 2 +- src/target/spirv/build_vulkan.cc | 2 +- src/target/spirv/codegen_spirv.h | 2 +- src/target/spirv/spirv_utils.cc | 2 +- src/target/spirv/spirv_utils.h | 2 +- 9 files changed, 11 insertions(+), 11 deletions(-) rename src/runtime/{vulkan => spirv}/spirv_shader.h (92%) diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index e39b4be17fb7..1058faeec672 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -30,7 +30,7 @@ #include #include "../source_utils.h" -#include "../vulkan/spirv_shader.h" +#include "../spirv/spirv_shader.h" #include "opencl_common.h" namespace tvm { diff --git a/src/runtime/opencl/opencl_module.h b/src/runtime/opencl/opencl_module.h index 415704142e9e..ac8a7e74e75e 100644 --- a/src/runtime/opencl/opencl_module.h +++ b/src/runtime/opencl/opencl_module.h @@ -32,7 +32,7 @@ #include #include "../meta_data.h" -#include "../vulkan/spirv_shader.h" +#include "../spirv/spirv_shader.h" namespace tvm { namespace runtime { diff --git a/src/runtime/vulkan/spirv_shader.h b/src/runtime/spirv/spirv_shader.h similarity index 92% rename from src/runtime/vulkan/spirv_shader.h rename to src/runtime/spirv/spirv_shader.h index 3393692045ae..cda42f52db36 100644 --- a/src/runtime/vulkan/spirv_shader.h +++ b/src/runtime/spirv/spirv_shader.h @@ -17,8 +17,8 @@ * under the License. */ -#ifndef TVM_RUNTIME_VULKAN_SPIRV_SHADER_H_ -#define TVM_RUNTIME_VULKAN_SPIRV_SHADER_H_ +#ifndef TVM_RUNTIME_SPIRV_SPIRV_SHADER_H_ +#define TVM_RUNTIME_SPIRV_SPIRV_SHADER_H_ #include #include @@ -57,4 +57,4 @@ using vulkan::SPIRVShader; namespace dmlc { DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::vulkan::SPIRVShader, true); } // namespace dmlc -#endif // TVM_RUNTIME_VULKAN_SPIRV_SHADER_H_ +#endif // TVM_RUNTIME_SPIRV_SPIRV_SHADER_H_ diff --git a/src/runtime/vulkan/vulkan_module.h b/src/runtime/vulkan/vulkan_module.h index 075a5d60bb83..878e096f5ac1 100644 --- a/src/runtime/vulkan/vulkan_module.h +++ b/src/runtime/vulkan/vulkan_module.h @@ -24,7 +24,7 @@ #include #include "../meta_data.h" -#include "spirv_shader.h" +#include "../spirv/spirv_shader.h" namespace tvm { namespace runtime { diff --git a/src/runtime/vulkan/vulkan_wrapped_func.h b/src/runtime/vulkan/vulkan_wrapped_func.h index 31f61cdbb7d8..ab2d2a88ed86 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.h +++ b/src/runtime/vulkan/vulkan_wrapped_func.h @@ -30,7 +30,7 @@ #include "../meta_data.h" #include "../pack_args.h" #include "../thread_storage_scope.h" -#include "spirv_shader.h" +#include "../spirv/spirv_shader.h" #include "vulkan/vulkan_core.h" #include "vulkan_common.h" #include "vulkan_device.h" diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index 9dab3b6d8ef7..5690ef05de5c 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -22,7 +22,7 @@ * \brief Build SPIRV block */ -#include "../../runtime/vulkan/spirv_shader.h" +#include "../../runtime/spirv/spirv_shader.h" #include "../../runtime/vulkan/vulkan_module.h" #include "../build_common.h" #include "spirv_utils.h" diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index 7564eb3ca608..475610e44081 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -37,7 +37,7 @@ #include #include "../../runtime/thread_storage_scope.h" -#include "../../runtime/vulkan/spirv_shader.h" +#include "../../runtime/spirv/spirv_shader.h" #include "ir_builder.h" #include "spirv_support.h" diff --git a/src/target/spirv/spirv_utils.cc b/src/target/spirv/spirv_utils.cc index fc84a9a64311..7b65c0a59f67 100644 --- a/src/target/spirv/spirv_utils.cc +++ b/src/target/spirv/spirv_utils.cc @@ -37,7 +37,7 @@ #include #include -#include "../../runtime/vulkan/spirv_shader.h" +#include "../../runtime/spirv/spirv_shader.h" #include "../../support/utils.h" namespace tvm { diff --git a/src/target/spirv/spirv_utils.h b/src/target/spirv/spirv_utils.h index cb14d32e9e70..b441a559f813 100644 --- a/src/target/spirv/spirv_utils.h +++ b/src/target/spirv/spirv_utils.h @@ -26,7 +26,7 @@ #include #include -#include "../../runtime/vulkan/spirv_shader.h" +#include "../../runtime/spirv/spirv_shader.h" namespace tvm { namespace codegen { From ae6724409f553edb579614f30f1e135c7f1a7eb7 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 8 May 2023 19:27:33 +0900 Subject: [PATCH 11/13] mv spirv module to its own file --- src/runtime/opencl/opencl_module.cc | 104 +--------------- src/runtime/opencl/opencl_module_spirv.cc | 137 ++++++++++++++++++++++ src/runtime/vulkan/vulkan_wrapped_func.h | 2 +- src/target/spirv/codegen_spirv.h | 2 +- 4 files changed, 140 insertions(+), 105 deletions(-) create mode 100644 src/runtime/opencl/opencl_module_spirv.cc diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 1058faeec672..45154ce2312c 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -30,7 +30,6 @@ #include #include "../source_utils.h" -#include "../spirv/spirv_shader.h" #include "opencl_common.h" namespace tvm { @@ -109,26 +108,6 @@ class OpenCLWrappedFunc { LaunchParamConfig launch_param_config_; }; -class OpenCLSPIRVModuleNode : public OpenCLModuleNodeBase { - public: - explicit OpenCLSPIRVModuleNode(const std::unordered_map& shaders, - const std::string& spirv_text, - std::unordered_map fmap) - : OpenCLModuleNodeBase(fmap), shaders_(shaders), spirv_text_(spirv_text) {} - - void SaveToFile(const std::string& file_name, const std::string& format) final; - void SaveToBinary(dmlc::Stream* stream) final; - std::string GetSource(const std::string&) final { return spirv_text_; } - - void Init() override; - cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, - const std::string& func_name, const KTRefEntry& e) override; - - private: - std::unordered_map shaders_; - std::string spirv_text_; -}; - OpenCLModuleNodeBase::~OpenCLModuleNodeBase() { { // free the kernel ids in global table. @@ -250,7 +229,7 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThre programs_[func_name][device_id] = clCreateProgramWithSource(w->contexts[platform], 1, &s, &len, &err); OPENCL_CHECK_ERROR(err); - } else if (fmt_ == "xclbin" || fmt_ == "awsxclbin" || fmt_ == "aocx" || fmt_ == "spirv") { + } else if (fmt_ == "xclbin" || fmt_ == "awsxclbin" || fmt_ == "aocx") { const unsigned char* s = (const unsigned char*)data_.c_str(); size_t len = data_.length(); cl_int err; @@ -378,87 +357,6 @@ Module OpenCLModuleCreate(std::string data, std::string fmt, return Module(n); } -void OpenCLSPIRVModuleNode::SaveToFile(const std::string& file_name, const std::string& format) { - // TODO(masahi): How SPIRV binaries should be save to a file? - LOG(FATAL) << "Not implemented."; -} - -void OpenCLSPIRVModuleNode::SaveToBinary(dmlc::Stream* stream) { - stream->Write(fmap_); - stream->Write(shaders_); -} - -void OpenCLSPIRVModuleNode::Init() { - workspace_ = GetGlobalWorkspace(); - workspace_->Init(); - // initialize the kernel id, need to lock global table. - std::lock_guard lock(workspace_->mu); - for (const auto& kv : fmap_) { - const std::string& key = kv.first; - KTRefEntry e; - if (workspace_->free_kernel_ids.size() != 0) { - e.kernel_id = workspace_->free_kernel_ids.back(); - workspace_->free_kernel_ids.pop_back(); - } else { - e.kernel_id = workspace_->num_registered_kernels++; - } - e.version = workspace_->timestamp++; - kid_map_[key] = e; - } - - // zero initialize cl_program pointers for each device kernel - for (auto& kv : shaders_) { - programs_.insert({kv.first, std::vector(workspace_->devices.size(), nullptr)}); - } -} - -cl_kernel OpenCLSPIRVModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, - const std::string& func_name, const KTRefEntry& e) { - std::lock_guard lock(build_lock_); - int device_id = t->device.device_id; - if (programs_[func_name][device_id] == nullptr) { - auto it = shaders_.find(func_name); - const unsigned char* s = (const unsigned char*)it->second.data.data(); - size_t len = it->second.data.size() * sizeof(uint32_t); - cl_int err; - cl_device_id dev = w->devices[device_id]; - auto platform = w->device_to_platform[dev]; - programs_[func_name][device_id] = - clCreateProgramWithBinary(w->contexts[platform], 1, &dev, &len, &s, nullptr, &err); - OPENCL_CHECK_ERROR(err); - - // build program - err = clBuildProgram(programs_[func_name][device_id], 1, &dev, nullptr, nullptr, nullptr); - - if (err != CL_SUCCESS) { - size_t len; - std::string log; - clGetProgramBuildInfo(programs_[func_name][device_id], dev, CL_PROGRAM_BUILD_LOG, 0, nullptr, - &len); - log.resize(len); - clGetProgramBuildInfo(programs_[func_name][device_id], dev, CL_PROGRAM_BUILD_LOG, len, - &log[0], nullptr); - LOG(FATAL) << "OpenCL build error for device=" << dev << "\n" << log; - } - } - // build kernel - cl_int err; - cl_kernel kernel = clCreateKernel(programs_[func_name][device_id], func_name.c_str(), &err); - OPENCL_CHECK_ERROR(err); - t->kernel_table[e.kernel_id].kernel = kernel; - t->kernel_table[e.kernel_id].version = e.version; - kernels_.push_back(kernel); - return kernel; -} - -Module OpenCLModuleCreate(const std::unordered_map& shaders, - const std::string& spirv_text, - std::unordered_map fmap) { - auto n = make_object(shaders, spirv_text, fmap); - n->Init(); - return Module(n); -} - // Load module from module. Module OpenCLModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data; diff --git a/src/runtime/opencl/opencl_module_spirv.cc b/src/runtime/opencl/opencl_module_spirv.cc new file mode 100644 index 000000000000..5e3ecf2eeb8b --- /dev/null +++ b/src/runtime/opencl/opencl_module_spirv.cc @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +#include +#include +#include + +#include "../source_utils.h" +#include "../spirv/spirv_shader.h" +#include "opencl_common.h" +#include "opencl_module.h" + +namespace tvm { +namespace runtime { + +class OpenCLSPIRVModuleNode : public OpenCLModuleNodeBase { + public: + explicit OpenCLSPIRVModuleNode(const std::unordered_map& shaders, + const std::string& spirv_text, + std::unordered_map fmap) + : OpenCLModuleNodeBase(fmap), shaders_(shaders), spirv_text_(spirv_text) {} + + void SaveToFile(const std::string& file_name, const std::string& format) final; + void SaveToBinary(dmlc::Stream* stream) final; + std::string GetSource(const std::string&) final { return spirv_text_; } + + void Init() override; + cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, + const std::string& func_name, const KTRefEntry& e) override; + + private: + std::unordered_map shaders_; + std::string spirv_text_; +}; + +void OpenCLSPIRVModuleNode::SaveToFile(const std::string& file_name, const std::string& format) { + // TODO(masahi): How SPIRV binaries should be save to a file? + LOG(FATAL) << "Not implemented."; +} + +void OpenCLSPIRVModuleNode::SaveToBinary(dmlc::Stream* stream) { + stream->Write(fmap_); + stream->Write(shaders_); +} + +void OpenCLSPIRVModuleNode::Init() { + workspace_ = GetGlobalWorkspace(); + workspace_->Init(); + // initialize the kernel id, need to lock global table. + std::lock_guard lock(workspace_->mu); + for (const auto& kv : fmap_) { + const std::string& key = kv.first; + KTRefEntry e; + if (workspace_->free_kernel_ids.size() != 0) { + e.kernel_id = workspace_->free_kernel_ids.back(); + workspace_->free_kernel_ids.pop_back(); + } else { + e.kernel_id = workspace_->num_registered_kernels++; + } + e.version = workspace_->timestamp++; + kid_map_[key] = e; + } + + // zero initialize cl_program pointers for each device kernel + for (auto& kv : shaders_) { + programs_.insert({kv.first, std::vector(workspace_->devices.size(), nullptr)}); + } +} + +cl_kernel OpenCLSPIRVModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, + const std::string& func_name, const KTRefEntry& e) { + std::lock_guard lock(build_lock_); + int device_id = t->device.device_id; + if (programs_[func_name][device_id] == nullptr) { + auto it = shaders_.find(func_name); + const unsigned char* s = (const unsigned char*)it->second.data.data(); + size_t len = it->second.data.size() * sizeof(uint32_t); + cl_int err; + cl_device_id dev = w->devices[device_id]; + auto platform = w->device_to_platform[dev]; + programs_[func_name][device_id] = + clCreateProgramWithBinary(w->contexts[platform], 1, &dev, &len, &s, nullptr, &err); + OPENCL_CHECK_ERROR(err); + + // build program + err = clBuildProgram(programs_[func_name][device_id], 1, &dev, nullptr, nullptr, nullptr); + + if (err != CL_SUCCESS) { + size_t len; + std::string log; + clGetProgramBuildInfo(programs_[func_name][device_id], dev, CL_PROGRAM_BUILD_LOG, 0, nullptr, + &len); + log.resize(len); + clGetProgramBuildInfo(programs_[func_name][device_id], dev, CL_PROGRAM_BUILD_LOG, len, + &log[0], nullptr); + LOG(FATAL) << "OpenCL build error for device=" << dev << "\n" << log; + } + } + // build kernel + cl_int err; + cl_kernel kernel = clCreateKernel(programs_[func_name][device_id], func_name.c_str(), &err); + OPENCL_CHECK_ERROR(err); + t->kernel_table[e.kernel_id].kernel = kernel; + t->kernel_table[e.kernel_id].version = e.version; + kernels_.push_back(kernel); + return kernel; +} + +Module OpenCLModuleCreate(const std::unordered_map& shaders, + const std::string& spirv_text, + std::unordered_map fmap) { + auto n = make_object(shaders, spirv_text, fmap); + n->Init(); + return Module(n); +} + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_wrapped_func.h b/src/runtime/vulkan/vulkan_wrapped_func.h index ab2d2a88ed86..285edcd3533d 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.h +++ b/src/runtime/vulkan/vulkan_wrapped_func.h @@ -29,8 +29,8 @@ #include "../meta_data.h" #include "../pack_args.h" -#include "../thread_storage_scope.h" #include "../spirv/spirv_shader.h" +#include "../thread_storage_scope.h" #include "vulkan/vulkan_core.h" #include "vulkan_common.h" #include "vulkan_device.h" diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index 475610e44081..f2d771070ed9 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -36,8 +36,8 @@ #include #include -#include "../../runtime/thread_storage_scope.h" #include "../../runtime/spirv/spirv_shader.h" +#include "../../runtime/thread_storage_scope.h" #include "ir_builder.h" #include "spirv_support.h" From ec7f6806b210099f3ddd6a29eecec7aa1aa7279d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 9 May 2023 00:59:07 +0900 Subject: [PATCH 12/13] fix after reorg --- src/runtime/spirv/spirv_shader.h | 8 ++++---- src/target/spirv/spirv_utils.cc | 3 ++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/runtime/spirv/spirv_shader.h b/src/runtime/spirv/spirv_shader.h index cda42f52db36..293dc5b78638 100644 --- a/src/runtime/spirv/spirv_shader.h +++ b/src/runtime/spirv/spirv_shader.h @@ -29,7 +29,7 @@ namespace tvm { namespace runtime { -namespace vulkan { +namespace spirv { struct SPIRVShader { /*! \brief header flag */ @@ -48,13 +48,13 @@ struct SPIRVShader { } }; -} // namespace vulkan +} // namespace spirv -using vulkan::SPIRVShader; +using spirv::SPIRVShader; } // namespace runtime } // namespace tvm namespace dmlc { -DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::vulkan::SPIRVShader, true); +DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::spirv::SPIRVShader, true); } // namespace dmlc #endif // TVM_RUNTIME_SPIRV_SPIRV_SHADER_H_ diff --git a/src/target/spirv/spirv_utils.cc b/src/target/spirv/spirv_utils.cc index 7b65c0a59f67..2a9110d87124 100644 --- a/src/target/spirv/spirv_utils.cc +++ b/src/target/spirv/spirv_utils.cc @@ -178,7 +178,8 @@ std::pair, std::string> Lo std::pair, std::string> LowerToSPIRV( IRModule mod, Target target) { - LOG(FATAL) << "LowerToSPIRV is called but Vulkan is not enabled."; + LOG(FATAL) + << "LowerToSPIRV is called but SPIRV codegen is not enabled. Please set -DUSE_VULKAN=ON."; return {}; } From f572d46e6e1f22852085176f4d86e361021be56e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 9 May 2023 07:34:06 +0900 Subject: [PATCH 13/13] build fix --- src/runtime/opencl/opencl_module.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/runtime/opencl/opencl_module.h b/src/runtime/opencl/opencl_module.h index ac8a7e74e75e..834f53510ecc 100644 --- a/src/runtime/opencl/opencl_module.h +++ b/src/runtime/opencl/opencl_module.h @@ -53,7 +53,7 @@ Module OpenCLModuleCreate(std::string data, std::string fmt, * \param spirv_text The concatenated text representation of SPIRV modules. * \param fmap The map function information map of each function. */ -Module OpenCLModuleCreate(const std::unordered_map& shaders, +Module OpenCLModuleCreate(const std::unordered_map& shaders, const std::string& spirv_text, std::unordered_map fmap); } // namespace runtime