Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ jobs:
python -m pytest -v tests/python/all-platform-minimal-test
- name: Minimal Metal Compile-Only
shell: bash -l {0}
run: >-
run: |
python -m pytest -v -s 'tests/python/unittest/test_allreduce.py::test_allreduce_sum_compile'
python -m pytest -v -s 'tests/python/unittest/test_target_codegen_metal.py::test_func_with_trailing_pod_params'
- name: Minimal Metal Compile-and-Run
shell: bash -l {0}
run: >-
Expand Down
89 changes: 55 additions & 34 deletions src/target/source/codegen_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ namespace codegen {

void CodeGenMetal::InitFuncState(const PrimFunc& f) {
CodeGenC::InitFuncState(f);
// skip the first underscore, so SSA variable starts from _1
name_supply_->FreshName("v_");
// analyze the data;
for (Var arg : f->params) {
if (arg.dtype().is_handle()) {
Expand All @@ -57,15 +55,33 @@ CodeGenMetal::CodeGenMetal(Target target) : target_(target) {
<< "};\n\n";
}

void CodeGenMetal::PrintFunctionSignature(const String& function_name, const PrimFunc& func,
std::ostream& os) {
void CodeGenMetal::AddFunction(const GlobalVar& gvar, const PrimFunc& func) {
// NOTE: There is no inter-function calls among Metal kernels.
// For now we keep the metal codegen without inter-function call
// process.
// We can switch to follow the flow with inter-function call process
// after the Metal function declaration is properly printed.
// In Metal, for PrimFuncs with signature
// def func(A: Buffer, B: Buffer, x: int, y: float) -> None
// where there are trailing pod parameters, the codegen emits a struct
// struct func_params{ x: int; y: float; }
// for the function. In the flow of inter-function call process,
// the struct will be emitted for every time a function is declared.
// So consequently there are duplicate appearances of a same struct,
// which makes the Metal compiler unable to recognize.

// clear previous generated state.
this->InitFuncState(func);
// skip the first underscore, so SSA variable starts from _1
name_supply_->FreshName("v_");

// add to alloc buffer type.
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";

// Function header.
os << "kernel void " << static_cast<std::string>(global_symbol.value()) << "(";
this->stream << "kernel void " << static_cast<std::string>(global_symbol.value()) << "(";

// Buffer arguments
size_t num_buffer = 0;
Expand All @@ -77,13 +93,13 @@ void CodeGenMetal::PrintFunctionSignature(const String& function_name, const Pri
for (size_t i = 0; i < func->params.size(); ++i, ++num_buffer) {
Var v = func->params[i];
if (!v.dtype().is_handle()) break;
os << " ";
this->stream << " ";
std::string vid = AllocVarID(v.get());
auto it = alloc_storage_scope_.find(v.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, os);
PrintStorageScope(it->second, this->stream);
}
PrintType(GetType(v), os);
PrintType(GetType(v), this->stream);
// Register handle data type
// TODO(tvm-team): consider simply keep type info in the
// type annotation(via a normalizing rewriting).
Expand All @@ -92,14 +108,15 @@ void CodeGenMetal::PrintFunctionSignature(const String& function_name, const Pri
RegisterHandleType(v.get(), prim->dtype);
}
}
os << ' ' << vid << " [[ buffer(" << i << ") ]],\n";
this->stream << ' ' << vid << " [[ buffer(" << i << ") ]],\n";
}
// Setup normal arguments.
size_t nargs = func->params.size() - num_buffer;
std::string varg = name_supply_->FreshName("arg");
if (nargs != 0) {
std::string arg_buf_type = static_cast<std::string>(global_symbol.value()) + "_args_t";
os << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer << ") ]],\n";
this->stream << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer
<< ") ]],\n";
// declare the struct
decl_stream << "struct " << arg_buf_type << " {\n";
for (size_t i = num_buffer; i < func->params.size(); ++i) {
Expand Down Expand Up @@ -141,16 +158,22 @@ void CodeGenMetal::PrintFunctionSignature(const String& function_name, const Pri

if (work_dim != 0) {
// use ushort by default for now
os << " ";
PrintType(DataType::UInt(thread_index_bits_, work_dim), os);
os << " blockIdx [[threadgroup_position_in_grid]],\n";
os << " ";
PrintType(DataType::UInt(thread_index_bits_, work_dim), os);
os << " threadIdx [[thread_position_in_threadgroup]]\n";
stream << " ";
PrintType(DataType::UInt(thread_index_bits_, work_dim), stream);
stream << " blockIdx [[threadgroup_position_in_grid]],\n";
stream << " ";
PrintType(DataType::UInt(thread_index_bits_, work_dim), stream);
stream << " threadIdx [[thread_position_in_threadgroup]]\n";
}
thread_work_dim_ = work_dim;

os << ")";
// the function scope.
stream << ") {\n";
int func_scope = this->BeginScope();
this->PrintStmt(func->body);
this->EndScope(func_scope);
this->PrintIndent();
this->stream << "}\n\n";
}

void CodeGenMetal::BindThreadIndex(const IterVar& iv) {
Expand Down Expand Up @@ -295,6 +318,9 @@ void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // N
}

void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
CHECK(!op->op.as<GlobalVarNode>())
<< "CodegenMetal does not support inter-function calls, "
<< "but expression " << GetRef<Call>(op) << " calls PrimFunc " << op->op;
if (op->op.same_as(builtin::reinterpret())) {
// generate as_type<TYPE>(ARG)
os << "(as_type<";
Expand Down Expand Up @@ -337,33 +363,28 @@ runtime::Module BuildMetal(IRModule mod, Target target) {
const auto* fmetal_compile = Registry::Get("tvm_callback_metal_compile");
std::string fmt = fmetal_compile ? "metallib" : "metal";

Map<GlobalVar, PrimFunc> functions;
for (auto [gvar, base_func] : mod->functions) {
ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenMetal: Can only take PrimFunc";
auto calling_conv = base_func->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch";

auto prim_func = Downcast<PrimFunc>(base_func);
functions.Set(gvar, prim_func);
}
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenMetal: Can only take PrimFunc";
auto global_symbol = kv.second->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined());
std::string func_name = global_symbol.value();

for (auto [gvar, prim_func] : functions) {
source_maker << "// Function: " << gvar->name_hint << "\n";
source_maker << "// Function: " << func_name << "\n";
CodeGenMetal cg(target);
cg.Init(output_ssa);
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch";

for (auto [other_gvar, other_prim_func] : functions) {
cg.DeclareFunction(other_gvar, other_prim_func);
}
cg.AddFunction(gvar, prim_func);
cg.AddFunction(kv.first, f);

std::string fsource = cg.Finish();
source_maker << fsource << "\n";
if (fmetal_compile) {
fsource = (*fmetal_compile)(fsource, target).operator std::string();
}
smap[cg.GetFunctionName(gvar)] = fsource;
smap[func_name] = fsource;
}

return MetalModuleCreate(smap, ExtractFuncInfo(mod), fmt, source_maker.str());
Expand Down
3 changes: 1 addition & 2 deletions src/target/source/codegen_metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ class CodeGenMetal final : public CodeGenC {
explicit CodeGenMetal(Target target);
// override print thread tag.
void PrintArgUnionDecl();
void PrintFunctionSignature(const String& function_name, const PrimFunc& func,
std::ostream& os) override;
void AddFunction(const GlobalVar& gvar, const PrimFunc& func) final;
void InitFuncState(const PrimFunc& f) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const CallNode* op) final; // NOLINT(*)
Expand Down
23 changes: 23 additions & 0 deletions tests/python/unittest/test_target_codegen_metal.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,5 +169,28 @@ def func(A: T.Buffer((16), "uint8"), B: T.Buffer((16), "float32")):
np.testing.assert_allclose(b_nd.numpy(), a.astype("float32"), atol=1e-5, rtol=1e-5)


@tvm.testing.requires_metal(support_required="compile-only")
def test_func_with_trailing_pod_params():
from tvm.contrib import xcode # pylint: disable=import-outside-toplevel

@T.prim_func
def func(A: T.Buffer((16), "float32"), B: T.Buffer((16), "float32"), x: T.float32):
for i in T.thread_binding(16, thread="threadIdx.x"):
with T.block("block"):
vi = T.axis.spatial(16, i)
B[vi] = A[vi] + x

@tvm.register_func("tvm_callback_metal_compile")
def compile_metal(src, target):
return xcode.compile_metal(src)

mod = tvm.IRModule({"main": func})

f = tvm.build(mod, target="metal")
src: str = f.imported_modules[0].get_source()
occurrences = src.count("struct func_kernel_args_t")
assert occurrences == 1, occurrences


if __name__ == "__main__":
tvm.testing.main()