From b95db83e319f5ee93f6c21edd6ebf5a9d8526b46 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 1 Nov 2023 15:06:03 -0400 Subject: [PATCH 1/2] [Codegen][Metal] Disable cross-function call in Metal codegen This PR restores the Metal codegen to the one before #15835. Due to there will likely be no internal function call in Metal, we think it is safe to do so. Verified that with this PR, the metal codegen and iPhone codegen will not fail and will work properly. The reason of the iPhone codegen failure is because the multiple declarations of a same function will lead to multiple emissions of a same structs, which is not recognizable by the metal compiler. --- .github/workflows/main.yml | 1 + src/target/source/codegen_metal.cc | 89 ++++++++++++------- src/target/source/codegen_metal.h | 3 +- .../unittest/test_target_codegen_metal.py | 23 +++++ 4 files changed, 80 insertions(+), 36 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 7b4b8d826f36..b45f40f1b83b 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -73,6 +73,7 @@ jobs: shell: bash -l {0} run: >- python -m pytest -v -s 'tests/python/unittest/test_allreduce.py::test_allreduce_sum_compile' + python -m pytest -v -s 'tests/python/unittest/test_target_codegen_metal.py::test_func_with_trailing_pod_params' - name: Minimal Metal Compile-and-Run shell: bash -l {0} run: >- diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index ddd7d25f3b5f..86d5956dec19 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -39,8 +39,6 @@ namespace codegen { void CodeGenMetal::InitFuncState(const PrimFunc& f) { CodeGenC::InitFuncState(f); - // skip the first underscore, so SSA variable starts from _1 - name_supply_->FreshName("v_"); // analyze the data; for (Var arg : f->params) { if (arg.dtype().is_handle()) { @@ -57,15 +55,33 @@ CodeGenMetal::CodeGenMetal(Target target) : target_(target) { << "};\n\n"; } -void CodeGenMetal::PrintFunctionSignature(const String& function_name, const PrimFunc& func, - std::ostream& os) { +void CodeGenMetal::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { + // NOTE: There is no inter-function calls among Metal kernels. + // For now we keep the metal codegen without inter-function call + // process. + // We can switch to follow the flow with inter-function call process + // after the Metal function declaration is properly printed. + // In Metal, for PrimFuncs with signature + // def func(A: Buffer, B: Buffer, x: int, y: float) -> None + // where there are trailing pod parameters, the codegen emits a struct + // struct func_params{ x: int; y: float; } + // for the function. In the flow of inter-function call process, + // the struct will be emitted for every time a function is declared. + // So consequently there are duplicate appearances of a same struct, + // which makes the Metal compiler unable to recognize. + + // clear previous generated state. + this->InitFuncState(func); + // skip the first underscore, so SSA variable starts from _1 + name_supply_->FreshName("v_"); + // add to alloc buffer type. auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; // Function header. - os << "kernel void " << static_cast(global_symbol.value()) << "("; + this->stream << "kernel void " << static_cast(global_symbol.value()) << "("; // Buffer arguments size_t num_buffer = 0; @@ -77,13 +93,13 @@ void CodeGenMetal::PrintFunctionSignature(const String& function_name, const Pri for (size_t i = 0; i < func->params.size(); ++i, ++num_buffer) { Var v = func->params[i]; if (!v.dtype().is_handle()) break; - os << " "; + this->stream << " "; std::string vid = AllocVarID(v.get()); auto it = alloc_storage_scope_.find(v.get()); if (it != alloc_storage_scope_.end()) { - PrintStorageScope(it->second, os); + PrintStorageScope(it->second, this->stream); } - PrintType(GetType(v), os); + PrintType(GetType(v), this->stream); // Register handle data type // TODO(tvm-team): consider simply keep type info in the // type annotation(via a normalizing rewriting). @@ -92,14 +108,15 @@ void CodeGenMetal::PrintFunctionSignature(const String& function_name, const Pri RegisterHandleType(v.get(), prim->dtype); } } - os << ' ' << vid << " [[ buffer(" << i << ") ]],\n"; + this->stream << ' ' << vid << " [[ buffer(" << i << ") ]],\n"; } // Setup normal arguments. size_t nargs = func->params.size() - num_buffer; std::string varg = name_supply_->FreshName("arg"); if (nargs != 0) { std::string arg_buf_type = static_cast(global_symbol.value()) + "_args_t"; - os << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer << ") ]],\n"; + this->stream << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer + << ") ]],\n"; // declare the struct decl_stream << "struct " << arg_buf_type << " {\n"; for (size_t i = num_buffer; i < func->params.size(); ++i) { @@ -141,16 +158,22 @@ void CodeGenMetal::PrintFunctionSignature(const String& function_name, const Pri if (work_dim != 0) { // use ushort by default for now - os << " "; - PrintType(DataType::UInt(thread_index_bits_, work_dim), os); - os << " blockIdx [[threadgroup_position_in_grid]],\n"; - os << " "; - PrintType(DataType::UInt(thread_index_bits_, work_dim), os); - os << " threadIdx [[thread_position_in_threadgroup]]\n"; + stream << " "; + PrintType(DataType::UInt(thread_index_bits_, work_dim), stream); + stream << " blockIdx [[threadgroup_position_in_grid]],\n"; + stream << " "; + PrintType(DataType::UInt(thread_index_bits_, work_dim), stream); + stream << " threadIdx [[thread_position_in_threadgroup]]\n"; } thread_work_dim_ = work_dim; - os << ")"; + // the function scope. + stream << ") {\n"; + int func_scope = this->BeginScope(); + this->PrintStmt(func->body); + this->EndScope(func_scope); + this->PrintIndent(); + this->stream << "}\n\n"; } void CodeGenMetal::BindThreadIndex(const IterVar& iv) { @@ -295,6 +318,9 @@ void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // N } void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) + CHECK(!op->op.as()) + << "CodegenMetal does not support inter-function calls, " + << "but expression " << GetRef(op) << " calls PrimFunc " << op->op; if (op->op.same_as(builtin::reinterpret())) { // generate as_type(ARG) os << "(as_type<"; @@ -337,33 +363,28 @@ runtime::Module BuildMetal(IRModule mod, Target target) { const auto* fmetal_compile = Registry::Get("tvm_callback_metal_compile"); std::string fmt = fmetal_compile ? "metallib" : "metal"; - Map functions; - for (auto [gvar, base_func] : mod->functions) { - ICHECK(base_func->IsInstance()) << "CodeGenMetal: Can only take PrimFunc"; - auto calling_conv = base_func->GetAttr(tvm::attr::kCallingConv); - ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) - << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - - auto prim_func = Downcast(base_func); - functions.Set(gvar, prim_func); - } + for (auto kv : mod->functions) { + ICHECK(kv.second->IsInstance()) << "CodeGenMetal: Can only take PrimFunc"; + auto global_symbol = kv.second->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol.defined()); + std::string func_name = global_symbol.value(); - for (auto [gvar, prim_func] : functions) { - source_maker << "// Function: " << gvar->name_hint << "\n"; + source_maker << "// Function: " << func_name << "\n"; CodeGenMetal cg(target); cg.Init(output_ssa); + auto f = Downcast(kv.second); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) + << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - for (auto [other_gvar, other_prim_func] : functions) { - cg.DeclareFunction(other_gvar, other_prim_func); - } - cg.AddFunction(gvar, prim_func); + cg.AddFunction(kv.first, f); std::string fsource = cg.Finish(); source_maker << fsource << "\n"; if (fmetal_compile) { fsource = (*fmetal_compile)(fsource, target).operator std::string(); } - smap[cg.GetFunctionName(gvar)] = fsource; + smap[func_name] = fsource; } return MetalModuleCreate(smap, ExtractFuncInfo(mod), fmt, source_maker.str()); diff --git a/src/target/source/codegen_metal.h b/src/target/source/codegen_metal.h index 26c991e60df9..9cff3211ce44 100644 --- a/src/target/source/codegen_metal.h +++ b/src/target/source/codegen_metal.h @@ -38,8 +38,7 @@ class CodeGenMetal final : public CodeGenC { explicit CodeGenMetal(Target target); // override print thread tag. void PrintArgUnionDecl(); - void PrintFunctionSignature(const String& function_name, const PrimFunc& func, - std::ostream& os) override; + void AddFunction(const GlobalVar& gvar, const PrimFunc& func) final; void InitFuncState(const PrimFunc& f) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageSync(const CallNode* op) final; // NOLINT(*) diff --git a/tests/python/unittest/test_target_codegen_metal.py b/tests/python/unittest/test_target_codegen_metal.py index dcbbba8c9c9f..b4e747a0b4d8 100644 --- a/tests/python/unittest/test_target_codegen_metal.py +++ b/tests/python/unittest/test_target_codegen_metal.py @@ -169,5 +169,28 @@ def func(A: T.Buffer((16), "uint8"), B: T.Buffer((16), "float32")): np.testing.assert_allclose(b_nd.numpy(), a.astype("float32"), atol=1e-5, rtol=1e-5) +@tvm.testing.requires_metal(support_required="compile-only") +def test_func_with_trailing_pod_params(): + from tvm.contrib import xcode # pylint: disable=import-outside-toplevel + + @T.prim_func + def func(A: T.Buffer((16), "float32"), B: T.Buffer((16), "float32"), x: T.float32): + for i in T.thread_binding(16, thread="threadIdx.x"): + with T.block("block"): + vi = T.axis.spatial(16, i) + B[vi] = A[vi] + x + + @tvm.register_func("tvm_callback_metal_compile") + def compile_metal(src, target): + return xcode.compile_metal(src) + + mod = tvm.IRModule({"main": func}) + + f = tvm.build(mod, target="metal") + src: str = f.imported_modules[0].get_source() + occurrences = src.count("struct func_kernel_args_t") + assert occurrences == 1, occurrences + + if __name__ == "__main__": tvm.testing.main() From 6105d1e1d72feb57d021e690893a09f5a7905f42 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 1 Nov 2023 23:38:56 -0400 Subject: [PATCH 2/2] Fix the action script --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index b45f40f1b83b..fdd2db9c4a81 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -71,7 +71,7 @@ jobs: python -m pytest -v tests/python/all-platform-minimal-test - name: Minimal Metal Compile-Only shell: bash -l {0} - run: >- + run: | python -m pytest -v -s 'tests/python/unittest/test_allreduce.py::test_allreduce_sum_compile' python -m pytest -v -s 'tests/python/unittest/test_target_codegen_metal.py::test_func_with_trailing_pod_params' - name: Minimal Metal Compile-and-Run