From 0a9859bc64135a7011d5d0848ccab58d1e03ef93 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 26 Apr 2023 11:14:48 -0400 Subject: [PATCH 1/2] [METAL] Update metal runtime to directly store kernel map This PR updates the metal runtime storage format to directly store the kernel map. This will enable more robust support to leverage metallib binary format which may not be compatible with previous string split. It changes the binary format of the metal module. We also added a version to enable easier future update. --- src/runtime/metal/metal_module.h | 11 +-- src/runtime/metal/metal_module.mm | 108 ++++++++++++++--------------- src/target/source/codegen_metal.cc | 30 ++++---- 3 files changed, 74 insertions(+), 75 deletions(-) diff --git a/src/runtime/metal/metal_module.h b/src/runtime/metal/metal_module.h index 77cdf64df8bc..d01523b1faba 100644 --- a/src/runtime/metal/metal_module.h +++ b/src/runtime/metal/metal_module.h @@ -41,13 +41,14 @@ static constexpr const int kMetalMaxNumDevice = 32; /*! * \brief create a metal module from data. * - * \param data The data content. - * \param fmt The format of the data, can be "metal" or "metallib" + * \param smap The map from name to each shader kernel. * \param fmap The map function information map of each function. - * \param source Optional, source file + * \param fmt The format of the source, can be "metal" or "metallib" + * \param source Optional, source file, concatenaed for debug dump */ -Module MetalModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, std::string source); +Module MetalModuleCreate(std::unordered_map smap, + std::unordered_map fmap, std::string fmt, + std::string source); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_METAL_METAL_MODULE_H_ diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index a5eddf3a9556..aef6cf5ebe36 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -30,23 +30,26 @@ #include "../file_utils.h" #include "../meta_data.h" #include "../pack_args.h" -#include "../source_utils.h" #include "../thread_storage_scope.h" #include "metal_common.h" namespace tvm { namespace runtime { +// The version of metal module +// for future compatibility checking +// bump when we change the binary format. +static constexpr const char* kMetalModuleVersion = "0.1.0"; + // Module to support thread-safe multi-GPU execution. // The runtime will contain a per-device module table // The modules will be lazily loaded class MetalModuleNode final : public runtime::ModuleNode { public: - explicit MetalModuleNode(std::string data, std::string fmt, - std::unordered_map fmap, std::string source) - : data_(data), fmt_(fmt), fmap_(fmap), source_(source) { - parsed_kernels_ = SplitKernels(data); - } + explicit MetalModuleNode(std::unordered_map smap, + std::unordered_map fmap, std::string fmt, + std::string source) + : smap_(smap), fmap_(fmap), fmt_(fmt), source_(source) {} const char* type_key() const final { return "metal"; } /*! \brief Get the property of the runtime module. */ @@ -57,27 +60,19 @@ int GetPropertyMask() const final { PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; void SaveToFile(const std::string& file_name, const std::string& format) final { - std::string fmt = GetFileFormat(file_name, format); - ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; - std::string meta_file = GetMetaFilePath(file_name); - SaveMetaDataToFile(meta_file, fmap_); - SaveBinaryToFile(file_name, data_); + LOG(FATAL) << "Do not support save to file, use save to binary and export instead"; } void SaveToBinary(dmlc::Stream* stream) final { - stream->Write(fmt_); + std::string version = kMetalModuleVersion; + stream->Write(version); + stream->Write(smap_); stream->Write(fmap_); - stream->Write(data_); + stream->Write(fmt_); } std::string GetSource(const std::string& format) final { - if (format == fmt_) return data_; - if (source_.length() != 0) { - return source_; - } else if (fmt_ == "metal") { - return data_; - } else { - return ""; - } + // return text source if available. + return source_; } // get a from primary context in device_id @@ -95,15 +90,11 @@ void SaveToBinary(dmlc::Stream* stream) final { // compile NSError* err_msg = nil; id lib = nil; - std::string source; - auto kernel = parsed_kernels_.find(func_name); - // If we cannot find this kernel in parsed_kernels_, it means that all kernels going together - // without explicit separator. In this case we use data_ with all kernels. It done for backward - // compatibility. - if (kernel != parsed_kernels_.end()) - source = kernel->second; - else - source = data_; + auto kernel = smap_.find(func_name); + // Directly lookup kernels + ICHECK(kernel != smap_.end()); + const std::string& source = kernel->second; + if (fmt_ == "metal") { MTLCompileOptions* opts = [MTLCompileOptions alloc]; opts.languageVersion = MTLLanguageVersion2_3; @@ -115,7 +106,8 @@ void SaveToBinary(dmlc::Stream* stream) final { error:&err_msg]; [opts dealloc]; if (lib == nil) { - LOG(FATAL) << "Fail to compile metal lib:" << [[err_msg localizedDescription] UTF8String]; + LOG(FATAL) << "Fail to compile metal source:" + << [[err_msg localizedDescription] UTF8String]; } if (err_msg != nil) { LOG(INFO) << "Warning: " << [[err_msg localizedDescription] UTF8String]; @@ -161,20 +153,18 @@ void SaveToBinary(dmlc::Stream* stream) final { } } }; - // the binary data - std::string data_; - // The format - std::string fmt_; + // the source shader data, can be mtl or binary + std::unordered_map smap_; // function information table. std::unordered_map fmap_; + // The format + std::string fmt_; // The source std::string source_; // function information. std::vector finfo_; // internal mutex when updating the module std::mutex mutex_; - // parsed kernel data - std::unordered_map parsed_kernels_; }; // a wrapped function class to get packed func. @@ -272,39 +262,45 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons return pf; } -Module MetalModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, std::string source) { +Module MetalModuleCreate(std::unordered_map smap, + std::unordered_map fmap, std::string fmt, + std::string source) { ObjectPtr n; AUTORELEASEPOOL { metal::MetalWorkspace::Global()->Init(); - n = make_object(data, fmt, fmap, source); + n = make_object(smap, fmap, fmt, source); }; return Module(n); } -// Load module from module. -Module MetalModuleLoadFile(const std::string& file_name, const std::string& format) { - std::string data; - std::unordered_map fmap; - std::string fmt = GetFileFormat(file_name, format); - std::string meta_file = GetMetaFilePath(file_name); - LoadBinaryFromFile(file_name, &data); - LoadMetaDataFromFile(meta_file, &fmap); - return MetalModuleCreate(data, fmt, fmap, ""); -} +TVM_REGISTER_GLOBAL("runtime.module.create_metal_module") + .set_body_typed([](Map smap, std::string fmap_json, std::string fmt, + std::string source) { + std::istringstream stream(fmap_json); + std::unordered_map fmap; + dmlc::JSONReader reader(&stream); + reader.Read(&fmap); + return MetalModuleCreate( + std::unordered_map(smap.begin(), smap.end()), fmap, fmt, + source); + }); Module MetalModuleLoadBinary(void* strm) { dmlc::Stream* stream = static_cast(strm); - std::string data; + // version is reserved for future changes and + // is discarded for now + std::string ver; + std::unordered_map smap; std::unordered_map fmap; std::string fmt; - stream->Read(&fmt); + + stream->Read(&ver); + stream->Read(&smap); stream->Read(&fmap); - stream->Read(&data); - return MetalModuleCreate(data, fmt, fmap, ""); -} + stream->Read(&fmt); -TVM_REGISTER_GLOBAL("runtime.module.loadfile_metal").set_body_typed(MetalModuleLoadFile); + return MetalModuleCreate(smap, fmap, fmt, ""); +} TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metal").set_body_typed(MetalModuleLoadBinary); } // namespace runtime diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 767311cb5aed..9e833e9ace24 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -336,33 +336,35 @@ runtime::Module BuildMetal(IRModule mod, Target target) { using tvm::runtime::Registry; bool output_ssa = false; - std::stringstream code; - std::stringstream source; - std::string fmt = "metal"; + std::ostringstream source_maker; + std::unordered_map smap; + const auto* fmetal_compile = Registry::Get("tvm_callback_metal_compile"); + std::string fmt = fmetal_compile ? "metallib" : "metal"; + for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "CodeGenMetal: Can only take PrimFunc"; - code << "// Function: " << kv.first->name_hint << std::endl; + auto global_symbol = kv.second->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol.defined()); + std::string func_name = global_symbol.value(); + + source_maker << "// Function: " << func_name << "\n"; CodeGenMetal cg(target); cg.Init(output_ssa); auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; + cg.AddFunction(f); std::string fsource = cg.Finish(); - if (const auto* f = Registry::Get("tvm_callback_metal_compile")) { - source << fsource; - fsource = (*f)(fsource).operator std::string(); - fmt = "metallib"; + source_maker << fsource << "\n"; + if (fmetal_compile) { + fsource = (*fmetal_compile)(fsource).operator std::string(); } - code << fsource; + smap[func_name] = fsource; } - std::string code_str = code.str(); - if (const auto* f = Registry::Get("tvm_callback_metal_postproc")) { - code_str = (*f)(code_str).operator std::string(); - } - return MetalModuleCreate(code_str, fmt, ExtractFuncInfo(mod), source.str()); + return MetalModuleCreate(smap, ExtractFuncInfo(mod), fmt, source_maker.str()); } TVM_REGISTER_GLOBAL("target.build.metal").set_body_typed(BuildMetal); From 3219e35fdf678123f0b8e8849521cece652d5035 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 26 Apr 2023 14:12:43 -0400 Subject: [PATCH 2/2] Fix lint --- src/target/opt/build_metal_off.cc | 7 ++++--- src/target/source/codegen_metal.cc | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/target/opt/build_metal_off.cc b/src/target/opt/build_metal_off.cc index 3cfe1316e7ce..555aa5002f98 100644 --- a/src/target/opt/build_metal_off.cc +++ b/src/target/opt/build_metal_off.cc @@ -26,10 +26,11 @@ namespace tvm { namespace runtime { -Module MetalModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, std::string source) { +Module MetalModuleCreate(std::unordered_map smap, + std::unordered_map fmap, std::string fmt, + std::string source) { LOG(WARNING) << "Metal runtime not enabled, return a source module..."; - return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "metal"); + return codegen::DeviceSourceModuleCreate(source, fmt, fmap, "metal"); } } // namespace runtime diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 9e833e9ace24..44da240dd5b0 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include "../../runtime/metal/metal_module.h"