diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 514ec8395821..a4332476f335 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -802,15 +802,16 @@ void CodeGenC::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*) } void CodeGenC::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*) - // constraint of current logic - ICHECK_EQ(op->base.dtype(), DataType::Int(32)); - os << "((int" << op->lanes << ")("; + // NOTE: C have comma expression so cannot use (int2)(v0, v1) + // instead should use int2(v0, v1) + PrintType(op->dtype, os); + os << "("; for (int i = 0; i < op->lanes; i++) { os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i << ")"; if (i != op->lanes - 1) os << ", "; } - os << "))"; + os << ")"; } void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) { @@ -999,9 +1000,11 @@ void CodeGenC::PrintVecElemLoadExpr(DataType t, int i, const std::string& value, } if (i == 0) { - os << "(("; + // NOTE: C have comma expression so cannot use (float2)(v0, v1) + // instead should use float2(v0, v1) + os << "("; PrintType(t, os); - os << ")("; + os << "("; } os << value; if (i != t.lanes() - 1) { diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index ad9560eef214..534e2c3654c4 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -299,17 +299,6 @@ void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // N os << ')'; } -void CodeGenMetal::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*) - PrintType(op->dtype, os); - os << "("; - for (int i = 0; i < op->lanes; ++i) { - if (i != 0) os << ", "; - os << "(" << PrintExpr(op->base) << ")" - << "+(" << PrintExpr(op->stride) << "*" << i << ")"; - } - os << ')'; -} - void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) if (op->op.same_as(builtin::reinterpret())) { // generate as_type(ARG) @@ -369,7 +358,11 @@ runtime::Module BuildMetal(IRModule mod, Target target) { code << fsource; } - return MetalModuleCreate(code.str(), fmt, ExtractFuncInfo(mod), source.str()); + std::string code_str = code.str(); + if (const auto* f = Registry::Get("tvm_callback_metal_postproc")) { + code_str = (*f)(code_str).operator std::string(); + } + return MetalModuleCreate(code_str, fmt, ExtractFuncInfo(mod), source.str()); } TVM_REGISTER_GLOBAL("target.build.metal").set_body_typed(BuildMetal); diff --git a/src/target/source/codegen_metal.h b/src/target/source/codegen_metal.h index 4e464c6636a8..99332e004678 100644 --- a/src/target/source/codegen_metal.h +++ b/src/target/source/codegen_metal.h @@ -51,9 +51,9 @@ class CodeGenMetal final : public CodeGenC { void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final; // overload visitor void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; + void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) + // reuse parent's function. using CodeGenC::PrintType; diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 6e5a9db4d37c..89cc09aeadb2 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -303,6 +303,31 @@ void CodeGenOpenCL::PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr stream << ");\n"; } +void CodeGenOpenCL::PrintVecElemLoadExpr(DataType t, int i, const std::string& value, + std::ostream& os) { // NOLINT(*) + ICHECK_GT(t.lanes(), 1); + if (t.bits() == 8 && (t.is_int() || t.is_uint())) { + if (i != 0) { + os << "|"; + } + os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))"; + return; + } + if (i == 0) { + // NOTE: opencl print things as (float2)(v0, v1) + os << "(("; + PrintType(t, os); + os << ")("; + } + os << value; + if (i != t.lanes() - 1) { + os << ","; + } else { + os << "))"; + } + return; +} + void CodeGenOpenCL::PrintStorageSync(const CallNode* op) { const std::string& sync = op->args[0].as()->value; if (sync == "warp") { @@ -490,6 +515,18 @@ void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // os << "))"; } +void CodeGenOpenCL::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*) + os << "(("; + PrintType(op->dtype, os); + os << ")("; + for (int i = 0; i < op->lanes; i++) { + os << "(" << PrintExpr(op->base) << ")" + << "+(" << PrintExpr(op->stride) << "*" << i << ")"; + if (i != op->lanes - 1) os << ", "; + } + os << "))"; +} + void CodeGenOpenCL::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) if (std::isinf(op->value)) { if (op->value < 0) { diff --git a/src/target/source/codegen_opencl.h b/src/target/source/codegen_opencl.h index bf3046f0d8df..169759976119 100644 --- a/src/target/source/codegen_opencl.h +++ b/src/target/source/codegen_opencl.h @@ -51,6 +51,8 @@ class CodeGenOpenCL final : public CodeGenC { std::string GetVecLoad(DataType t, const BufferNode* buffer, PrimExpr base) final; void PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base, const std::string& value) final; // NOLINT(*) + void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, + std::ostream& os) final; // NOLINT(*) // the address of load/store void PrintVecAddr(const BufferNode* buffer, DataType t, PrimExpr base, std::ostream& os); // NOLINT(*) @@ -62,6 +64,7 @@ class CodeGenOpenCL final : public CodeGenC { // overload visitor void VisitStmt_(const AllocateNode* op) final; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const CastNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) diff --git a/tests/python/unittest/test_target_codegen_metal.py b/tests/python/unittest/test_target_codegen_metal.py index 002cf3c69640..eee54d052342 100644 --- a/tests/python/unittest/test_target_codegen_metal.py +++ b/tests/python/unittest/test_target_codegen_metal.py @@ -22,6 +22,8 @@ from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16 from tvm.contrib import nvcc import tvm.testing +import tvm.script +from tvm.script import tir as T tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") @@ -54,6 +56,31 @@ def check_inf_nan(dev, n, value, dtype): check_inf_nan(dev, 1, float("nan"), "float16") +@tvm.testing.requires_gpu +@tvm.testing.requires_metal +def test_unaligned_vectorize(): + @tvm.script.ir_module + class IRModule: + @T.prim_func + def main(A: T.Buffer((2, 3), "float32"), B: T.Buffer((6,), "float32")): + T.func_attr({"global_symbol": "main"}) + for i0_1 in T.thread_binding(3, thread="threadIdx.x"): + for i0_0 in T.vectorized(2): + with T.block("block"): + vi0 = T.axis.spatial(6, i0_0 * 3 + i0_1) + B[vi0] = A[vi0 // 3, vi0 % 3] + + target = "metal" + dev = tvm.metal() + + a = (np.arange(6).reshape(2, 3)).astype("float32") + a_nd = tvm.nd.array(a, dev) + b_nd = tvm.nd.empty((6,), "float32", dev) + f = tvm.build(IRModule, target=target) + f(a_nd, b_nd) + np.testing.assert_allclose(b_nd.numpy(), a.reshape(6), atol=1e-5, rtol=1e-5) + + @tvm.testing.requires_gpu @tvm.testing.requires_metal def test_metal_erf():