diff --git a/src/Bounds.cpp b/src/Bounds.cpp index 7b94d77e2e1b..b0d7926ef8b1 100644 --- a/src/Bounds.cpp +++ b/src/Bounds.cpp @@ -1582,6 +1582,7 @@ class Bounds : public IRVisitor { interval.min *= factor; } break; + case VectorReduce::SaturatingAdd: case VectorReduce::Mul: // Technically there are some things we could say // here. E.g. if all the lanes are positive then we're diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index 133f5d20f779..c89c6445fc9b 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -1437,12 +1437,11 @@ void CodeGen_LLVM::visit(const Variable *op) { } template -bool CodeGen_LLVM::try_to_fold_vector_reduce(const Op *op) { - const VectorReduce *red = op->a.template as(); - Expr b = op->b; +bool CodeGen_LLVM::try_to_fold_vector_reduce(const Expr &a, Expr b) { + const VectorReduce *red = a.as(); if (!red) { - red = op->b.template as(); - b = op->a; + red = b.as(); + b = a; } if (red && ((std::is_same::value && red->op == VectorReduce::Add) || @@ -1450,7 +1449,8 @@ bool CodeGen_LLVM::try_to_fold_vector_reduce(const Op *op) { (std::is_same::value && red->op == VectorReduce::Max) || (std::is_same::value && red->op == VectorReduce::Mul) || (std::is_same::value && red->op == VectorReduce::And) || - (std::is_same::value && red->op == VectorReduce::Or))) { + (std::is_same::value && red->op == VectorReduce::Or) || + (std::is_same::value && red->op == VectorReduce::SaturatingAdd))) { codegen_vector_reduce(red, b); return true; } @@ -1465,7 +1465,7 @@ void CodeGen_LLVM::visit(const Add *op) { } // Some backends can fold the add into a vector reduce - if (try_to_fold_vector_reduce(op)) { + if (try_to_fold_vector_reduce(op->a, op->b)) { return; } @@ -1509,7 +1509,7 @@ void CodeGen_LLVM::visit(const Mul *op) { return; } - if (try_to_fold_vector_reduce(op)) { + if (try_to_fold_vector_reduce(op->a, op->b)) { return; } @@ -1569,7 +1569,7 @@ void CodeGen_LLVM::visit(const Min *op) { return; } - if (try_to_fold_vector_reduce(op)) { + if (try_to_fold_vector_reduce(op->a, op->b)) { return; } @@ -1589,7 +1589,7 @@ void CodeGen_LLVM::visit(const Max *op) { return; } - if (try_to_fold_vector_reduce(op)) { + if (try_to_fold_vector_reduce(op->a, op->b)) { return; } @@ -1708,7 +1708,7 @@ void CodeGen_LLVM::visit(const GE *op) { } void CodeGen_LLVM::visit(const And *op) { - if (try_to_fold_vector_reduce(op)) { + if (try_to_fold_vector_reduce(op->a, op->b)) { return; } @@ -1718,7 +1718,7 @@ void CodeGen_LLVM::visit(const And *op) { } void CodeGen_LLVM::visit(const Or *op) { - if (try_to_fold_vector_reduce(op)) { + if (try_to_fold_vector_reduce(op->a, op->b)) { return; } @@ -2806,24 +2806,30 @@ void CodeGen_LLVM::visit(const Call *op) { } } else if (op->is_intrinsic(Call::saturating_add) || op->is_intrinsic(Call::saturating_sub)) { internal_assert(op->args.size() == 2); - std::string intrin; - if (op->type.is_int()) { - intrin = "llvm.s"; - } else { - internal_assert(op->type.is_uint()); - intrin = "llvm.u"; - } - if (op->is_intrinsic(Call::saturating_add)) { - intrin += "add.sat."; - } else { - internal_assert(op->is_intrinsic(Call::saturating_sub)); - intrin += "sub.sat."; - } - if (op->type.lanes() > 1) { - intrin += "v" + std::to_string(op->type.lanes()); + + // Try to fold the vector reduce for a call to saturating_add + const bool folded = try_to_fold_vector_reduce(op->args[0], op->args[1]); + + if (!folded) { + std::string intrin; + if (op->type.is_int()) { + intrin = "llvm.s"; + } else { + internal_assert(op->type.is_uint()); + intrin = "llvm.u"; + } + if (op->is_intrinsic(Call::saturating_add)) { + intrin += "add.sat."; + } else { + internal_assert(op->is_intrinsic(Call::saturating_sub)); + intrin += "sub.sat."; + } + if (op->type.lanes() > 1) { + intrin += "v" + std::to_string(op->type.lanes()); + } + intrin += "i" + std::to_string(op->type.bits()); + value = call_intrin(op->type, op->type.lanes(), intrin, op->args); } - intrin += "i" + std::to_string(op->type.bits()); - value = call_intrin(op->type, op->type.lanes(), intrin, op->args); } else if (op->is_intrinsic(Call::stringify)) { internal_assert(!op->args.empty()); @@ -4371,6 +4377,9 @@ void CodeGen_LLVM::codegen_vector_reduce(const VectorReduce *op, const Expr &ini case VectorReduce::Or: binop = Or::make; break; + case VectorReduce::SaturatingAdd: + binop = saturating_add; + break; } if (op->type.is_bool() && op->op == VectorReduce::Or) { diff --git a/src/CodeGen_LLVM.h b/src/CodeGen_LLVM.h index 955987c8bd06..61393c764f00 100644 --- a/src/CodeGen_LLVM.h +++ b/src/CodeGen_LLVM.h @@ -567,7 +567,7 @@ class CodeGen_LLVM : public IRVisitor { /** A helper routine for generating folded vector reductions. */ template - bool try_to_fold_vector_reduce(const Op *op); + bool try_to_fold_vector_reduce(const Expr &a, Expr b); }; } // namespace Internal diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 2038dcce75c8..6d6ff81a3666 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -178,12 +178,22 @@ const x86Intrinsic intrinsic_defs[] = { {"dpbf16psx16", Float(32, 16), "dot_product", {Float(32, 16), BFloat(16, 32), BFloat(16, 32)}, Target::AVX512_SapphireRapids}, {"dpbf16psx8", Float(32, 8), "dot_product", {Float(32, 8), BFloat(16, 16), BFloat(16, 16)}, Target::AVX512_SapphireRapids}, {"dpbf16psx4", Float(32, 4), "dot_product", {Float(32, 4), BFloat(16, 8), BFloat(16, 8)}, Target::AVX512_SapphireRapids}, + {"dpbusdx16", Int(32, 16), "dot_product", {Int(32, 16), UInt(8, 64), Int(8, 64)}, Target::AVX512_SapphireRapids}, {"dpbusdx8", Int(32, 8), "dot_product", {Int(32, 8), UInt(8, 32), Int(8, 32)}, Target::AVX512_SapphireRapids}, {"dpbusdx4", Int(32, 4), "dot_product", {Int(32, 4), UInt(8, 16), Int(8, 16)}, Target::AVX512_SapphireRapids}, + {"dpwssdx16", Int(32, 16), "dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_SapphireRapids}, {"dpwssdx8", Int(32, 8), "dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_SapphireRapids}, {"dpwssdx4", Int(32, 4), "dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_SapphireRapids}, + + {"dpbusdsx16", Int(32, 16), "saturating_dot_product", {Int(32, 16), UInt(8, 64), Int(8, 64)}, Target::AVX512_SapphireRapids}, + {"dpbusdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), UInt(8, 32), Int(8, 32)}, Target::AVX512_SapphireRapids}, + {"dpbusdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), UInt(8, 16), Int(8, 16)}, Target::AVX512_SapphireRapids}, + + {"dpwssdsx16", Int(32, 16), "saturating_dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_SapphireRapids}, + {"dpwssdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_SapphireRapids}, + {"dpwssdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_SapphireRapids}, }; // clang-format on @@ -505,13 +515,14 @@ void CodeGen_X86::visit(const Call *op) { } void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init) { - if (op->op != VectorReduce::Add) { + if (op->op != VectorReduce::Add && op->op != VectorReduce::SaturatingAdd) { CodeGen_Posix::codegen_vector_reduce(op, init); return; } const int factor = op->value.type().lanes() / op->type.lanes(); struct Pattern { + VectorReduce::Operator reduce_op; int factor; Expr pattern; const char *intrin; @@ -524,15 +535,18 @@ void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init }; // clang-format off static const Pattern patterns[] = { - {2, wild_f32x_ * wild_f32x_, "dot_product", BFloat(16), Pattern::CombineInit}, - {2, i32(widening_mul(wild_i16x_, wild_i16x_)), "dot_product", {}, Pattern::CombineInit}, - {4, i32(widening_mul(wild_u8x_, wild_i8x_)), "dot_product", {}, Pattern::CombineInit}, - {4, i32(widening_mul(wild_i8x_, wild_u8x_)), "dot_product", {}, Pattern::CombineInit | Pattern::SwapOperands}, - {2, i32(widening_mul(wild_i16x_, wild_i16x_)), "pmaddwd", Int(16)}, - {2, i32(widening_mul(wild_i8x_, wild_i8x_)), "pmaddwd", Int(16)}, - {2, i32(widening_mul(wild_i8x_, wild_u8x_)), "pmaddwd", Int(16)}, - {2, i32(widening_mul(wild_u8x_, wild_i8x_)), "pmaddwd", Int(16)}, - {2, i32(widening_mul(wild_u8x_, wild_u8x_)), "pmaddwd", Int(16)}, + {VectorReduce::Add, 2, wild_f32x_ * wild_f32x_, "dot_product", BFloat(16), Pattern::CombineInit}, + {VectorReduce::Add, 2, i32(widening_mul(wild_i16x_, wild_i16x_)), "dot_product", {}, Pattern::CombineInit}, + {VectorReduce::Add, 4, i32(widening_mul(wild_u8x_, wild_i8x_)), "dot_product", {}, Pattern::CombineInit}, + {VectorReduce::Add, 4, i32(widening_mul(wild_i8x_, wild_u8x_)), "dot_product", {}, Pattern::CombineInit | Pattern::SwapOperands}, + {VectorReduce::SaturatingAdd, 2, i32(widening_mul(wild_i16x_, wild_i16x_)), "saturating_dot_product", {}, Pattern::CombineInit}, + {VectorReduce::SaturatingAdd, 4, i32(widening_mul(wild_u8x_, wild_i8x_)), "saturating_dot_product", {}, Pattern::CombineInit}, + {VectorReduce::SaturatingAdd, 4, i32(widening_mul(wild_i8x_, wild_u8x_)), "saturating_dot_product", {}, Pattern::CombineInit | Pattern::SwapOperands}, + {VectorReduce::Add, 2, i32(widening_mul(wild_i16x_, wild_i16x_)), "pmaddwd", Int(16)}, + {VectorReduce::Add, 2, i32(widening_mul(wild_i8x_, wild_i8x_)), "pmaddwd", Int(16)}, + {VectorReduce::Add, 2, i32(widening_mul(wild_i8x_, wild_u8x_)), "pmaddwd", Int(16)}, + {VectorReduce::Add, 2, i32(widening_mul(wild_u8x_, wild_i8x_)), "pmaddwd", Int(16)}, + {VectorReduce::Add, 2, i32(widening_mul(wild_u8x_, wild_u8x_)), "pmaddwd", Int(16)}, // One could do a horizontal widening addition with // pmaddwd against a vector of ones. Currently disabled // because I haven't found case where it's clearly better. @@ -541,7 +555,7 @@ void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init std::vector matches; for (const Pattern &p : patterns) { - if (p.factor != factor) { + if (op->op != p.reduce_op || p.factor != factor) { continue; } if (expr_match(p.pattern, op->value, matches)) { diff --git a/src/IR.h b/src/IR.h index ce5e16cac996..bdaf3c7f6e86 100644 --- a/src/IR.h +++ b/src/IR.h @@ -867,6 +867,7 @@ struct VectorReduce : public ExprNode { // operators. typedef enum { Add, + SaturatingAdd, Mul, Min, Max, diff --git a/src/IRPrinter.cpp b/src/IRPrinter.cpp index 62fdb705997b..b02cc861ce35 100644 --- a/src/IRPrinter.cpp +++ b/src/IRPrinter.cpp @@ -279,6 +279,9 @@ ostream &operator<<(ostream &out, const VectorReduce::Operator &op) { case VectorReduce::Add: out << "Add"; break; + case VectorReduce::SaturatingAdd: + out << "SaturatingAdd"; + break; case VectorReduce::Mul: out << "Mul"; break; diff --git a/src/InlineReductions.cpp b/src/InlineReductions.cpp index 10714a4a1358..cdeb43ce79a2 100644 --- a/src/InlineReductions.cpp +++ b/src/InlineReductions.cpp @@ -123,6 +123,22 @@ Expr sum(const RDom &r, Expr e, const std::string &name) { return f(v.call_args); } +Expr saturating_sum(Expr e, const std::string &name) { + return saturating_sum(RDom(), std::move(e), name); +} + +Expr saturating_sum(const RDom &r, Expr e, const std::string &name) { + Internal::FindFreeVars v(r, name); + e = v.mutate(common_subexpression_elimination(e)); + + user_assert(v.rdom.defined()) << "Expression passed to saturating_sum must reference a reduction domain"; + + Func f(name); + f(v.free_vars) = 0; + f(v.free_vars) = Internal::saturating_add(f(v.free_vars), e); + return f(v.call_args); +} + Expr product(Expr e, const std::string &name) { return product(RDom(), std::move(e), name); } diff --git a/src/InlineReductions.h b/src/InlineReductions.h index 1b953e17162e..143ee8cebe51 100644 --- a/src/InlineReductions.h +++ b/src/InlineReductions.h @@ -36,6 +36,7 @@ namespace Halide { */ //@{ Expr sum(Expr, const std::string &s = "sum"); +Expr saturating_sum(Expr, const std::string &s = "saturating_sum"); Expr product(Expr, const std::string &s = "product"); Expr maximum(Expr, const std::string &s = "maximum"); Expr minimum(Expr, const std::string &s = "minimum"); @@ -52,6 +53,7 @@ Expr minimum(Expr, const std::string &s = "minimum"); */ // @{ Expr sum(const RDom &, Expr, const std::string &s = "sum"); +Expr saturating_sum(const RDom &r, Expr e, const std::string &s = "saturating_sum"); Expr product(const RDom &, Expr, const std::string &s = "product"); Expr maximum(const RDom &, Expr, const std::string &s = "maximum"); Expr minimum(const RDom &, Expr, const std::string &s = "minimum"); diff --git a/src/Monotonic.cpp b/src/Monotonic.cpp index 7383ed377dff..dd81eec80ad9 100644 --- a/src/Monotonic.cpp +++ b/src/Monotonic.cpp @@ -536,6 +536,7 @@ class DerivativeBounds : public IRVisitor { op->value.accept(this); switch (op->op) { case VectorReduce::Add: + case VectorReduce::SaturatingAdd: result = multiply(result, op->value.type().lanes() / op->type.lanes()); break; case VectorReduce::Min: diff --git a/src/Simplify_Exprs.cpp b/src/Simplify_Exprs.cpp index d0b2da780e40..9479e324e9a5 100644 --- a/src/Simplify_Exprs.cpp +++ b/src/Simplify_Exprs.cpp @@ -74,6 +74,14 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *bounds) { bounds->max *= factor; } break; + case VectorReduce::SaturatingAdd: + if (bounds->min_defined) { + bounds->min = saturating_mul(bounds->min, factor); + } + if (bounds->max_defined) { + bounds->max = saturating_mul(bounds->max, factor); + } + break; case VectorReduce::Mul: // Don't try to infer anything about bounds. Leave the // alignment unchanged even though we could theoretically diff --git a/src/Simplify_Internal.h b/src/Simplify_Internal.h index f18e29645ac7..1a59ccd32821 100644 --- a/src/Simplify_Internal.h +++ b/src/Simplify_Internal.h @@ -28,6 +28,18 @@ namespace Halide { namespace Internal { +inline int64_t saturating_mul(int64_t a, int64_t b) { + if (mul_would_overflow(64, a, b)) { + if ((a > 0) == (b > 0)) { + return INT64_MAX; + } else { + return INT64_MIN; + } + } else { + return a * b; + } +} + class Simplify : public VariadicVisitor { using Super = VariadicVisitor; diff --git a/src/Simplify_Mul.cpp b/src/Simplify_Mul.cpp index 08c194002316..631ad7b99d5f 100644 --- a/src/Simplify_Mul.cpp +++ b/src/Simplify_Mul.cpp @@ -3,20 +3,6 @@ namespace Halide { namespace Internal { -namespace { -int64_t saturating_mul(int64_t a, int64_t b) { - if (mul_would_overflow(64, a, b)) { - if ((a > 0) == (b > 0)) { - return INT64_MAX; - } else { - return INT64_MIN; - } - } else { - return a * b; - } -} -} // namespace - Expr Simplify::visit(const Mul *op, ExprInfo *bounds) { ExprInfo a_bounds, b_bounds; Expr a = mutate(op->a, &a_bounds); diff --git a/src/VectorizeLoops.cpp b/src/VectorizeLoops.cpp index e31126b4949e..2e56a53d2da8 100644 --- a/src/VectorizeLoops.cpp +++ b/src/VectorizeLoops.cpp @@ -1093,6 +1093,12 @@ class VectorSubs : public IRMutator { reduce_op = VectorReduce::Or; } } + } else if (const Call *call_op = store->value.as()) { + if (call_op->is_intrinsic(Call::saturating_add)) { + a = call_op->args[0]; + b = call_op->args[1]; + reduce_op = VectorReduce::SaturatingAdd; + } } if (!a.defined() || !b.defined()) { @@ -1175,6 +1181,8 @@ class VectorSubs : public IRMutator { return a && b; case VectorReduce::Or: return a || b; + case VectorReduce::SaturatingAdd: + return saturating_add(a, b); } return Expr(); }; diff --git a/src/runtime/x86_avx512.ll b/src/runtime/x86_avx512.ll index 904fabe9368e..730014c99d9a 100644 --- a/src/runtime/x86_avx512.ll +++ b/src/runtime/x86_avx512.ll @@ -90,3 +90,51 @@ define weak_odr <4 x i32> @dpwssdx4(<4 x i32> %init, <8 x i16> %a, <8 x i16> %b ret <4 x i32> %3 } declare <4 x i32> @llvm.x86.avx512.vpdpwssd.128(<4 x i32>, <4 x i32>, <4 x i32>) + +define weak_odr <16 x i32> @dpbusdsx16(<16 x i32> %init, <64 x i8> %a, <64 x i8> %b) nounwind alwaysinline { + %1 = bitcast <64 x i8> %a to <16 x i32> + %2 = bitcast <64 x i8> %b to <16 x i32> + %3 = tail call <16 x i32> @llvm.x86.avx512.vpdpbusds.512(<16 x i32> %init, <16 x i32> %1, <16 x i32> %2) + ret <16 x i32> %3 +} +declare <16 x i32> @llvm.x86.avx512.vpdpbusds.512(<16 x i32>, <16 x i32>, <16 x i32>) + +define weak_odr <8 x i32> @dpbusdsx8(<8 x i32> %init, <32 x i8> %a, <32 x i8> %b) nounwind alwaysinline { + %1 = bitcast <32 x i8> %a to <8 x i32> + %2 = bitcast <32 x i8> %b to <8 x i32> + %3 = tail call <8 x i32> @llvm.x86.avx512.vpdpbusds.256(<8 x i32> %init, <8 x i32> %1, <8 x i32> %2) + ret <8 x i32> %3 +} +declare <8 x i32> @llvm.x86.avx512.vpdpbusds.256(<8 x i32>, <8 x i32>, <8 x i32>) + +define weak_odr <4 x i32> @dpbusdsx4(<4 x i32> %init, <16 x i8> %a, <16 x i8> %b) nounwind alwaysinline { + %1 = bitcast <16 x i8> %a to <4 x i32> + %2 = bitcast <16 x i8> %b to <4 x i32> + %3 = tail call <4 x i32> @llvm.x86.avx512.vpdpbusds.128(<4 x i32> %init, <4 x i32> %1, <4 x i32> %2) + ret <4 x i32> %3 +} +declare <4 x i32> @llvm.x86.avx512.vpdpbusds.128(<4 x i32>, <4 x i32>, <4 x i32>) + +define weak_odr <16 x i32> @dpwssdsx16(<16 x i32> %init, <32 x i16> %a, <32 x i16> %b) nounwind alwaysinline { + %1 = bitcast <32 x i16> %a to <16 x i32> + %2 = bitcast <32 x i16> %b to <16 x i32> + %3 = tail call <16 x i32> @llvm.x86.avx512.vpdpwssds.512(<16 x i32> %init, <16 x i32> %1, <16 x i32> %2) + ret <16 x i32> %3 +} +declare <16 x i32> @llvm.x86.avx512.vpdpwssds.512(<16 x i32>, <16 x i32>, <16 x i32>) + +define weak_odr <8 x i32> @dpwssdsx8(<8 x i32> %init, <16 x i16> %a, <16 x i16> %b) nounwind alwaysinline { + %1 = bitcast <16 x i16> %a to <8 x i32> + %2 = bitcast <16 x i16> %b to <8 x i32> + %3 = tail call <8 x i32> @llvm.x86.avx512.vpdpwssds.256(<8 x i32> %init, <8 x i32> %1, <8 x i32> %2) + ret <8 x i32> %3 +} +declare <8 x i32> @llvm.x86.avx512.vpdpwssds.256(<8 x i32>, <8 x i32>, <8 x i32>) + +define weak_odr <4 x i32> @dpwssdsx4(<4 x i32> %init, <8 x i16> %a, <8 x i16> %b) nounwind alwaysinline { + %1 = bitcast <8 x i16> %a to <4 x i32> + %2 = bitcast <8 x i16> %b to <4 x i32> + %3 = tail call <4 x i32> @llvm.x86.avx512.vpdpwssds.128(<4 x i32> %init, <4 x i32> %1, <4 x i32> %2) + ret <4 x i32> %3 +} +declare <4 x i32> @llvm.x86.avx512.vpdpwssds.128(<4 x i32>, <4 x i32>, <4 x i32>) diff --git a/test/correctness/simd_op_check.cpp b/test/correctness/simd_op_check.cpp index cc5d1186cb05..d50978c89ebc 100644 --- a/test/correctness/simd_op_check.cpp +++ b/test/correctness/simd_op_check.cpp @@ -533,6 +533,23 @@ class SimdOpCheck : public SimdOpCheckTest { check("vpdpbusd*xmm", 4, sum(i32(in_u8(4 * x + r)) * in_i8(4 * x + r + 32))); check("vpdpbusd*xmm", 4, sum(i32(in_i8(4 * x + r)) * in_u8(4 * x + r + 32))); } + { + // 16 bit, 2 element saturaing dot product + RDom r(0, 2); + check("vpdpwssds*zmm", 16, saturating_sum(i32(in_i16(2 * x + r)) * in_i16(2 * x + r + 32))); + check("vpdpwssds*ymm", 8, saturating_sum(i32(in_i16(2 * x + r)) * in_i16(2 * x + r + 32))); + check("vpdpwssds*xmm", 4, saturating_sum(i32(in_i16(2 * x + r)) * in_i16(2 * x + r + 32))); + } + { + // 8 bit, 4 element saturating dot product + RDom r(0, 4); + check("vpdpbusds*zmm", 16, saturating_sum(i32(in_u8(4 * x + r)) * in_i8(4 * x + r + 32))); + check("vpdpbusds*zmm", 16, saturating_sum(i32(in_i8(4 * x + r)) * in_u8(4 * x + r + 32))); + check("vpdpbusds*ymm", 8, saturating_sum(i32(in_u8(4 * x + r)) * in_i8(4 * x + r + 32))); + check("vpdpbusds*ymm", 8, saturating_sum(i32(in_i8(4 * x + r)) * in_u8(4 * x + r + 32))); + check("vpdpbusds*xmm", 4, saturating_sum(i32(in_u8(4 * x + r)) * in_i8(4 * x + r + 32))); + check("vpdpbusds*xmm", 4, saturating_sum(i32(in_i8(4 * x + r)) * in_u8(4 * x + r + 32))); + } } } diff --git a/test/correctness/simd_op_check.h b/test/correctness/simd_op_check.h index 43786143aecd..49ec21f66c7e 100644 --- a/test/correctness/simd_op_check.h +++ b/test/correctness/simd_op_check.h @@ -208,7 +208,7 @@ class SimdOpCheckTest { g.compute_at(f, x) .update() .split(x, xo, xi, vector_width) - .atomic() + .atomic(true) .vectorize(g.rvars()[0]) .vectorize(xi); }