From 7d33223aff111a33366cc2705333e0856217b37b Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 19 Mar 2023 10:20:49 -0400 Subject: [PATCH] [CODEGEN][METAL] Fix unaligned vector load This PR fixes the implementation of unaligned vector load. Previously vector construction was printed as (float2)(v0, v1). This will cause problem as C have comma expression, and (v0, v1) will be evaluated as v1. The final result will become float2(v1, v1). The bug affects all codegen that uses the default implementation, such as metal. We added a testcase on metal to cover this case. Also updated codegen opencl to keep the old style as that is the convention opencl follows. --- src/target/source/codegen_c.cc | 15 +++++--- src/target/source/codegen_metal.cc | 17 +++------ src/target/source/codegen_metal.h | 4 +- src/target/source/codegen_opencl.cc | 37 +++++++++++++++++++ src/target/source/codegen_opencl.h | 3 ++ .../unittest/test_target_codegen_metal.py | 27 ++++++++++++++ 6 files changed, 83 insertions(+), 20 deletions(-) diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 514ec8395821..a4332476f335 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -802,15 +802,16 @@ void CodeGenC::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*) } void CodeGenC::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*) - // constraint of current logic - ICHECK_EQ(op->base.dtype(), DataType::Int(32)); - os << "((int" << op->lanes << ")("; + // NOTE: C have comma expression so cannot use (int2)(v0, v1) + // instead should use int2(v0, v1) + PrintType(op->dtype, os); + os << "("; for (int i = 0; i < op->lanes; i++) { os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i << ")"; if (i != op->lanes - 1) os << ", "; } - os << "))"; + os << ")"; } void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) { @@ -999,9 +1000,11 @@ void CodeGenC::PrintVecElemLoadExpr(DataType t, int i, const std::string& value, } if (i == 0) { - os << "(("; + // NOTE: C have comma expression so cannot use (float2)(v0, v1) + // instead should use float2(v0, v1) + os << "("; PrintType(t, os); - os << ")("; + os << "("; } os << value; if (i != t.lanes() - 1) { diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index ad9560eef214..534e2c3654c4 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -299,17 +299,6 @@ void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // N os << ')'; } -void CodeGenMetal::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*) - PrintType(op->dtype, os); - os << "("; - for (int i = 0; i < op->lanes; ++i) { - if (i != 0) os << ", "; - os << "(" << PrintExpr(op->base) << ")" - << "+(" << PrintExpr(op->stride) << "*" << i << ")"; - } - os << ')'; -} - void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) if (op->op.same_as(builtin::reinterpret())) { // generate as_type(ARG) @@ -369,7 +358,11 @@ runtime::Module BuildMetal(IRModule mod, Target target) { code << fsource; } - return MetalModuleCreate(code.str(), fmt, ExtractFuncInfo(mod), source.str()); + std::string code_str = code.str(); + if (const auto* f = Registry::Get("tvm_callback_metal_postproc")) { + code_str = (*f)(code_str).operator std::string(); + } + return MetalModuleCreate(code_str, fmt, ExtractFuncInfo(mod), source.str()); } TVM_REGISTER_GLOBAL("target.build.metal").set_body_typed(BuildMetal); diff --git a/src/target/source/codegen_metal.h b/src/target/source/codegen_metal.h index 4e464c6636a8..99332e004678 100644 --- a/src/target/source/codegen_metal.h +++ b/src/target/source/codegen_metal.h @@ -51,9 +51,9 @@ class CodeGenMetal final : public CodeGenC { void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final; // overload visitor void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; + void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) + // reuse parent's function. using CodeGenC::PrintType; diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 6e5a9db4d37c..89cc09aeadb2 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -303,6 +303,31 @@ void CodeGenOpenCL::PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr stream << ");\n"; } +void CodeGenOpenCL::PrintVecElemLoadExpr(DataType t, int i, const std::string& value, + std::ostream& os) { // NOLINT(*) + ICHECK_GT(t.lanes(), 1); + if (t.bits() == 8 && (t.is_int() || t.is_uint())) { + if (i != 0) { + os << "|"; + } + os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))"; + return; + } + if (i == 0) { + // NOTE: opencl print things as (float2)(v0, v1) + os << "(("; + PrintType(t, os); + os << ")("; + } + os << value; + if (i != t.lanes() - 1) { + os << ","; + } else { + os << "))"; + } + return; +} + void CodeGenOpenCL::PrintStorageSync(const CallNode* op) { const std::string& sync = op->args[0].as()->value; if (sync == "warp") { @@ -490,6 +515,18 @@ void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // os << "))"; } +void CodeGenOpenCL::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*) + os << "(("; + PrintType(op->dtype, os); + os << ")("; + for (int i = 0; i < op->lanes; i++) { + os << "(" << PrintExpr(op->base) << ")" + << "+(" << PrintExpr(op->stride) << "*" << i << ")"; + if (i != op->lanes - 1) os << ", "; + } + os << "))"; +} + void CodeGenOpenCL::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) if (std::isinf(op->value)) { if (op->value < 0) { diff --git a/src/target/source/codegen_opencl.h b/src/target/source/codegen_opencl.h index bf3046f0d8df..169759976119 100644 --- a/src/target/source/codegen_opencl.h +++ b/src/target/source/codegen_opencl.h @@ -51,6 +51,8 @@ class CodeGenOpenCL final : public CodeGenC { std::string GetVecLoad(DataType t, const BufferNode* buffer, PrimExpr base) final; void PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base, const std::string& value) final; // NOLINT(*) + void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, + std::ostream& os) final; // NOLINT(*) // the address of load/store void PrintVecAddr(const BufferNode* buffer, DataType t, PrimExpr base, std::ostream& os); // NOLINT(*) @@ -62,6 +64,7 @@ class CodeGenOpenCL final : public CodeGenC { // overload visitor void VisitStmt_(const AllocateNode* op) final; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const CastNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) diff --git a/tests/python/unittest/test_target_codegen_metal.py b/tests/python/unittest/test_target_codegen_metal.py index 002cf3c69640..eee54d052342 100644 --- a/tests/python/unittest/test_target_codegen_metal.py +++ b/tests/python/unittest/test_target_codegen_metal.py @@ -22,6 +22,8 @@ from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16 from tvm.contrib import nvcc import tvm.testing +import tvm.script +from tvm.script import tir as T tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") @@ -54,6 +56,31 @@ def check_inf_nan(dev, n, value, dtype): check_inf_nan(dev, 1, float("nan"), "float16") +@tvm.testing.requires_gpu +@tvm.testing.requires_metal +def test_unaligned_vectorize(): + @tvm.script.ir_module + class IRModule: + @T.prim_func + def main(A: T.Buffer((2, 3), "float32"), B: T.Buffer((6,), "float32")): + T.func_attr({"global_symbol": "main"}) + for i0_1 in T.thread_binding(3, thread="threadIdx.x"): + for i0_0 in T.vectorized(2): + with T.block("block"): + vi0 = T.axis.spatial(6, i0_0 * 3 + i0_1) + B[vi0] = A[vi0 // 3, vi0 % 3] + + target = "metal" + dev = tvm.metal() + + a = (np.arange(6).reshape(2, 3)).astype("float32") + a_nd = tvm.nd.array(a, dev) + b_nd = tvm.nd.empty((6,), "float32", dev) + f = tvm.build(IRModule, target=target) + f(a_nd, b_nd) + np.testing.assert_allclose(b_nd.numpy(), a.reshape(6), atol=1e-5, rtol=1e-5) + + @tvm.testing.requires_gpu @tvm.testing.requires_metal def test_metal_erf():