diff --git a/cmake/modules/OpenCL.cmake b/cmake/modules/OpenCL.cmake index 53199f19cb25..f380ad75d14c 100644 --- a/cmake/modules/OpenCL.cmake +++ b/cmake/modules/OpenCL.cmake @@ -41,6 +41,7 @@ endif(USE_AOCL) if(USE_OPENCL) tvm_file_glob(GLOB RUNTIME_OPENCL_SRCS src/runtime/opencl/*.cc) + list(APPEND COMPILER_SRCS src/target/spirv/spirv_utils.cc) if(${USE_OPENCL} MATCHES ${IS_TRUE_PATTERN}) message(STATUS "Enabled runtime search for OpenCL library location") diff --git a/cmake/modules/Vulkan.cmake b/cmake/modules/Vulkan.cmake index 7470fb6125a4..1f303f3a032b 100644 --- a/cmake/modules/Vulkan.cmake +++ b/cmake/modules/Vulkan.cmake @@ -34,4 +34,5 @@ if(USE_VULKAN) list(APPEND COMPILER_SRCS ${COMPILER_VULKAN_SRCS}) list(APPEND TVM_LINKER_LIBS ${Vulkan_SPIRV_TOOLS_LIBRARY}) list(APPEND TVM_RUNTIME_LINKER_LIBS ${Vulkan_LIBRARY}) + add_definitions(-DTVM_ENABLE_SPIRV=1) endif(USE_VULKAN) diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index a8a4cf3dc65c..d25d2db0eb9f 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -411,18 +411,16 @@ struct BufferDescriptor { // To make the call thread-safe, we create a thread-local kernel table // and lazily install new kernels into the kernel table when the kernel is called. // The kernels are recycled when the module get destructed. -class OpenCLModuleNode : public ModuleNode { +class OpenCLModuleNodeBase : public ModuleNode { public: // Kernel table reference entry. struct KTRefEntry { size_t kernel_id; size_t version; }; - explicit OpenCLModuleNode(std::string data, std::string fmt, - std::unordered_map fmap, std::string source) - : data_(data), fmt_(fmt), fmap_(fmap), source_(source) {} + explicit OpenCLModuleNodeBase(std::unordered_map fmap) : fmap_(fmap) {} // destructor - ~OpenCLModuleNode(); + ~OpenCLModuleNodeBase(); /*! * \brief Get the global workspace @@ -436,38 +434,56 @@ class OpenCLModuleNode : public ModuleNode { return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; } - PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; - void SaveToFile(const std::string& file_name, const std::string& format) final; - void SaveToBinary(dmlc::Stream* stream) final; - std::string GetSource(const std::string& format) final; + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) override; + // Initialize the programs - void Init(); + virtual void Init() = 0; // install a new kernel to thread local entry - cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, - const std::string& func_name, const KTRefEntry& e); - void SetPreCompiledPrograms(const std::string& bytes); - std::string GetPreCompiledPrograms(); + virtual cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, + const std::string& func_name, const KTRefEntry& e) = 0; - private: + protected: // The workspace, need to keep reference to use it in destructor. // In case of static destruction order problem. cl::OpenCLWorkspace* workspace_; - // the binary data - std::string data_; - // The format - std::string fmt_; // function information table. std::unordered_map fmap_; // Module local mutex std::mutex build_lock_; - // The OpenCL source. - std::string source_; // Mapping from primitive name to cl program for each device. std::unordered_map> programs_; // kernel id cache std::unordered_map kid_map_; - // kernels build so far. + // kernels built so far. std::vector kernels_; +}; + +class OpenCLModuleNode : public OpenCLModuleNodeBase { + public: + explicit OpenCLModuleNode(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) + : OpenCLModuleNodeBase(fmap), data_(data), fmt_(fmt), source_(source) {} + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; + void SaveToFile(const std::string& file_name, const std::string& format) final; + void SaveToBinary(dmlc::Stream* stream) final; + void SetPreCompiledPrograms(const std::string& bytes); + std::string GetPreCompiledPrograms(); + std::string GetSource(const std::string& format) final; + + // Initialize the programs + void Init() override; + // install a new kernel to thread local entry + cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, + const std::string& func_name, const KTRefEntry& e) override; + + private: + // the binary data + std::string data_; + // The format + std::string fmt_; + // The OpenCL source. + std::string source_; // parsed kernel data std::unordered_map parsed_kernels_; }; diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 7c084758a456..45154ce2312c 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -38,7 +38,7 @@ namespace runtime { class OpenCLWrappedFunc { public: // initialize the OpenCL function. - void Init(OpenCLModuleNode* m, ObjectPtr sptr, OpenCLModuleNode::KTRefEntry entry, + void Init(OpenCLModuleNodeBase* m, ObjectPtr sptr, OpenCLModuleNode::KTRefEntry entry, std::string func_name, std::vector arg_size, const std::vector& launch_param_tags) { w_ = m->GetGlobalWorkspace(); @@ -95,7 +95,7 @@ class OpenCLWrappedFunc { // global workspace. cl::OpenCLWorkspace* w_; // The module - OpenCLModuleNode* m_; + OpenCLModuleNodeBase* m_; // resource handle ObjectPtr sptr_; // global kernel id in the kernel table. @@ -108,7 +108,7 @@ class OpenCLWrappedFunc { LaunchParamConfig launch_param_config_; }; -OpenCLModuleNode::~OpenCLModuleNode() { +OpenCLModuleNodeBase::~OpenCLModuleNodeBase() { { // free the kernel ids in global table. std::lock_guard lock(workspace_->mu); @@ -130,22 +130,13 @@ OpenCLModuleNode::~OpenCLModuleNode() { } } -cl::OpenCLWorkspace* OpenCLModuleNode::GetGlobalWorkspace() { +cl::OpenCLWorkspace* OpenCLModuleNodeBase::GetGlobalWorkspace() { return cl::OpenCLWorkspace::Global(); } -PackedFunc OpenCLModuleNode::GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc OpenCLModuleNodeBase::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { ICHECK_EQ(sptr_to_self.get(), this); - if (name == "opencl.GetPreCompiledPrograms") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetPreCompiledPrograms(); - }); - } else if (name == "opencl.SetPreCompiledPrograms") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - this->SetPreCompiledPrograms(args[0]); - }); - } ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; auto it = fmap_.find(name); if (it == fmap_.end()) return PackedFunc(); @@ -344,6 +335,21 @@ std::string OpenCLModuleNode::GetPreCompiledPrograms() { return data; } +PackedFunc OpenCLModuleNode::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { + ICHECK_EQ(sptr_to_self.get(), this); + if (name == "opencl.GetPreCompiledPrograms") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->GetPreCompiledPrograms(); + }); + } else if (name == "opencl.SetPreCompiledPrograms") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + this->SetPreCompiledPrograms(args[0]); + }); + } + return OpenCLModuleNodeBase::GetFunction(name, sptr_to_self); +} + Module OpenCLModuleCreate(std::string data, std::string fmt, std::unordered_map fmap, std::string source) { auto n = make_object(data, fmt, fmap, source); diff --git a/src/runtime/opencl/opencl_module.h b/src/runtime/opencl/opencl_module.h index 77f4b8010779..834f53510ecc 100644 --- a/src/runtime/opencl/opencl_module.h +++ b/src/runtime/opencl/opencl_module.h @@ -32,11 +32,12 @@ #include #include "../meta_data.h" +#include "../spirv/spirv_shader.h" namespace tvm { namespace runtime { /*! - * \brief create a opencl module for GPU devices from data. + * \brief Create a opencl module for GPU devices from data. * * \param data The module data. * \param fmt The format of the data, can be "clbin", "cl" @@ -44,6 +45,17 @@ namespace runtime { */ Module OpenCLModuleCreate(std::string data, std::string fmt, std::unordered_map fmap, std::string source); + +/*! + * \brief Create a opencl module from SPIRV. + * + * \param shaders The map from function names to SPIRV binaries. + * \param spirv_text The concatenated text representation of SPIRV modules. + * \param fmap The map function information map of each function. + */ +Module OpenCLModuleCreate(const std::unordered_map& shaders, + const std::string& spirv_text, + std::unordered_map fmap); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_OPENCL_OPENCL_MODULE_H_ diff --git a/src/runtime/opencl/opencl_module_spirv.cc b/src/runtime/opencl/opencl_module_spirv.cc new file mode 100644 index 000000000000..5e3ecf2eeb8b --- /dev/null +++ b/src/runtime/opencl/opencl_module_spirv.cc @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +#include +#include +#include + +#include "../source_utils.h" +#include "../spirv/spirv_shader.h" +#include "opencl_common.h" +#include "opencl_module.h" + +namespace tvm { +namespace runtime { + +class OpenCLSPIRVModuleNode : public OpenCLModuleNodeBase { + public: + explicit OpenCLSPIRVModuleNode(const std::unordered_map& shaders, + const std::string& spirv_text, + std::unordered_map fmap) + : OpenCLModuleNodeBase(fmap), shaders_(shaders), spirv_text_(spirv_text) {} + + void SaveToFile(const std::string& file_name, const std::string& format) final; + void SaveToBinary(dmlc::Stream* stream) final; + std::string GetSource(const std::string&) final { return spirv_text_; } + + void Init() override; + cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, + const std::string& func_name, const KTRefEntry& e) override; + + private: + std::unordered_map shaders_; + std::string spirv_text_; +}; + +void OpenCLSPIRVModuleNode::SaveToFile(const std::string& file_name, const std::string& format) { + // TODO(masahi): How SPIRV binaries should be save to a file? + LOG(FATAL) << "Not implemented."; +} + +void OpenCLSPIRVModuleNode::SaveToBinary(dmlc::Stream* stream) { + stream->Write(fmap_); + stream->Write(shaders_); +} + +void OpenCLSPIRVModuleNode::Init() { + workspace_ = GetGlobalWorkspace(); + workspace_->Init(); + // initialize the kernel id, need to lock global table. + std::lock_guard lock(workspace_->mu); + for (const auto& kv : fmap_) { + const std::string& key = kv.first; + KTRefEntry e; + if (workspace_->free_kernel_ids.size() != 0) { + e.kernel_id = workspace_->free_kernel_ids.back(); + workspace_->free_kernel_ids.pop_back(); + } else { + e.kernel_id = workspace_->num_registered_kernels++; + } + e.version = workspace_->timestamp++; + kid_map_[key] = e; + } + + // zero initialize cl_program pointers for each device kernel + for (auto& kv : shaders_) { + programs_.insert({kv.first, std::vector(workspace_->devices.size(), nullptr)}); + } +} + +cl_kernel OpenCLSPIRVModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, + const std::string& func_name, const KTRefEntry& e) { + std::lock_guard lock(build_lock_); + int device_id = t->device.device_id; + if (programs_[func_name][device_id] == nullptr) { + auto it = shaders_.find(func_name); + const unsigned char* s = (const unsigned char*)it->second.data.data(); + size_t len = it->second.data.size() * sizeof(uint32_t); + cl_int err; + cl_device_id dev = w->devices[device_id]; + auto platform = w->device_to_platform[dev]; + programs_[func_name][device_id] = + clCreateProgramWithBinary(w->contexts[platform], 1, &dev, &len, &s, nullptr, &err); + OPENCL_CHECK_ERROR(err); + + // build program + err = clBuildProgram(programs_[func_name][device_id], 1, &dev, nullptr, nullptr, nullptr); + + if (err != CL_SUCCESS) { + size_t len; + std::string log; + clGetProgramBuildInfo(programs_[func_name][device_id], dev, CL_PROGRAM_BUILD_LOG, 0, nullptr, + &len); + log.resize(len); + clGetProgramBuildInfo(programs_[func_name][device_id], dev, CL_PROGRAM_BUILD_LOG, len, + &log[0], nullptr); + LOG(FATAL) << "OpenCL build error for device=" << dev << "\n" << log; + } + } + // build kernel + cl_int err; + cl_kernel kernel = clCreateKernel(programs_[func_name][device_id], func_name.c_str(), &err); + OPENCL_CHECK_ERROR(err); + t->kernel_table[e.kernel_id].kernel = kernel; + t->kernel_table[e.kernel_id].version = e.version; + kernels_.push_back(kernel); + return kernel; +} + +Module OpenCLModuleCreate(const std::unordered_map& shaders, + const std::string& spirv_text, + std::unordered_map fmap) { + auto n = make_object(shaders, spirv_text, fmap); + n->Init(); + return Module(n); +} + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_shader.h b/src/runtime/spirv/spirv_shader.h similarity index 82% rename from src/runtime/vulkan/vulkan_shader.h rename to src/runtime/spirv/spirv_shader.h index 513e3bccc36e..293dc5b78638 100644 --- a/src/runtime/vulkan/vulkan_shader.h +++ b/src/runtime/spirv/spirv_shader.h @@ -17,8 +17,8 @@ * under the License. */ -#ifndef TVM_RUNTIME_VULKAN_VULKAN_SHADER_H_ -#define TVM_RUNTIME_VULKAN_VULKAN_SHADER_H_ +#ifndef TVM_RUNTIME_SPIRV_SPIRV_SHADER_H_ +#define TVM_RUNTIME_SPIRV_SPIRV_SHADER_H_ #include #include @@ -29,9 +29,9 @@ namespace tvm { namespace runtime { -namespace vulkan { +namespace spirv { -struct VulkanShader { +struct SPIRVShader { /*! \brief header flag */ uint32_t flag{0}; /*! \brief Data segment */ @@ -48,13 +48,13 @@ struct VulkanShader { } }; -} // namespace vulkan +} // namespace spirv -using vulkan::VulkanShader; +using spirv::SPIRVShader; } // namespace runtime } // namespace tvm namespace dmlc { -DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::vulkan::VulkanShader, true); +DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::spirv::SPIRVShader, true); } // namespace dmlc -#endif // TVM_RUNTIME_VULKAN_VULKAN_SHADER_H_ +#endif // TVM_RUNTIME_SPIRV_SPIRV_SHADER_H_ diff --git a/src/runtime/vulkan/vulkan_module.cc b/src/runtime/vulkan/vulkan_module.cc index 89104d9d63d9..232cf1d58ec7 100644 --- a/src/runtime/vulkan/vulkan_module.cc +++ b/src/runtime/vulkan/vulkan_module.cc @@ -29,7 +29,7 @@ namespace tvm { namespace runtime { namespace vulkan { -Module VulkanModuleCreate(std::unordered_map smap, +Module VulkanModuleCreate(std::unordered_map smap, std::unordered_map fmap, std::string source) { auto n = make_object(smap, fmap, source); return Module(n); @@ -37,7 +37,7 @@ Module VulkanModuleCreate(std::unordered_map smap, Module VulkanModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data; - std::unordered_map smap; + std::unordered_map smap; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); @@ -54,7 +54,7 @@ Module VulkanModuleLoadFile(const std::string& file_name, const std::string& for Module VulkanModuleLoadBinary(void* strm) { dmlc::Stream* stream = static_cast(strm); - std::unordered_map smap; + std::unordered_map smap; std::unordered_map fmap; std::string fmt; diff --git a/src/runtime/vulkan/vulkan_module.h b/src/runtime/vulkan/vulkan_module.h index c75a077a361d..878e096f5ac1 100644 --- a/src/runtime/vulkan/vulkan_module.h +++ b/src/runtime/vulkan/vulkan_module.h @@ -24,12 +24,12 @@ #include #include "../meta_data.h" -#include "vulkan_shader.h" +#include "../spirv/spirv_shader.h" namespace tvm { namespace runtime { namespace vulkan { -Module VulkanModuleCreate(std::unordered_map smap, +Module VulkanModuleCreate(std::unordered_map smap, std::unordered_map fmap, std::string source); } // namespace vulkan diff --git a/src/runtime/vulkan/vulkan_wrapped_func.h b/src/runtime/vulkan/vulkan_wrapped_func.h index 187736e82a6d..285edcd3533d 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.h +++ b/src/runtime/vulkan/vulkan_wrapped_func.h @@ -29,11 +29,11 @@ #include "../meta_data.h" #include "../pack_args.h" +#include "../spirv/spirv_shader.h" #include "../thread_storage_scope.h" #include "vulkan/vulkan_core.h" #include "vulkan_common.h" #include "vulkan_device.h" -#include "vulkan_shader.h" namespace tvm { namespace runtime { @@ -82,7 +82,7 @@ class VulkanWrappedFunc { class VulkanModuleNode final : public runtime::ModuleNode { public: - explicit VulkanModuleNode(std::unordered_map smap, + explicit VulkanModuleNode(std::unordered_map smap, std::unordered_map fmap, std::string source) : smap_(smap), fmap_(fmap), source_(source) {} ~VulkanModuleNode(); @@ -106,7 +106,7 @@ class VulkanModuleNode final : public runtime::ModuleNode { private: // function information table. - std::unordered_map smap_; + std::unordered_map smap_; // function information table. std::unordered_map fmap_; // The format diff --git a/src/target/opt/build_opencl_off.cc b/src/target/opt/build_opencl_off.cc index 2367500eca92..9e368d5599cf 100644 --- a/src/target/opt/build_opencl_off.cc +++ b/src/target/opt/build_opencl_off.cc @@ -31,5 +31,12 @@ Module OpenCLModuleCreate(std::string data, std::string fmt, return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "opencl"); } +Module OpenCLModuleCreate(const std::unordered_map& shaders, + const std::string& spirv_text, + std::unordered_map fmap) { + LOG(FATAL) << "OpenCLModuleCreate is called but OpenCL is not enabled."; + return Module(); +} + } // namespace runtime } // namespace tvm diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 525ee95f4117..de96f923e2fa 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -30,6 +30,7 @@ #include "../../runtime/texture.h" #include "../../runtime/thread_storage_scope.h" #include "../build_common.h" +#include "../spirv/spirv_utils.h" namespace tvm { namespace codegen { @@ -585,6 +586,14 @@ void CodeGenOpenCL::SetTextureScope( } runtime::Module BuildOpenCL(IRModule mod, Target target) { +#if TVM_ENABLE_SPIRV + Optional device = target->GetAttr("device"); + if (device && device.value() == "spirv") { + auto [smap, spirv_text] = LowerToSPIRV(mod, target); + return runtime::OpenCLModuleCreate(smap, spirv_text, ExtractFuncInfo(mod)); + } +#endif + using tvm::runtime::Registry; bool output_ssa = false; diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index dc1d8f865baa..5690ef05de5c 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -21,144 +21,18 @@ * \file build_vulkan.cc * \brief Build SPIRV block */ -// Use libspirv for parsing and validating code. -#include -#include -#include - -#include -#include +#include "../../runtime/spirv/spirv_shader.h" #include "../../runtime/vulkan/vulkan_module.h" -#include "../../runtime/vulkan/vulkan_shader.h" -#include "../../support/utils.h" #include "../build_common.h" -#include "codegen_spirv.h" +#include "spirv_utils.h" namespace tvm { namespace codegen { -class SPIRVTools { - public: - explicit SPIRVTools(Target target) { - uint32_t vulkan_version = - target->GetAttr("vulkan_api_version").value_or(VK_API_VERSION_1_0).IntValue(); - uint32_t spirv_version = - target->GetAttr("max_spirv_version").value_or(0x10000).IntValue(); - - spv_target_env validation_version; - if (vulkan_version >= VK_API_VERSION_1_2) { - validation_version = SPV_ENV_VULKAN_1_2; - } else if (vulkan_version >= VK_API_VERSION_1_1 && spirv_version >= 0x10400) { - validation_version = SPV_ENV_VULKAN_1_1_SPIRV_1_4; - } else if (vulkan_version >= VK_API_VERSION_1_1) { - validation_version = SPV_ENV_VULKAN_1_1; - } else { - validation_version = SPV_ENV_VULKAN_1_0; - } - - ctx_ = spvContextCreate(validation_version); - } - ~SPIRVTools() { spvContextDestroy(ctx_); } - std::string BinaryToText(const std::vector& bin) { - spv_text text = nullptr; - spv_diagnostic diagnostic = nullptr; - spv_const_binary_t spv_bin{bin.data(), bin.size()}; - - spv_result_t res = - spvBinaryToText(ctx_, spv_bin.code, spv_bin.wordCount, - SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES | SPV_BINARY_TO_TEXT_OPTION_INDENT, - &text, &diagnostic); - - ICHECK_EQ(res, SPV_SUCCESS) << " line=" << diagnostic->position.line - << " column=" << diagnostic->position.column - << " index=" << diagnostic->position.index - << " error:" << diagnostic->error; - spvDiagnosticDestroy(diagnostic); - - std::string ret(text->str); - spvTextDestroy(text); - return ret; - } - - void ValidateShader(const std::vector& bin) { - spv_const_binary_t spv_bin{bin.data(), bin.size()}; - - spv_diagnostic diagnostic = nullptr; - spv_result_t res = spvValidate(ctx_, &spv_bin, &diagnostic); - - ICHECK_EQ(res, SPV_SUCCESS) << " index=" << diagnostic->position.index - << " error:" << diagnostic->error; - - spvDiagnosticDestroy(diagnostic); - } - - private: - spv_context ctx_; -}; - runtime::Module BuildSPIRV(IRModule mod, Target target) { - using tvm::runtime::Registry; - using tvm::runtime::VulkanShader; - - std::ostringstream code_data; - SPIRVTools spirv_tools(target); - std::unordered_map smap; - - const auto* postproc = Registry::Get("tvm_callback_vulkan_postproc"); - - mod = tir::transform::PointerValueTypeRewrite()(std::move(mod)); - - CodeGenSPIRV cg(target); - - for (auto kv : mod->functions) { - ICHECK(kv.second->IsInstance()) << "CodeGenSPIRV: Can only take PrimFunc"; - auto f = Downcast(kv.second); - auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) - << "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()) - << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; - - std::string f_name = global_symbol.value(); - std::string entry = f_name; - - VulkanShader shader = cg.BuildFunction(f, entry); - - if (auto path = std::getenv("TVM_VULKAN_DEBUG_SHADER_SAVEPATH")) { - if (*path) { - std::stringstream ss; - ss << path << "/" << f_name << "_"; - std::string prefix = ss.str(); - - std::ofstream(prefix + "tir.txt") << f; - std::ofstream(prefix + "spv.txt") << spirv_tools.BinaryToText(shader.data); - std::ofstream(prefix + "spv.spv", std::ios::binary) - .write(reinterpret_cast(shader.data.data()), - sizeof(shader.data[0]) * shader.data.size()); - } - } - - if (!support::BoolEnvironmentVar("TVM_VULKAN_DISABLE_SHADER_VALIDATION")) { - spirv_tools.ValidateShader(shader.data); - } - - if (postproc != nullptr) { - TVMByteArray arr; - arr.data = reinterpret_cast(dmlc::BeginPtr(shader.data)); - arr.size = shader.data.size() * sizeof(uint32_t); - std::string transformed = (*postproc)(arr); - ICHECK_EQ(transformed.length() % 4U, 0U); - shader.data.resize(transformed.size() / 4U); - std::copy(transformed.begin(), transformed.end(), - reinterpret_cast(dmlc::BeginPtr(shader.data))); - } - code_data << spirv_tools.BinaryToText(shader.data); - smap[f_name] = std::move(shader); - } - - return runtime::VulkanModuleCreate(smap, ExtractFuncInfo(mod), code_data.str()); + auto [smap, spirv_text] = LowerToSPIRV(mod, target); + return runtime::VulkanModuleCreate(smap, ExtractFuncInfo(mod), spirv_text); } TVM_REGISTER_GLOBAL("target.build.vulkan").set_body_typed([](IRModule mod, Target target) { diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index e3ef5acb8331..2a4233b44bcf 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -31,7 +31,6 @@ #include "../../runtime/pack_args.h" #include "../../runtime/vulkan/vulkan_common.h" -#include "../../runtime/vulkan/vulkan_shader.h" #include "../../tir/transforms/ir_utils.h" namespace tvm { @@ -39,7 +38,7 @@ namespace codegen { CodeGenSPIRV::CodeGenSPIRV(Target target) : spirv_support_(target) {} -runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::string& name) { +runtime::SPIRVShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::string& name) { this->InitFuncState(); ICHECK(f->HasNonzeroAttr(tir::attr::kNoAlias)) << "SPIRV only takes restricted memory model"; std::vector pod_args; @@ -79,7 +78,7 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std:: spirv::Value func_ptr = builder_->NewFunction(); builder_->StartFunction(func_ptr); - runtime::VulkanShader shader; + runtime::SPIRVShader shader; if (pod_args.size() != 0) { std::vector value_types; diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index 08b9db0ee539..f2d771070ed9 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -36,8 +36,8 @@ #include #include +#include "../../runtime/spirv/spirv_shader.h" #include "../../runtime/thread_storage_scope.h" -#include "../../runtime/vulkan/vulkan_shader.h" #include "ir_builder.h" #include "spirv_support.h" @@ -66,7 +66,7 @@ class CodeGenSPIRV : public ExprFunctor, * \param name The name of the target function. * \return The final spirv module. */ - virtual runtime::VulkanShader BuildFunction(const PrimFunc& f, const std::string& name); + virtual runtime::SPIRVShader BuildFunction(const PrimFunc& f, const std::string& name); /*! * \brief Create Value for expression e * \param e The expression to be created value for. @@ -153,7 +153,7 @@ class CodeGenSPIRV : public ExprFunctor, * product of the number of lanes of the buffer element type and * the number of lanes of the index. */ - void CheckContentType(DataType type, int index_lanes = 1) { + void CheckContentType(DataType type, int index_lanes = 1) const { ICHECK(element_type_known) << "Cannot check element type of buffer " << name_hint << " no previous element type defined"; DataType expected_type = element_type.with_lanes(index_lanes * element_type.lanes()); diff --git a/src/target/spirv/spirv_support.cc b/src/target/spirv/spirv_support.cc index 81b5cd8b8a6a..1b46e7f08339 100644 --- a/src/target/spirv/spirv_support.cc +++ b/src/target/spirv/spirv_support.cc @@ -32,8 +32,9 @@ namespace tvm { namespace codegen { SPIRVSupport::SPIRVSupport(tvm::Target target) { - ICHECK_EQ(target->GetTargetDeviceType(), kDLVulkan) - << "SPIRVSupport can only be checked for vulkan device type"; + auto device_type = target->GetTargetDeviceType(); + ICHECK(device_type == kDLVulkan || device_type == kDLOpenCL || device_type == kDLWebGPU) + << "Unsupported device type for SPIRV codegen:" << device_type; if (target->GetAttr("vulkan_api_version")) { vulkan_api_version = target->GetAttr("vulkan_api_version").value().IntValue(); diff --git a/src/target/spirv/spirv_utils.cc b/src/target/spirv/spirv_utils.cc new file mode 100644 index 000000000000..2a9110d87124 --- /dev/null +++ b/src/target/spirv/spirv_utils.cc @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file spirv_utils.cc + * \brief Build SPIRV block + */ +// Use libspirv for parsing and validating code. +#include "spirv_utils.h" + +#if TVM_ENABLE_SPIRV +#include + +#include "codegen_spirv.h" +#endif + +#include + +#include +#include +#include +#include + +#include "../../runtime/spirv/spirv_shader.h" +#include "../../support/utils.h" + +namespace tvm { +namespace codegen { + +#if TVM_ENABLE_SPIRV + +class SPIRVTools { + public: + explicit SPIRVTools(Target target) { + uint32_t vulkan_version = + target->GetAttr("vulkan_api_version").value_or(VK_API_VERSION_1_0).IntValue(); + uint32_t spirv_version = + target->GetAttr("max_spirv_version").value_or(0x10000).IntValue(); + + spv_target_env validation_version; + if (target->kind->name == "opencl") { + validation_version = SPV_ENV_OPENCL_2_2; + } else { + if (vulkan_version >= VK_API_VERSION_1_2) { + validation_version = SPV_ENV_VULKAN_1_2; + } else if (vulkan_version >= VK_API_VERSION_1_1 && spirv_version >= 0x10400) { + validation_version = SPV_ENV_VULKAN_1_1_SPIRV_1_4; + } else if (vulkan_version >= VK_API_VERSION_1_1) { + validation_version = SPV_ENV_VULKAN_1_1; + } else { + validation_version = SPV_ENV_VULKAN_1_0; + } + } + ctx_ = spvContextCreate(validation_version); + } + + ~SPIRVTools() { spvContextDestroy(ctx_); } + + std::string BinaryToText(const std::vector& bin) { + spv_text text = nullptr; + spv_diagnostic diagnostic = nullptr; + spv_const_binary_t spv_bin{bin.data(), bin.size()}; + + spv_result_t res = + spvBinaryToText(ctx_, spv_bin.code, spv_bin.wordCount, + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES | SPV_BINARY_TO_TEXT_OPTION_INDENT, + &text, &diagnostic); + + ICHECK_EQ(res, SPV_SUCCESS) << " line=" << diagnostic->position.line + << " column=" << diagnostic->position.column + << " index=" << diagnostic->position.index + << " error:" << diagnostic->error; + spvDiagnosticDestroy(diagnostic); + + std::string ret(text->str); + spvTextDestroy(text); + return ret; + } + + void ValidateShader(const std::vector& bin) { + spv_const_binary_t spv_bin{bin.data(), bin.size()}; + + spv_diagnostic diagnostic = nullptr; + spv_result_t res = spvValidate(ctx_, &spv_bin, &diagnostic); + + ICHECK_EQ(res, SPV_SUCCESS) << " index=" << diagnostic->position.index + << " error:" << diagnostic->error; + + spvDiagnosticDestroy(diagnostic); + } + + private: + spv_context ctx_; +}; + +std::pair, std::string> LowerToSPIRV( + IRModule mod, Target target) { + using tvm::runtime::Registry; + using tvm::runtime::SPIRVShader; + + std::ostringstream code_data; + SPIRVTools spirv_tools(target); + std::unordered_map smap; + + const auto* postproc = Registry::Get("tvm_callback_vulkan_postproc"); + + mod = tir::transform::PointerValueTypeRewrite()(std::move(mod)); + + CodeGenSPIRV cg(target); + + for (auto kv : mod->functions) { + ICHECK(kv.second->IsInstance()) << "CodeGenSPIRV: Can only take PrimFunc"; + auto f = Downcast(kv.second); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) + << "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol.defined()) + << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; + + std::string f_name = global_symbol.value(); + std::string entry = f_name; + + SPIRVShader shader = cg.BuildFunction(f, entry); + + if (auto path = std::getenv("TVM_VULKAN_DEBUG_SHADER_SAVEPATH")) { + if (*path) { + std::stringstream ss; + ss << path << "/" << f_name << "_"; + std::string prefix = ss.str(); + + std::ofstream(prefix + "tir.txt") << f; + std::ofstream(prefix + "spv.txt") << spirv_tools.BinaryToText(shader.data); + std::ofstream(prefix + "spv.spv", std::ios::binary) + .write(reinterpret_cast(shader.data.data()), + sizeof(shader.data[0]) * shader.data.size()); + } + } + + if (!support::BoolEnvironmentVar("TVM_VULKAN_DISABLE_SHADER_VALIDATION")) { + spirv_tools.ValidateShader(shader.data); + } + + if (postproc != nullptr) { + TVMByteArray arr; + arr.data = reinterpret_cast(dmlc::BeginPtr(shader.data)); + arr.size = shader.data.size() * sizeof(uint32_t); + std::string transformed = (*postproc)(arr); + ICHECK_EQ(transformed.length() % 4U, 0U); + shader.data.resize(transformed.size() / 4U); + std::copy(transformed.begin(), transformed.end(), + reinterpret_cast(dmlc::BeginPtr(shader.data))); + } + code_data << spirv_tools.BinaryToText(shader.data); + smap[f_name] = std::move(shader); + } + + return std::make_pair(smap, code_data.str()); +} + +#else + +std::pair, std::string> LowerToSPIRV( + IRModule mod, Target target) { + LOG(FATAL) + << "LowerToSPIRV is called but SPIRV codegen is not enabled. Please set -DUSE_VULKAN=ON."; + return {}; +} + +#endif + +} // namespace codegen +} // namespace tvm diff --git a/src/target/spirv/spirv_utils.h b/src/target/spirv/spirv_utils.h new file mode 100644 index 000000000000..b441a559f813 --- /dev/null +++ b/src/target/spirv/spirv_utils.h @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TARGET_SPIRV_SPIRV_UTILS_H_ +#define TVM_TARGET_SPIRV_SPIRV_UTILS_H_ + +#include +#include + +#include +#include +#include + +#include "../../runtime/spirv/spirv_shader.h" + +namespace tvm { +namespace codegen { +/*! + * \brief Lower an IRModule to SPIRV modules. + * + * \param mod The IRModule to lower. + * \param target The target information. + * \return The map from function names to SPIRV binaries, and the concatenated text representation + * of the SPIRV modules. + */ +std::pair, std::string> LowerToSPIRV( + IRModule mod, Target target); + +} // namespace codegen +} // namespace tvm +#endif // TVM_TARGET_SPIRV_SPIRV_UTILS_H_