From a7f62ebcea7b8459220d21c327b1a32c0bafa55f Mon Sep 17 00:00:00 2001 From: Thales Sabino Date: Thu, 11 Mar 2021 13:13:20 +0000 Subject: [PATCH 01/11] Add support for AVX-512 VNNI saturating dot products This commit adds support to Intel VNNI saturating dot product instructions vpdpbuds and vpdpwssd This was accomplished by adding a new VectorReduce operation to perform the saturating_add and exposing a new inline reduction saturaring_sum. Users can then write RDom r(0, 4); f(x) = saturating_sum(i32(0), i16(i8(g(x + r)) * u8(h(x + r)))) bool override_associativity_test = true; int vector_width = 4; Var xo, xi; f.update() .split(x, xo, xi, vector_width) .atomic(override_associativity_test) .vectorize(r) .vectorize(xi); To lower the expression into a call to vpdpbuds. Note that override_associativity_test is set to true or halide will fail to prove the associativity of the saturating_add operation Add support for VectorReduce::SaturatingAdd in CodeGen_LLVM Code is correctly generated when no intrinsic is available to perform a saturating dot product. Add vpdpbusds,vpdpwssd tests to simd_op_check Test if the saturating dot product instructions are being generated for AVX512_SapphireRapids targets --- src/CodeGen_LLVM.cpp | 67 +++++++++++++++++------------- src/CodeGen_LLVM.h | 2 +- src/CodeGen_X86.cpp | 37 +++++++++++------ src/IR.h | 1 + src/IRPrinter.cpp | 3 ++ src/InlineReductions.cpp | 16 +++++++ src/InlineReductions.h | 2 + src/VectorizeLoops.cpp | 8 ++++ src/runtime/x86_avx512.ll | 48 +++++++++++++++++++++ test/correctness/simd_op_check.cpp | 17 ++++++++ test/correctness/simd_op_check.h | 4 +- 11 files changed, 162 insertions(+), 43 deletions(-) diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index 133f5d20f779..92337351bbbb 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -1437,12 +1437,11 @@ void CodeGen_LLVM::visit(const Variable *op) { } template -bool CodeGen_LLVM::try_to_fold_vector_reduce(const Op *op) { - const VectorReduce *red = op->a.template as(); - Expr b = op->b; +bool CodeGen_LLVM::try_to_fold_vector_reduce(Expr a, Expr b) { + const VectorReduce *red = a.as(); if (!red) { - red = op->b.template as(); - b = op->a; + red = b.as(); + b = a; } if (red && ((std::is_same::value && red->op == VectorReduce::Add) || @@ -1450,7 +1449,8 @@ bool CodeGen_LLVM::try_to_fold_vector_reduce(const Op *op) { (std::is_same::value && red->op == VectorReduce::Max) || (std::is_same::value && red->op == VectorReduce::Mul) || (std::is_same::value && red->op == VectorReduce::And) || - (std::is_same::value && red->op == VectorReduce::Or))) { + (std::is_same::value && red->op == VectorReduce::Or) || + (std::is_same::value && red->op == VectorReduce::SaturatingAdd))) { codegen_vector_reduce(red, b); return true; } @@ -1465,7 +1465,7 @@ void CodeGen_LLVM::visit(const Add *op) { } // Some backends can fold the add into a vector reduce - if (try_to_fold_vector_reduce(op)) { + if (try_to_fold_vector_reduce(op->a, op->b)) { return; } @@ -1509,7 +1509,7 @@ void CodeGen_LLVM::visit(const Mul *op) { return; } - if (try_to_fold_vector_reduce(op)) { + if (try_to_fold_vector_reduce(op->a, op->b)) { return; } @@ -1569,7 +1569,7 @@ void CodeGen_LLVM::visit(const Min *op) { return; } - if (try_to_fold_vector_reduce(op)) { + if (try_to_fold_vector_reduce(op->a, op->b)) { return; } @@ -1589,7 +1589,7 @@ void CodeGen_LLVM::visit(const Max *op) { return; } - if (try_to_fold_vector_reduce(op)) { + if (try_to_fold_vector_reduce(op->a, op->b)) { return; } @@ -1708,7 +1708,7 @@ void CodeGen_LLVM::visit(const GE *op) { } void CodeGen_LLVM::visit(const And *op) { - if (try_to_fold_vector_reduce(op)) { + if (try_to_fold_vector_reduce(op->a, op->b)) { return; } @@ -1718,7 +1718,7 @@ void CodeGen_LLVM::visit(const And *op) { } void CodeGen_LLVM::visit(const Or *op) { - if (try_to_fold_vector_reduce(op)) { + if (try_to_fold_vector_reduce(op->a, op->b)) { return; } @@ -2806,24 +2806,30 @@ void CodeGen_LLVM::visit(const Call *op) { } } else if (op->is_intrinsic(Call::saturating_add) || op->is_intrinsic(Call::saturating_sub)) { internal_assert(op->args.size() == 2); - std::string intrin; - if (op->type.is_int()) { - intrin = "llvm.s"; - } else { - internal_assert(op->type.is_uint()); - intrin = "llvm.u"; - } - if (op->is_intrinsic(Call::saturating_add)) { - intrin += "add.sat."; - } else { - internal_assert(op->is_intrinsic(Call::saturating_sub)); - intrin += "sub.sat."; - } - if (op->type.lanes() > 1) { - intrin += "v" + std::to_string(op->type.lanes()); + + // Try to fold the vector reduce for a call to saturating_add + const bool folded = try_to_fold_vector_reduce(op->args[0], op->args[1]); + + if (!folded) { + std::string intrin; + if (op->type.is_int()) { + intrin = "llvm.s"; + } else { + internal_assert(op->type.is_uint()); + intrin = "llvm.u"; + } + if (op->is_intrinsic(Call::saturating_add)) { + intrin += "add.sat."; + } else { + internal_assert(op->is_intrinsic(Call::saturating_sub)); + intrin += "sub.sat."; + } + if (op->type.lanes() > 1) { + intrin += "v" + std::to_string(op->type.lanes()); + } + intrin += "i" + std::to_string(op->type.bits()); + value = call_intrin(op->type, op->type.lanes(), intrin, op->args); } - intrin += "i" + std::to_string(op->type.bits()); - value = call_intrin(op->type, op->type.lanes(), intrin, op->args); } else if (op->is_intrinsic(Call::stringify)) { internal_assert(!op->args.empty()); @@ -4371,6 +4377,9 @@ void CodeGen_LLVM::codegen_vector_reduce(const VectorReduce *op, const Expr &ini case VectorReduce::Or: binop = Or::make; break; + case VectorReduce::SaturatingAdd: + binop = saturating_add; + break; } if (op->type.is_bool() && op->op == VectorReduce::Or) { diff --git a/src/CodeGen_LLVM.h b/src/CodeGen_LLVM.h index 955987c8bd06..f2722230aba4 100644 --- a/src/CodeGen_LLVM.h +++ b/src/CodeGen_LLVM.h @@ -567,7 +567,7 @@ class CodeGen_LLVM : public IRVisitor { /** A helper routine for generating folded vector reductions. */ template - bool try_to_fold_vector_reduce(const Op *op); + bool try_to_fold_vector_reduce(Expr a, Expr b); }; } // namespace Internal diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 2038dcce75c8..cc17eb8c21d6 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -178,12 +178,22 @@ const x86Intrinsic intrinsic_defs[] = { {"dpbf16psx16", Float(32, 16), "dot_product", {Float(32, 16), BFloat(16, 32), BFloat(16, 32)}, Target::AVX512_SapphireRapids}, {"dpbf16psx8", Float(32, 8), "dot_product", {Float(32, 8), BFloat(16, 16), BFloat(16, 16)}, Target::AVX512_SapphireRapids}, {"dpbf16psx4", Float(32, 4), "dot_product", {Float(32, 4), BFloat(16, 8), BFloat(16, 8)}, Target::AVX512_SapphireRapids}, + {"dpbusdx16", Int(32, 16), "dot_product", {Int(32, 16), UInt(8, 64), Int(8, 64)}, Target::AVX512_SapphireRapids}, {"dpbusdx8", Int(32, 8), "dot_product", {Int(32, 8), UInt(8, 32), Int(8, 32)}, Target::AVX512_SapphireRapids}, {"dpbusdx4", Int(32, 4), "dot_product", {Int(32, 4), UInt(8, 16), Int(8, 16)}, Target::AVX512_SapphireRapids}, + {"dpwssdx16", Int(32, 16), "dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_SapphireRapids}, {"dpwssdx8", Int(32, 8), "dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_SapphireRapids}, {"dpwssdx4", Int(32, 4), "dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_SapphireRapids}, + + {"dpbusdsx16", Int(32, 16), "saturating_dot_product", {Int(32, 16), UInt(8, 64), Int(8, 64)}, Target::AVX512_SapphireRapids}, + {"dpbusdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), UInt(8, 32), Int(8, 32)}, Target::AVX512_SapphireRapids}, + {"dpbusdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), UInt(8, 16), Int(8, 16)}, Target::AVX512_SapphireRapids}, + + {"dpwssdsx16", Int(32, 16), "saturating_dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_SapphireRapids}, + {"dpwssdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_SapphireRapids}, + {"dpwssdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_SapphireRapids}, }; // clang-format on @@ -505,13 +515,14 @@ void CodeGen_X86::visit(const Call *op) { } void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init) { - if (op->op != VectorReduce::Add) { + if (op->op != VectorReduce::Add && op->op != VectorReduce::SaturatingAdd) { CodeGen_Posix::codegen_vector_reduce(op, init); return; } const int factor = op->value.type().lanes() / op->type.lanes(); struct Pattern { + VectorReduce::Operator reduce_op; int factor; Expr pattern; const char *intrin; @@ -524,24 +535,26 @@ void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init }; // clang-format off static const Pattern patterns[] = { - {2, wild_f32x_ * wild_f32x_, "dot_product", BFloat(16), Pattern::CombineInit}, - {2, i32(widening_mul(wild_i16x_, wild_i16x_)), "dot_product", {}, Pattern::CombineInit}, - {4, i32(widening_mul(wild_u8x_, wild_i8x_)), "dot_product", {}, Pattern::CombineInit}, - {4, i32(widening_mul(wild_i8x_, wild_u8x_)), "dot_product", {}, Pattern::CombineInit | Pattern::SwapOperands}, - {2, i32(widening_mul(wild_i16x_, wild_i16x_)), "pmaddwd", Int(16)}, - {2, i32(widening_mul(wild_i8x_, wild_i8x_)), "pmaddwd", Int(16)}, - {2, i32(widening_mul(wild_i8x_, wild_u8x_)), "pmaddwd", Int(16)}, - {2, i32(widening_mul(wild_u8x_, wild_i8x_)), "pmaddwd", Int(16)}, - {2, i32(widening_mul(wild_u8x_, wild_u8x_)), "pmaddwd", Int(16)}, + {VectorReduce::Add, 2, wild_f32x_ * wild_f32x_, "dot_product", BFloat(16), Pattern::CombineInit}, + {VectorReduce::Add, 2, i32(widening_mul(wild_i16x_, wild_i16x_)), "dot_product", {}, Pattern::CombineInit}, + {VectorReduce::Add, 4, i32(widening_mul(wild_u8x_, wild_i8x_)), "dot_product", {}, Pattern::CombineInit}, + {VectorReduce::Add, 4, i32(widening_mul(wild_i8x_, wild_u8x_)), "dot_product", {}, Pattern::CombineInit | Pattern::SwapOperands}, + {VectorReduce::SaturatingAdd, 2, i32(widening_mul(wild_i16x_, wild_i16x_)), "saturating_dot_product", {}, Pattern::CombineInit}, + {VectorReduce::SaturatingAdd, 4, i32(widening_mul(wild_u8x_, wild_i8x_)), "saturating_dot_product", {}, Pattern::CombineInit}, + {VectorReduce::SaturatingAdd, 4, i32(widening_mul(wild_i8x_, wild_u8x_)), "saturating_dot_product", {}, Pattern::CombineInit | Pattern::SwapOperands}, + {VectorReduce::Add, 2, i32(widening_mul(wild_i16x_, wild_i16x_)), "pmaddwd", Int(16)}, + {VectorReduce::Add, 2, i32(widening_mul(wild_i8x_, wild_i8x_)), "pmaddwd", Int(16)}, + {VectorReduce::Add, 2, i32(widening_mul(wild_i8x_, wild_u8x_)), "pmaddwd", Int(16)}, + {VectorReduce::Add, 2, i32(widening_mul(wild_u8x_, wild_i8x_)), "pmaddwd", Int(16)}, + {VectorReduce::Add, 2, i32(widening_mul(wild_u8x_, wild_u8x_)), "pmaddwd", Int(16)}, // One could do a horizontal widening addition with // pmaddwd against a vector of ones. Currently disabled // because I haven't found case where it's clearly better. }; - // clang-format on std::vector matches; for (const Pattern &p : patterns) { - if (p.factor != factor) { + if (op->op != p.reduce_op || p.factor != factor) { continue; } if (expr_match(p.pattern, op->value, matches)) { diff --git a/src/IR.h b/src/IR.h index ce5e16cac996..bdaf3c7f6e86 100644 --- a/src/IR.h +++ b/src/IR.h @@ -867,6 +867,7 @@ struct VectorReduce : public ExprNode { // operators. typedef enum { Add, + SaturatingAdd, Mul, Min, Max, diff --git a/src/IRPrinter.cpp b/src/IRPrinter.cpp index 62fdb705997b..b02cc861ce35 100644 --- a/src/IRPrinter.cpp +++ b/src/IRPrinter.cpp @@ -279,6 +279,9 @@ ostream &operator<<(ostream &out, const VectorReduce::Operator &op) { case VectorReduce::Add: out << "Add"; break; + case VectorReduce::SaturatingAdd: + out << "SaturatingAdd"; + break; case VectorReduce::Mul: out << "Mul"; break; diff --git a/src/InlineReductions.cpp b/src/InlineReductions.cpp index 10714a4a1358..2fd49d701e0e 100644 --- a/src/InlineReductions.cpp +++ b/src/InlineReductions.cpp @@ -123,6 +123,22 @@ Expr sum(const RDom &r, Expr e, const std::string &name) { return f(v.call_args); } +Expr saturating_sum(Expr init_val, Expr e, const std::string &name) { + return saturating_sum(RDom(), init_val, e, name); +} + +Expr saturating_sum(const RDom &r, Expr init_val, Expr e, const std::string &name) { + Internal::FindFreeVars v(r, name); + e = v.mutate(common_subexpression_elimination(e)); + + user_assert(v.rdom.defined()) << "Expression passed to saturating_sum must reference a reduction domain"; + + Func f(name); + f(v.free_vars) = init_val; + f(v.free_vars) = Internal::saturating_add(f(v.free_vars), e); + return f(v.call_args); +} + Expr product(Expr e, const std::string &name) { return product(RDom(), std::move(e), name); } diff --git a/src/InlineReductions.h b/src/InlineReductions.h index 1b953e17162e..037bdb0f6836 100644 --- a/src/InlineReductions.h +++ b/src/InlineReductions.h @@ -36,6 +36,7 @@ namespace Halide { */ //@{ Expr sum(Expr, const std::string &s = "sum"); +Expr saturating_sum(Expr init_val, Expr e, const std::string &s = "saturating_sum"); Expr product(Expr, const std::string &s = "product"); Expr maximum(Expr, const std::string &s = "maximum"); Expr minimum(Expr, const std::string &s = "minimum"); @@ -52,6 +53,7 @@ Expr minimum(Expr, const std::string &s = "minimum"); */ // @{ Expr sum(const RDom &, Expr, const std::string &s = "sum"); +Expr saturating_sum(const RDom &r, Expr init_val, Expr e, const std::string &s = "saturating_sum"); Expr product(const RDom &, Expr, const std::string &s = "product"); Expr maximum(const RDom &, Expr, const std::string &s = "maximum"); Expr minimum(const RDom &, Expr, const std::string &s = "minimum"); diff --git a/src/VectorizeLoops.cpp b/src/VectorizeLoops.cpp index e31126b4949e..2e56a53d2da8 100644 --- a/src/VectorizeLoops.cpp +++ b/src/VectorizeLoops.cpp @@ -1093,6 +1093,12 @@ class VectorSubs : public IRMutator { reduce_op = VectorReduce::Or; } } + } else if (const Call *call_op = store->value.as()) { + if (call_op->is_intrinsic(Call::saturating_add)) { + a = call_op->args[0]; + b = call_op->args[1]; + reduce_op = VectorReduce::SaturatingAdd; + } } if (!a.defined() || !b.defined()) { @@ -1175,6 +1181,8 @@ class VectorSubs : public IRMutator { return a && b; case VectorReduce::Or: return a || b; + case VectorReduce::SaturatingAdd: + return saturating_add(a, b); } return Expr(); }; diff --git a/src/runtime/x86_avx512.ll b/src/runtime/x86_avx512.ll index 904fabe9368e..730014c99d9a 100644 --- a/src/runtime/x86_avx512.ll +++ b/src/runtime/x86_avx512.ll @@ -90,3 +90,51 @@ define weak_odr <4 x i32> @dpwssdx4(<4 x i32> %init, <8 x i16> %a, <8 x i16> %b ret <4 x i32> %3 } declare <4 x i32> @llvm.x86.avx512.vpdpwssd.128(<4 x i32>, <4 x i32>, <4 x i32>) + +define weak_odr <16 x i32> @dpbusdsx16(<16 x i32> %init, <64 x i8> %a, <64 x i8> %b) nounwind alwaysinline { + %1 = bitcast <64 x i8> %a to <16 x i32> + %2 = bitcast <64 x i8> %b to <16 x i32> + %3 = tail call <16 x i32> @llvm.x86.avx512.vpdpbusds.512(<16 x i32> %init, <16 x i32> %1, <16 x i32> %2) + ret <16 x i32> %3 +} +declare <16 x i32> @llvm.x86.avx512.vpdpbusds.512(<16 x i32>, <16 x i32>, <16 x i32>) + +define weak_odr <8 x i32> @dpbusdsx8(<8 x i32> %init, <32 x i8> %a, <32 x i8> %b) nounwind alwaysinline { + %1 = bitcast <32 x i8> %a to <8 x i32> + %2 = bitcast <32 x i8> %b to <8 x i32> + %3 = tail call <8 x i32> @llvm.x86.avx512.vpdpbusds.256(<8 x i32> %init, <8 x i32> %1, <8 x i32> %2) + ret <8 x i32> %3 +} +declare <8 x i32> @llvm.x86.avx512.vpdpbusds.256(<8 x i32>, <8 x i32>, <8 x i32>) + +define weak_odr <4 x i32> @dpbusdsx4(<4 x i32> %init, <16 x i8> %a, <16 x i8> %b) nounwind alwaysinline { + %1 = bitcast <16 x i8> %a to <4 x i32> + %2 = bitcast <16 x i8> %b to <4 x i32> + %3 = tail call <4 x i32> @llvm.x86.avx512.vpdpbusds.128(<4 x i32> %init, <4 x i32> %1, <4 x i32> %2) + ret <4 x i32> %3 +} +declare <4 x i32> @llvm.x86.avx512.vpdpbusds.128(<4 x i32>, <4 x i32>, <4 x i32>) + +define weak_odr <16 x i32> @dpwssdsx16(<16 x i32> %init, <32 x i16> %a, <32 x i16> %b) nounwind alwaysinline { + %1 = bitcast <32 x i16> %a to <16 x i32> + %2 = bitcast <32 x i16> %b to <16 x i32> + %3 = tail call <16 x i32> @llvm.x86.avx512.vpdpwssds.512(<16 x i32> %init, <16 x i32> %1, <16 x i32> %2) + ret <16 x i32> %3 +} +declare <16 x i32> @llvm.x86.avx512.vpdpwssds.512(<16 x i32>, <16 x i32>, <16 x i32>) + +define weak_odr <8 x i32> @dpwssdsx8(<8 x i32> %init, <16 x i16> %a, <16 x i16> %b) nounwind alwaysinline { + %1 = bitcast <16 x i16> %a to <8 x i32> + %2 = bitcast <16 x i16> %b to <8 x i32> + %3 = tail call <8 x i32> @llvm.x86.avx512.vpdpwssds.256(<8 x i32> %init, <8 x i32> %1, <8 x i32> %2) + ret <8 x i32> %3 +} +declare <8 x i32> @llvm.x86.avx512.vpdpwssds.256(<8 x i32>, <8 x i32>, <8 x i32>) + +define weak_odr <4 x i32> @dpwssdsx4(<4 x i32> %init, <8 x i16> %a, <8 x i16> %b) nounwind alwaysinline { + %1 = bitcast <8 x i16> %a to <4 x i32> + %2 = bitcast <8 x i16> %b to <4 x i32> + %3 = tail call <4 x i32> @llvm.x86.avx512.vpdpwssds.128(<4 x i32> %init, <4 x i32> %1, <4 x i32> %2) + ret <4 x i32> %3 +} +declare <4 x i32> @llvm.x86.avx512.vpdpwssds.128(<4 x i32>, <4 x i32>, <4 x i32>) diff --git a/test/correctness/simd_op_check.cpp b/test/correctness/simd_op_check.cpp index a67af7602a18..65ebe991336c 100644 --- a/test/correctness/simd_op_check.cpp +++ b/test/correctness/simd_op_check.cpp @@ -533,6 +533,23 @@ class SimdOpCheck : public SimdOpCheckTest { check("vpdpbusd*xmm", 4, sum(i32(in_u8(4 * x + r)) * in_i8(4 * x + r + 32))); check("vpdpbusd*xmm", 4, sum(i32(in_i8(4 * x + r)) * in_u8(4 * x + r + 32))); } + { + // 16 bit, 2 element saturaing dot product + RDom r(0, 2); + check("vpdpwssds*zmm", 16, saturating_sum(i32(0), i32(in_i16(2 * x + r)) * in_i16(2 * x + r + 32))); + check("vpdpwssds*ymm", 8, saturating_sum(i32(0), i32(in_i16(2 * x + r)) * in_i16(2 * x + r + 32))); + check("vpdpwssds*xmm", 4, saturating_sum(i32(0), i32(in_i16(2 * x + r)) * in_i16(2 * x + r + 32))); + } + { + // 8 bit, 4 element saturating dot product + RDom r(0, 4); + check("vpdpbusds*zmm", 16, saturating_sum(i32(0), i32(in_u8(4 * x + r)) * in_i8(4 * x + r + 32))); + check("vpdpbusds*zmm", 16, saturating_sum(i32(0), i32(in_i8(4 * x + r)) * in_u8(4 * x + r + 32))); + check("vpdpbusds*ymm", 8, saturating_sum(i32(0), i32(in_u8(4 * x + r)) * in_i8(4 * x + r + 32))); + check("vpdpbusds*ymm", 8, saturating_sum(i32(0), i32(in_i8(4 * x + r)) * in_u8(4 * x + r + 32))); + check("vpdpbusds*xmm", 4, saturating_sum(i32(0), i32(in_u8(4 * x + r)) * in_i8(4 * x + r + 32))); + check("vpdpbusds*xmm", 4, saturating_sum(i32(0), i32(in_i8(4 * x + r)) * in_u8(4 * x + r + 32))); + } } } diff --git a/test/correctness/simd_op_check.h b/test/correctness/simd_op_check.h index 43786143aecd..c496a7c2e19f 100644 --- a/test/correctness/simd_op_check.h +++ b/test/correctness/simd_op_check.h @@ -171,6 +171,7 @@ class SimdOpCheckTest { Internal::Function f(op->func); if (f.has_update_definition()) { inline_reduction = f; + override_associativity_test = op->name.find("saturating_sum") != std::string::npos; result = true; } } @@ -179,6 +180,7 @@ class SimdOpCheckTest { public: Internal::Function inline_reduction; + bool override_associativity_test = false; bool result = false; } has_inline_reduction; e.accept(&has_inline_reduction); @@ -208,7 +210,7 @@ class SimdOpCheckTest { g.compute_at(f, x) .update() .split(x, xo, xi, vector_width) - .atomic() + .atomic(has_inline_reduction.override_associativity_test) .vectorize(g.rvars()[0]) .vectorize(xi); } From 1244498216af95c467b9e8eb891d3be3329ab420 Mon Sep 17 00:00:00 2001 From: Thales Sabino Date: Mon, 15 Mar 2021 11:31:14 +0000 Subject: [PATCH 02/11] Improve code according to report from clang-tidy --- src/CodeGen_LLVM.cpp | 2 +- src/CodeGen_LLVM.h | 2 +- src/InlineReductions.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index 92337351bbbb..c89c6445fc9b 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -1437,7 +1437,7 @@ void CodeGen_LLVM::visit(const Variable *op) { } template -bool CodeGen_LLVM::try_to_fold_vector_reduce(Expr a, Expr b) { +bool CodeGen_LLVM::try_to_fold_vector_reduce(const Expr &a, Expr b) { const VectorReduce *red = a.as(); if (!red) { red = b.as(); diff --git a/src/CodeGen_LLVM.h b/src/CodeGen_LLVM.h index f2722230aba4..61393c764f00 100644 --- a/src/CodeGen_LLVM.h +++ b/src/CodeGen_LLVM.h @@ -567,7 +567,7 @@ class CodeGen_LLVM : public IRVisitor { /** A helper routine for generating folded vector reductions. */ template - bool try_to_fold_vector_reduce(Expr a, Expr b); + bool try_to_fold_vector_reduce(const Expr &a, Expr b); }; } // namespace Internal diff --git a/src/InlineReductions.cpp b/src/InlineReductions.cpp index 2fd49d701e0e..71d0fba5a439 100644 --- a/src/InlineReductions.cpp +++ b/src/InlineReductions.cpp @@ -124,7 +124,7 @@ Expr sum(const RDom &r, Expr e, const std::string &name) { } Expr saturating_sum(Expr init_val, Expr e, const std::string &name) { - return saturating_sum(RDom(), init_val, e, name); + return saturating_sum(RDom(), std::move(init_val), std::move(e), name); } Expr saturating_sum(const RDom &r, Expr init_val, Expr e, const std::string &name) { From d4697c74a7b79f4725071dddb76eb31e223630e5 Mon Sep 17 00:00:00 2001 From: Thales Sabino Date: Mon, 15 Mar 2021 12:24:59 +0000 Subject: [PATCH 03/11] Make init_val a const ref since it only used that way inside saturating_sum --- src/InlineReductions.cpp | 6 +++--- src/InlineReductions.h | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/InlineReductions.cpp b/src/InlineReductions.cpp index 71d0fba5a439..0ac191a2b5aa 100644 --- a/src/InlineReductions.cpp +++ b/src/InlineReductions.cpp @@ -123,11 +123,11 @@ Expr sum(const RDom &r, Expr e, const std::string &name) { return f(v.call_args); } -Expr saturating_sum(Expr init_val, Expr e, const std::string &name) { - return saturating_sum(RDom(), std::move(init_val), std::move(e), name); +Expr saturating_sum(const Expr &init_val, Expr e, const std::string &name) { + return saturating_sum(RDom(), init_val, std::move(e), name); } -Expr saturating_sum(const RDom &r, Expr init_val, Expr e, const std::string &name) { +Expr saturating_sum(const RDom &r, const Expr& init_val, Expr e, const std::string &name) { Internal::FindFreeVars v(r, name); e = v.mutate(common_subexpression_elimination(e)); diff --git a/src/InlineReductions.h b/src/InlineReductions.h index 037bdb0f6836..ea75cfc18e27 100644 --- a/src/InlineReductions.h +++ b/src/InlineReductions.h @@ -36,7 +36,7 @@ namespace Halide { */ //@{ Expr sum(Expr, const std::string &s = "sum"); -Expr saturating_sum(Expr init_val, Expr e, const std::string &s = "saturating_sum"); +Expr saturating_sum(const Expr &init_val, Expr e, const std::string &s = "saturating_sum"); Expr product(Expr, const std::string &s = "product"); Expr maximum(Expr, const std::string &s = "maximum"); Expr minimum(Expr, const std::string &s = "minimum"); @@ -53,7 +53,7 @@ Expr minimum(Expr, const std::string &s = "minimum"); */ // @{ Expr sum(const RDom &, Expr, const std::string &s = "sum"); -Expr saturating_sum(const RDom &r, Expr init_val, Expr e, const std::string &s = "saturating_sum"); +Expr saturating_sum(const RDom &r, const Expr &init_val, Expr e, const std::string &s = "saturating_sum"); Expr product(const RDom &, Expr, const std::string &s = "product"); Expr maximum(const RDom &, Expr, const std::string &s = "maximum"); Expr minimum(const RDom &, Expr, const std::string &s = "minimum"); From 1fbaccf3cd9ca07a50069e10518d089aad9ce6c6 Mon Sep 17 00:00:00 2001 From: Thales Sabino Date: Mon, 15 Mar 2021 13:41:11 +0000 Subject: [PATCH 04/11] clang-format --- src/InlineReductions.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/InlineReductions.cpp b/src/InlineReductions.cpp index 0ac191a2b5aa..1b3881ae137f 100644 --- a/src/InlineReductions.cpp +++ b/src/InlineReductions.cpp @@ -127,7 +127,7 @@ Expr saturating_sum(const Expr &init_val, Expr e, const std::string &name) { return saturating_sum(RDom(), init_val, std::move(e), name); } -Expr saturating_sum(const RDom &r, const Expr& init_val, Expr e, const std::string &name) { +Expr saturating_sum(const RDom &r, const Expr &init_val, Expr e, const std::string &name) { Internal::FindFreeVars v(r, name); e = v.mutate(common_subexpression_elimination(e)); From 05555090a644d118ebe98c79a40e1f3efb80a14f Mon Sep 17 00:00:00 2001 From: Thales Sabino Date: Mon, 15 Mar 2021 17:57:02 +0000 Subject: [PATCH 05/11] Revert removal of clang-format tag in CodeGen_X86.cpp --- src/CodeGen_X86.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index cc17eb8c21d6..6d6ff81a3666 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -551,6 +551,7 @@ void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init // pmaddwd against a vector of ones. Currently disabled // because I haven't found case where it's clearly better. }; + // clang-format on std::vector matches; for (const Pattern &p : patterns) { From 86afec8821af170047a8961b1a23f39a16b37fd3 Mon Sep 17 00:00:00 2001 From: Thales Sabino Date: Tue, 16 Mar 2021 10:30:57 +0000 Subject: [PATCH 06/11] Add SaturatingAdd case Monotonic VectorReduce visit --- src/Monotonic.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Monotonic.cpp b/src/Monotonic.cpp index 7383ed377dff..dd81eec80ad9 100644 --- a/src/Monotonic.cpp +++ b/src/Monotonic.cpp @@ -536,6 +536,7 @@ class DerivativeBounds : public IRVisitor { op->value.accept(this); switch (op->op) { case VectorReduce::Add: + case VectorReduce::SaturatingAdd: result = multiply(result, op->value.type().lanes() / op->type.lanes()); break; case VectorReduce::Min: From d441f76d14c81379291c4bc7afa207c52c186736 Mon Sep 17 00:00:00 2001 From: Thales Sabino Date: Tue, 16 Mar 2021 10:33:43 +0000 Subject: [PATCH 07/11] Bail out in Bounds when dealing with a SaturatingAdd VectorReduce --- src/Bounds.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Bounds.cpp b/src/Bounds.cpp index 7b94d77e2e1b..b0d7926ef8b1 100644 --- a/src/Bounds.cpp +++ b/src/Bounds.cpp @@ -1582,6 +1582,7 @@ class Bounds : public IRVisitor { interval.min *= factor; } break; + case VectorReduce::SaturatingAdd: case VectorReduce::Mul: // Technically there are some things we could say // here. E.g. if all the lanes are positive then we're From 872c5a25f533b03b99527d46c5bc1d54fb4572fa Mon Sep 17 00:00:00 2001 From: Thales Sabino Date: Tue, 16 Mar 2021 14:30:03 +0000 Subject: [PATCH 08/11] Move saturating_mul to Simplify_Internal.h so it can be used in Simplify_Exprs.cpp --- src/Simplify_Exprs.cpp | 8 ++++++++ src/Simplify_Internal.h | 14 ++++++++++++++ src/Simplify_Mul.cpp | 14 -------------- 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/src/Simplify_Exprs.cpp b/src/Simplify_Exprs.cpp index d0b2da780e40..9479e324e9a5 100644 --- a/src/Simplify_Exprs.cpp +++ b/src/Simplify_Exprs.cpp @@ -74,6 +74,14 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *bounds) { bounds->max *= factor; } break; + case VectorReduce::SaturatingAdd: + if (bounds->min_defined) { + bounds->min = saturating_mul(bounds->min, factor); + } + if (bounds->max_defined) { + bounds->max = saturating_mul(bounds->max, factor); + } + break; case VectorReduce::Mul: // Don't try to infer anything about bounds. Leave the // alignment unchanged even though we could theoretically diff --git a/src/Simplify_Internal.h b/src/Simplify_Internal.h index f18e29645ac7..210011aa9521 100644 --- a/src/Simplify_Internal.h +++ b/src/Simplify_Internal.h @@ -28,6 +28,20 @@ namespace Halide { namespace Internal { +namespace { +inline int64_t saturating_mul(int64_t a, int64_t b) { + if (mul_would_overflow(64, a, b)) { + if ((a > 0) == (b > 0)) { + return INT64_MAX; + } else { + return INT64_MIN; + } + } else { + return a * b; + } +} +} // namespace + class Simplify : public VariadicVisitor { using Super = VariadicVisitor; diff --git a/src/Simplify_Mul.cpp b/src/Simplify_Mul.cpp index 08c194002316..631ad7b99d5f 100644 --- a/src/Simplify_Mul.cpp +++ b/src/Simplify_Mul.cpp @@ -3,20 +3,6 @@ namespace Halide { namespace Internal { -namespace { -int64_t saturating_mul(int64_t a, int64_t b) { - if (mul_would_overflow(64, a, b)) { - if ((a > 0) == (b > 0)) { - return INT64_MAX; - } else { - return INT64_MIN; - } - } else { - return a * b; - } -} -} // namespace - Expr Simplify::visit(const Mul *op, ExprInfo *bounds) { ExprInfo a_bounds, b_bounds; Expr a = mutate(op->a, &a_bounds); From a16d6615b760ae14ac080745493cba30e0b93ef4 Mon Sep 17 00:00:00 2001 From: Thales Sabino Date: Tue, 23 Mar 2021 10:10:53 +0000 Subject: [PATCH 09/11] Remove init_val from the saturating_sum inline reduction --- src/InlineReductions.cpp | 8 ++++---- src/InlineReductions.h | 4 ++-- test/correctness/simd_op_check.cpp | 18 +++++++++--------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/InlineReductions.cpp b/src/InlineReductions.cpp index 1b3881ae137f..cdeb43ce79a2 100644 --- a/src/InlineReductions.cpp +++ b/src/InlineReductions.cpp @@ -123,18 +123,18 @@ Expr sum(const RDom &r, Expr e, const std::string &name) { return f(v.call_args); } -Expr saturating_sum(const Expr &init_val, Expr e, const std::string &name) { - return saturating_sum(RDom(), init_val, std::move(e), name); +Expr saturating_sum(Expr e, const std::string &name) { + return saturating_sum(RDom(), std::move(e), name); } -Expr saturating_sum(const RDom &r, const Expr &init_val, Expr e, const std::string &name) { +Expr saturating_sum(const RDom &r, Expr e, const std::string &name) { Internal::FindFreeVars v(r, name); e = v.mutate(common_subexpression_elimination(e)); user_assert(v.rdom.defined()) << "Expression passed to saturating_sum must reference a reduction domain"; Func f(name); - f(v.free_vars) = init_val; + f(v.free_vars) = 0; f(v.free_vars) = Internal::saturating_add(f(v.free_vars), e); return f(v.call_args); } diff --git a/src/InlineReductions.h b/src/InlineReductions.h index ea75cfc18e27..143ee8cebe51 100644 --- a/src/InlineReductions.h +++ b/src/InlineReductions.h @@ -36,7 +36,7 @@ namespace Halide { */ //@{ Expr sum(Expr, const std::string &s = "sum"); -Expr saturating_sum(const Expr &init_val, Expr e, const std::string &s = "saturating_sum"); +Expr saturating_sum(Expr, const std::string &s = "saturating_sum"); Expr product(Expr, const std::string &s = "product"); Expr maximum(Expr, const std::string &s = "maximum"); Expr minimum(Expr, const std::string &s = "minimum"); @@ -53,7 +53,7 @@ Expr minimum(Expr, const std::string &s = "minimum"); */ // @{ Expr sum(const RDom &, Expr, const std::string &s = "sum"); -Expr saturating_sum(const RDom &r, const Expr &init_val, Expr e, const std::string &s = "saturating_sum"); +Expr saturating_sum(const RDom &r, Expr e, const std::string &s = "saturating_sum"); Expr product(const RDom &, Expr, const std::string &s = "product"); Expr maximum(const RDom &, Expr, const std::string &s = "maximum"); Expr minimum(const RDom &, Expr, const std::string &s = "minimum"); diff --git a/test/correctness/simd_op_check.cpp b/test/correctness/simd_op_check.cpp index 65ebe991336c..c8458610321a 100644 --- a/test/correctness/simd_op_check.cpp +++ b/test/correctness/simd_op_check.cpp @@ -536,19 +536,19 @@ class SimdOpCheck : public SimdOpCheckTest { { // 16 bit, 2 element saturaing dot product RDom r(0, 2); - check("vpdpwssds*zmm", 16, saturating_sum(i32(0), i32(in_i16(2 * x + r)) * in_i16(2 * x + r + 32))); - check("vpdpwssds*ymm", 8, saturating_sum(i32(0), i32(in_i16(2 * x + r)) * in_i16(2 * x + r + 32))); - check("vpdpwssds*xmm", 4, saturating_sum(i32(0), i32(in_i16(2 * x + r)) * in_i16(2 * x + r + 32))); + check("vpdpwssds*zmm", 16, saturating_sum(i32(in_i16(2 * x + r)) * in_i16(2 * x + r + 32))); + check("vpdpwssds*ymm", 8, saturating_sum(i32(in_i16(2 * x + r)) * in_i16(2 * x + r + 32))); + check("vpdpwssds*xmm", 4, saturating_sum(i32(in_i16(2 * x + r)) * in_i16(2 * x + r + 32))); } { // 8 bit, 4 element saturating dot product RDom r(0, 4); - check("vpdpbusds*zmm", 16, saturating_sum(i32(0), i32(in_u8(4 * x + r)) * in_i8(4 * x + r + 32))); - check("vpdpbusds*zmm", 16, saturating_sum(i32(0), i32(in_i8(4 * x + r)) * in_u8(4 * x + r + 32))); - check("vpdpbusds*ymm", 8, saturating_sum(i32(0), i32(in_u8(4 * x + r)) * in_i8(4 * x + r + 32))); - check("vpdpbusds*ymm", 8, saturating_sum(i32(0), i32(in_i8(4 * x + r)) * in_u8(4 * x + r + 32))); - check("vpdpbusds*xmm", 4, saturating_sum(i32(0), i32(in_u8(4 * x + r)) * in_i8(4 * x + r + 32))); - check("vpdpbusds*xmm", 4, saturating_sum(i32(0), i32(in_i8(4 * x + r)) * in_u8(4 * x + r + 32))); + check("vpdpbusds*zmm", 16, saturating_sum(i32(in_u8(4 * x + r)) * in_i8(4 * x + r + 32))); + check("vpdpbusds*zmm", 16, saturating_sum(i32(in_i8(4 * x + r)) * in_u8(4 * x + r + 32))); + check("vpdpbusds*ymm", 8, saturating_sum(i32(in_u8(4 * x + r)) * in_i8(4 * x + r + 32))); + check("vpdpbusds*ymm", 8, saturating_sum(i32(in_i8(4 * x + r)) * in_u8(4 * x + r + 32))); + check("vpdpbusds*xmm", 4, saturating_sum(i32(in_u8(4 * x + r)) * in_i8(4 * x + r + 32))); + check("vpdpbusds*xmm", 4, saturating_sum(i32(in_i8(4 * x + r)) * in_u8(4 * x + r + 32))); } } } From a1d2685499d1779bde36a8323c61235a8f232c81 Mon Sep 17 00:00:00 2001 From: Thales Sabino Date: Tue, 23 Mar 2021 16:25:09 +0000 Subject: [PATCH 10/11] Unconditionally override the associativity test in the simd_op_check tests --- test/correctness/simd_op_check.h | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/correctness/simd_op_check.h b/test/correctness/simd_op_check.h index c496a7c2e19f..49ec21f66c7e 100644 --- a/test/correctness/simd_op_check.h +++ b/test/correctness/simd_op_check.h @@ -171,7 +171,6 @@ class SimdOpCheckTest { Internal::Function f(op->func); if (f.has_update_definition()) { inline_reduction = f; - override_associativity_test = op->name.find("saturating_sum") != std::string::npos; result = true; } } @@ -180,7 +179,6 @@ class SimdOpCheckTest { public: Internal::Function inline_reduction; - bool override_associativity_test = false; bool result = false; } has_inline_reduction; e.accept(&has_inline_reduction); @@ -210,7 +208,7 @@ class SimdOpCheckTest { g.compute_at(f, x) .update() .split(x, xo, xi, vector_width) - .atomic(has_inline_reduction.override_associativity_test) + .atomic(true) .vectorize(g.rvars()[0]) .vectorize(xi); } From 57b7da6a4092eb3c78c10e0ed3fbc368e900990c Mon Sep 17 00:00:00 2001 From: Thales Sabino Date: Tue, 30 Mar 2021 09:59:03 +0100 Subject: [PATCH 11/11] Remove annonymous namespace from saturating_mul utility --- src/Simplify_Internal.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/Simplify_Internal.h b/src/Simplify_Internal.h index 210011aa9521..1a59ccd32821 100644 --- a/src/Simplify_Internal.h +++ b/src/Simplify_Internal.h @@ -28,7 +28,6 @@ namespace Halide { namespace Internal { -namespace { inline int64_t saturating_mul(int64_t a, int64_t b) { if (mul_would_overflow(64, a, b)) { if ((a > 0) == (b > 0)) { @@ -40,7 +39,6 @@ inline int64_t saturating_mul(int64_t a, int64_t b) { return a * b; } } -} // namespace class Simplify : public VariadicVisitor { using Super = VariadicVisitor;