diff --git a/src/runtime/metal/metal_module.h b/src/runtime/metal/metal_module.h index 77cdf64df8bc..d01523b1faba 100644 --- a/src/runtime/metal/metal_module.h +++ b/src/runtime/metal/metal_module.h @@ -41,13 +41,14 @@ static constexpr const int kMetalMaxNumDevice = 32; /*! * \brief create a metal module from data. * - * \param data The data content. - * \param fmt The format of the data, can be "metal" or "metallib" + * \param smap The map from name to each shader kernel. * \param fmap The map function information map of each function. - * \param source Optional, source file + * \param fmt The format of the source, can be "metal" or "metallib" + * \param source Optional, source file, concatenaed for debug dump */ -Module MetalModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, std::string source); +Module MetalModuleCreate(std::unordered_map smap, + std::unordered_map fmap, std::string fmt, + std::string source); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_METAL_METAL_MODULE_H_ diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index a5eddf3a9556..aef6cf5ebe36 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -30,23 +30,26 @@ #include "../file_utils.h" #include "../meta_data.h" #include "../pack_args.h" -#include "../source_utils.h" #include "../thread_storage_scope.h" #include "metal_common.h" namespace tvm { namespace runtime { +// The version of metal module +// for future compatibility checking +// bump when we change the binary format. +static constexpr const char* kMetalModuleVersion = "0.1.0"; + // Module to support thread-safe multi-GPU execution. // The runtime will contain a per-device module table // The modules will be lazily loaded class MetalModuleNode final : public runtime::ModuleNode { public: - explicit MetalModuleNode(std::string data, std::string fmt, - std::unordered_map fmap, std::string source) - : data_(data), fmt_(fmt), fmap_(fmap), source_(source) { - parsed_kernels_ = SplitKernels(data); - } + explicit MetalModuleNode(std::unordered_map smap, + std::unordered_map fmap, std::string fmt, + std::string source) + : smap_(smap), fmap_(fmap), fmt_(fmt), source_(source) {} const char* type_key() const final { return "metal"; } /*! \brief Get the property of the runtime module. */ @@ -57,27 +60,19 @@ int GetPropertyMask() const final { PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; void SaveToFile(const std::string& file_name, const std::string& format) final { - std::string fmt = GetFileFormat(file_name, format); - ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; - std::string meta_file = GetMetaFilePath(file_name); - SaveMetaDataToFile(meta_file, fmap_); - SaveBinaryToFile(file_name, data_); + LOG(FATAL) << "Do not support save to file, use save to binary and export instead"; } void SaveToBinary(dmlc::Stream* stream) final { - stream->Write(fmt_); + std::string version = kMetalModuleVersion; + stream->Write(version); + stream->Write(smap_); stream->Write(fmap_); - stream->Write(data_); + stream->Write(fmt_); } std::string GetSource(const std::string& format) final { - if (format == fmt_) return data_; - if (source_.length() != 0) { - return source_; - } else if (fmt_ == "metal") { - return data_; - } else { - return ""; - } + // return text source if available. + return source_; } // get a from primary context in device_id @@ -95,15 +90,11 @@ void SaveToBinary(dmlc::Stream* stream) final { // compile NSError* err_msg = nil; id lib = nil; - std::string source; - auto kernel = parsed_kernels_.find(func_name); - // If we cannot find this kernel in parsed_kernels_, it means that all kernels going together - // without explicit separator. In this case we use data_ with all kernels. It done for backward - // compatibility. - if (kernel != parsed_kernels_.end()) - source = kernel->second; - else - source = data_; + auto kernel = smap_.find(func_name); + // Directly lookup kernels + ICHECK(kernel != smap_.end()); + const std::string& source = kernel->second; + if (fmt_ == "metal") { MTLCompileOptions* opts = [MTLCompileOptions alloc]; opts.languageVersion = MTLLanguageVersion2_3; @@ -115,7 +106,8 @@ void SaveToBinary(dmlc::Stream* stream) final { error:&err_msg]; [opts dealloc]; if (lib == nil) { - LOG(FATAL) << "Fail to compile metal lib:" << [[err_msg localizedDescription] UTF8String]; + LOG(FATAL) << "Fail to compile metal source:" + << [[err_msg localizedDescription] UTF8String]; } if (err_msg != nil) { LOG(INFO) << "Warning: " << [[err_msg localizedDescription] UTF8String]; @@ -161,20 +153,18 @@ void SaveToBinary(dmlc::Stream* stream) final { } } }; - // the binary data - std::string data_; - // The format - std::string fmt_; + // the source shader data, can be mtl or binary + std::unordered_map smap_; // function information table. std::unordered_map fmap_; + // The format + std::string fmt_; // The source std::string source_; // function information. std::vector finfo_; // internal mutex when updating the module std::mutex mutex_; - // parsed kernel data - std::unordered_map parsed_kernels_; }; // a wrapped function class to get packed func. @@ -272,39 +262,45 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons return pf; } -Module MetalModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, std::string source) { +Module MetalModuleCreate(std::unordered_map smap, + std::unordered_map fmap, std::string fmt, + std::string source) { ObjectPtr n; AUTORELEASEPOOL { metal::MetalWorkspace::Global()->Init(); - n = make_object(data, fmt, fmap, source); + n = make_object(smap, fmap, fmt, source); }; return Module(n); } -// Load module from module. -Module MetalModuleLoadFile(const std::string& file_name, const std::string& format) { - std::string data; - std::unordered_map fmap; - std::string fmt = GetFileFormat(file_name, format); - std::string meta_file = GetMetaFilePath(file_name); - LoadBinaryFromFile(file_name, &data); - LoadMetaDataFromFile(meta_file, &fmap); - return MetalModuleCreate(data, fmt, fmap, ""); -} +TVM_REGISTER_GLOBAL("runtime.module.create_metal_module") + .set_body_typed([](Map smap, std::string fmap_json, std::string fmt, + std::string source) { + std::istringstream stream(fmap_json); + std::unordered_map fmap; + dmlc::JSONReader reader(&stream); + reader.Read(&fmap); + return MetalModuleCreate( + std::unordered_map(smap.begin(), smap.end()), fmap, fmt, + source); + }); Module MetalModuleLoadBinary(void* strm) { dmlc::Stream* stream = static_cast(strm); - std::string data; + // version is reserved for future changes and + // is discarded for now + std::string ver; + std::unordered_map smap; std::unordered_map fmap; std::string fmt; - stream->Read(&fmt); + + stream->Read(&ver); + stream->Read(&smap); stream->Read(&fmap); - stream->Read(&data); - return MetalModuleCreate(data, fmt, fmap, ""); -} + stream->Read(&fmt); -TVM_REGISTER_GLOBAL("runtime.module.loadfile_metal").set_body_typed(MetalModuleLoadFile); + return MetalModuleCreate(smap, fmap, fmt, ""); +} TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metal").set_body_typed(MetalModuleLoadBinary); } // namespace runtime diff --git a/src/target/opt/build_metal_off.cc b/src/target/opt/build_metal_off.cc index 3cfe1316e7ce..555aa5002f98 100644 --- a/src/target/opt/build_metal_off.cc +++ b/src/target/opt/build_metal_off.cc @@ -26,10 +26,11 @@ namespace tvm { namespace runtime { -Module MetalModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, std::string source) { +Module MetalModuleCreate(std::unordered_map smap, + std::unordered_map fmap, std::string fmt, + std::string source) { LOG(WARNING) << "Metal runtime not enabled, return a source module..."; - return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "metal"); + return codegen::DeviceSourceModuleCreate(source, fmt, fmap, "metal"); } } // namespace runtime diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 767311cb5aed..44da240dd5b0 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include "../../runtime/metal/metal_module.h" @@ -336,33 +337,35 @@ runtime::Module BuildMetal(IRModule mod, Target target) { using tvm::runtime::Registry; bool output_ssa = false; - std::stringstream code; - std::stringstream source; - std::string fmt = "metal"; + std::ostringstream source_maker; + std::unordered_map smap; + const auto* fmetal_compile = Registry::Get("tvm_callback_metal_compile"); + std::string fmt = fmetal_compile ? "metallib" : "metal"; + for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "CodeGenMetal: Can only take PrimFunc"; - code << "// Function: " << kv.first->name_hint << std::endl; + auto global_symbol = kv.second->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol.defined()); + std::string func_name = global_symbol.value(); + + source_maker << "// Function: " << func_name << "\n"; CodeGenMetal cg(target); cg.Init(output_ssa); auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; + cg.AddFunction(f); std::string fsource = cg.Finish(); - if (const auto* f = Registry::Get("tvm_callback_metal_compile")) { - source << fsource; - fsource = (*f)(fsource).operator std::string(); - fmt = "metallib"; + source_maker << fsource << "\n"; + if (fmetal_compile) { + fsource = (*fmetal_compile)(fsource).operator std::string(); } - code << fsource; + smap[func_name] = fsource; } - std::string code_str = code.str(); - if (const auto* f = Registry::Get("tvm_callback_metal_postproc")) { - code_str = (*f)(code_str).operator std::string(); - } - return MetalModuleCreate(code_str, fmt, ExtractFuncInfo(mod), source.str()); + return MetalModuleCreate(smap, ExtractFuncInfo(mod), fmt, source_maker.str()); } TVM_REGISTER_GLOBAL("target.build.metal").set_body_typed(BuildMetal);