Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1582,6 +1582,7 @@ class Bounds : public IRVisitor {
interval.min *= factor;
}
break;
case VectorReduce::SaturatingAdd:
case VectorReduce::Mul:
// Technically there are some things we could say
// here. E.g. if all the lanes are positive then we're
Expand Down
67 changes: 38 additions & 29 deletions src/CodeGen_LLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1437,20 +1437,20 @@ void CodeGen_LLVM::visit(const Variable *op) {
}

template<typename Op>
bool CodeGen_LLVM::try_to_fold_vector_reduce(const Op *op) {
const VectorReduce *red = op->a.template as<VectorReduce>();
Expr b = op->b;
bool CodeGen_LLVM::try_to_fold_vector_reduce(const Expr &a, Expr b) {
const VectorReduce *red = a.as<VectorReduce>();
if (!red) {
red = op->b.template as<VectorReduce>();
b = op->a;
red = b.as<VectorReduce>();
b = a;
}
if (red &&
((std::is_same<Op, Add>::value && red->op == VectorReduce::Add) ||
(std::is_same<Op, Min>::value && red->op == VectorReduce::Min) ||
(std::is_same<Op, Max>::value && red->op == VectorReduce::Max) ||
(std::is_same<Op, Mul>::value && red->op == VectorReduce::Mul) ||
(std::is_same<Op, And>::value && red->op == VectorReduce::And) ||
(std::is_same<Op, Or>::value && red->op == VectorReduce::Or))) {
(std::is_same<Op, Or>::value && red->op == VectorReduce::Or) ||
(std::is_same<Op, Call>::value && red->op == VectorReduce::SaturatingAdd))) {
codegen_vector_reduce(red, b);
return true;
}
Expand All @@ -1465,7 +1465,7 @@ void CodeGen_LLVM::visit(const Add *op) {
}

// Some backends can fold the add into a vector reduce
if (try_to_fold_vector_reduce(op)) {
if (try_to_fold_vector_reduce<Add>(op->a, op->b)) {
return;
}

Expand Down Expand Up @@ -1509,7 +1509,7 @@ void CodeGen_LLVM::visit(const Mul *op) {
return;
}

if (try_to_fold_vector_reduce(op)) {
if (try_to_fold_vector_reduce<Mul>(op->a, op->b)) {
return;
}

Expand Down Expand Up @@ -1569,7 +1569,7 @@ void CodeGen_LLVM::visit(const Min *op) {
return;
}

if (try_to_fold_vector_reduce(op)) {
if (try_to_fold_vector_reduce<Min>(op->a, op->b)) {
return;
}

Expand All @@ -1589,7 +1589,7 @@ void CodeGen_LLVM::visit(const Max *op) {
return;
}

if (try_to_fold_vector_reduce(op)) {
if (try_to_fold_vector_reduce<Max>(op->a, op->b)) {
return;
}

Expand Down Expand Up @@ -1708,7 +1708,7 @@ void CodeGen_LLVM::visit(const GE *op) {
}

void CodeGen_LLVM::visit(const And *op) {
if (try_to_fold_vector_reduce(op)) {
if (try_to_fold_vector_reduce<And>(op->a, op->b)) {
return;
}

Expand All @@ -1718,7 +1718,7 @@ void CodeGen_LLVM::visit(const And *op) {
}

void CodeGen_LLVM::visit(const Or *op) {
if (try_to_fold_vector_reduce(op)) {
if (try_to_fold_vector_reduce<Or>(op->a, op->b)) {
return;
}

Expand Down Expand Up @@ -2806,24 +2806,30 @@ void CodeGen_LLVM::visit(const Call *op) {
}
} else if (op->is_intrinsic(Call::saturating_add) || op->is_intrinsic(Call::saturating_sub)) {
internal_assert(op->args.size() == 2);
std::string intrin;
if (op->type.is_int()) {
intrin = "llvm.s";
} else {
internal_assert(op->type.is_uint());
intrin = "llvm.u";
}
if (op->is_intrinsic(Call::saturating_add)) {
intrin += "add.sat.";
} else {
internal_assert(op->is_intrinsic(Call::saturating_sub));
intrin += "sub.sat.";
}
if (op->type.lanes() > 1) {
intrin += "v" + std::to_string(op->type.lanes());

// Try to fold the vector reduce for a call to saturating_add
const bool folded = try_to_fold_vector_reduce<Call>(op->args[0], op->args[1]);

if (!folded) {
std::string intrin;
if (op->type.is_int()) {
intrin = "llvm.s";
} else {
internal_assert(op->type.is_uint());
intrin = "llvm.u";
}
if (op->is_intrinsic(Call::saturating_add)) {
intrin += "add.sat.";
} else {
internal_assert(op->is_intrinsic(Call::saturating_sub));
intrin += "sub.sat.";
}
if (op->type.lanes() > 1) {
intrin += "v" + std::to_string(op->type.lanes());
}
intrin += "i" + std::to_string(op->type.bits());
value = call_intrin(op->type, op->type.lanes(), intrin, op->args);
}
intrin += "i" + std::to_string(op->type.bits());
value = call_intrin(op->type, op->type.lanes(), intrin, op->args);
} else if (op->is_intrinsic(Call::stringify)) {
internal_assert(!op->args.empty());

Expand Down Expand Up @@ -4371,6 +4377,9 @@ void CodeGen_LLVM::codegen_vector_reduce(const VectorReduce *op, const Expr &ini
case VectorReduce::Or:
binop = Or::make;
break;
case VectorReduce::SaturatingAdd:
binop = saturating_add;
break;
}

if (op->type.is_bool() && op->op == VectorReduce::Or) {
Expand Down
2 changes: 1 addition & 1 deletion src/CodeGen_LLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ class CodeGen_LLVM : public IRVisitor {

/** A helper routine for generating folded vector reductions. */
template<typename Op>
bool try_to_fold_vector_reduce(const Op *op);
bool try_to_fold_vector_reduce(const Expr &a, Expr b);
};

} // namespace Internal
Expand Down
36 changes: 25 additions & 11 deletions src/CodeGen_X86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,22 @@ const x86Intrinsic intrinsic_defs[] = {
{"dpbf16psx16", Float(32, 16), "dot_product", {Float(32, 16), BFloat(16, 32), BFloat(16, 32)}, Target::AVX512_SapphireRapids},
{"dpbf16psx8", Float(32, 8), "dot_product", {Float(32, 8), BFloat(16, 16), BFloat(16, 16)}, Target::AVX512_SapphireRapids},
{"dpbf16psx4", Float(32, 4), "dot_product", {Float(32, 4), BFloat(16, 8), BFloat(16, 8)}, Target::AVX512_SapphireRapids},

{"dpbusdx16", Int(32, 16), "dot_product", {Int(32, 16), UInt(8, 64), Int(8, 64)}, Target::AVX512_SapphireRapids},
{"dpbusdx8", Int(32, 8), "dot_product", {Int(32, 8), UInt(8, 32), Int(8, 32)}, Target::AVX512_SapphireRapids},
{"dpbusdx4", Int(32, 4), "dot_product", {Int(32, 4), UInt(8, 16), Int(8, 16)}, Target::AVX512_SapphireRapids},

{"dpwssdx16", Int(32, 16), "dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_SapphireRapids},
{"dpwssdx8", Int(32, 8), "dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_SapphireRapids},
{"dpwssdx4", Int(32, 4), "dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_SapphireRapids},

{"dpbusdsx16", Int(32, 16), "saturating_dot_product", {Int(32, 16), UInt(8, 64), Int(8, 64)}, Target::AVX512_SapphireRapids},
{"dpbusdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), UInt(8, 32), Int(8, 32)}, Target::AVX512_SapphireRapids},
{"dpbusdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), UInt(8, 16), Int(8, 16)}, Target::AVX512_SapphireRapids},

{"dpwssdsx16", Int(32, 16), "saturating_dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_SapphireRapids},
{"dpwssdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_SapphireRapids},
{"dpwssdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_SapphireRapids},
};
// clang-format on

Expand Down Expand Up @@ -505,13 +515,14 @@ void CodeGen_X86::visit(const Call *op) {
}

void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init) {
if (op->op != VectorReduce::Add) {
if (op->op != VectorReduce::Add && op->op != VectorReduce::SaturatingAdd) {
CodeGen_Posix::codegen_vector_reduce(op, init);
return;
}
const int factor = op->value.type().lanes() / op->type.lanes();

struct Pattern {
VectorReduce::Operator reduce_op;
int factor;
Expr pattern;
const char *intrin;
Expand All @@ -524,15 +535,18 @@ void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init
};
// clang-format off
static const Pattern patterns[] = {
{2, wild_f32x_ * wild_f32x_, "dot_product", BFloat(16), Pattern::CombineInit},
{2, i32(widening_mul(wild_i16x_, wild_i16x_)), "dot_product", {}, Pattern::CombineInit},
{4, i32(widening_mul(wild_u8x_, wild_i8x_)), "dot_product", {}, Pattern::CombineInit},
{4, i32(widening_mul(wild_i8x_, wild_u8x_)), "dot_product", {}, Pattern::CombineInit | Pattern::SwapOperands},
{2, i32(widening_mul(wild_i16x_, wild_i16x_)), "pmaddwd", Int(16)},
{2, i32(widening_mul(wild_i8x_, wild_i8x_)), "pmaddwd", Int(16)},
{2, i32(widening_mul(wild_i8x_, wild_u8x_)), "pmaddwd", Int(16)},
{2, i32(widening_mul(wild_u8x_, wild_i8x_)), "pmaddwd", Int(16)},
{2, i32(widening_mul(wild_u8x_, wild_u8x_)), "pmaddwd", Int(16)},
{VectorReduce::Add, 2, wild_f32x_ * wild_f32x_, "dot_product", BFloat(16), Pattern::CombineInit},
{VectorReduce::Add, 2, i32(widening_mul(wild_i16x_, wild_i16x_)), "dot_product", {}, Pattern::CombineInit},
{VectorReduce::Add, 4, i32(widening_mul(wild_u8x_, wild_i8x_)), "dot_product", {}, Pattern::CombineInit},
{VectorReduce::Add, 4, i32(widening_mul(wild_i8x_, wild_u8x_)), "dot_product", {}, Pattern::CombineInit | Pattern::SwapOperands},
{VectorReduce::SaturatingAdd, 2, i32(widening_mul(wild_i16x_, wild_i16x_)), "saturating_dot_product", {}, Pattern::CombineInit},
{VectorReduce::SaturatingAdd, 4, i32(widening_mul(wild_u8x_, wild_i8x_)), "saturating_dot_product", {}, Pattern::CombineInit},
{VectorReduce::SaturatingAdd, 4, i32(widening_mul(wild_i8x_, wild_u8x_)), "saturating_dot_product", {}, Pattern::CombineInit | Pattern::SwapOperands},
{VectorReduce::Add, 2, i32(widening_mul(wild_i16x_, wild_i16x_)), "pmaddwd", Int(16)},
{VectorReduce::Add, 2, i32(widening_mul(wild_i8x_, wild_i8x_)), "pmaddwd", Int(16)},
{VectorReduce::Add, 2, i32(widening_mul(wild_i8x_, wild_u8x_)), "pmaddwd", Int(16)},
{VectorReduce::Add, 2, i32(widening_mul(wild_u8x_, wild_i8x_)), "pmaddwd", Int(16)},
{VectorReduce::Add, 2, i32(widening_mul(wild_u8x_, wild_u8x_)), "pmaddwd", Int(16)},
// One could do a horizontal widening addition with
// pmaddwd against a vector of ones. Currently disabled
// because I haven't found case where it's clearly better.
Expand All @@ -541,7 +555,7 @@ void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init

std::vector<Expr> matches;
for (const Pattern &p : patterns) {
if (p.factor != factor) {
if (op->op != p.reduce_op || p.factor != factor) {
continue;
}
if (expr_match(p.pattern, op->value, matches)) {
Expand Down
1 change: 1 addition & 0 deletions src/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,7 @@ struct VectorReduce : public ExprNode<VectorReduce> {
// operators.
typedef enum {
Add,
SaturatingAdd,
Mul,
Min,
Max,
Expand Down
3 changes: 3 additions & 0 deletions src/IRPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,9 @@ ostream &operator<<(ostream &out, const VectorReduce::Operator &op) {
case VectorReduce::Add:
out << "Add";
break;
case VectorReduce::SaturatingAdd:
out << "SaturatingAdd";
break;
case VectorReduce::Mul:
out << "Mul";
break;
Expand Down
16 changes: 16 additions & 0 deletions src/InlineReductions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,22 @@ Expr sum(const RDom &r, Expr e, const std::string &name) {
return f(v.call_args);
}

Expr saturating_sum(Expr e, const std::string &name) {
return saturating_sum(RDom(), std::move(e), name);
}

Expr saturating_sum(const RDom &r, Expr e, const std::string &name) {
Internal::FindFreeVars v(r, name);
e = v.mutate(common_subexpression_elimination(e));

user_assert(v.rdom.defined()) << "Expression passed to saturating_sum must reference a reduction domain";

Func f(name);
f(v.free_vars) = 0;
f(v.free_vars) = Internal::saturating_add(f(v.free_vars), e);
return f(v.call_args);
}

Expr product(Expr e, const std::string &name) {
return product(RDom(), std::move(e), name);
}
Expand Down
2 changes: 2 additions & 0 deletions src/InlineReductions.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ namespace Halide {
*/
//@{
Expr sum(Expr, const std::string &s = "sum");
Expr saturating_sum(Expr, const std::string &s = "saturating_sum");
Expr product(Expr, const std::string &s = "product");
Expr maximum(Expr, const std::string &s = "maximum");
Expr minimum(Expr, const std::string &s = "minimum");
Expand All @@ -52,6 +53,7 @@ Expr minimum(Expr, const std::string &s = "minimum");
*/
// @{
Expr sum(const RDom &, Expr, const std::string &s = "sum");
Expr saturating_sum(const RDom &r, Expr e, const std::string &s = "saturating_sum");
Expr product(const RDom &, Expr, const std::string &s = "product");
Expr maximum(const RDom &, Expr, const std::string &s = "maximum");
Expr minimum(const RDom &, Expr, const std::string &s = "minimum");
Expand Down
1 change: 1 addition & 0 deletions src/Monotonic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,7 @@ class DerivativeBounds : public IRVisitor {
op->value.accept(this);
switch (op->op) {
case VectorReduce::Add:
case VectorReduce::SaturatingAdd:
result = multiply(result, op->value.type().lanes() / op->type.lanes());
break;
case VectorReduce::Min:
Expand Down
8 changes: 8 additions & 0 deletions src/Simplify_Exprs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *bounds) {
bounds->max *= factor;
}
break;
case VectorReduce::SaturatingAdd:
if (bounds->min_defined) {
bounds->min = saturating_mul(bounds->min, factor);
}
if (bounds->max_defined) {
bounds->max = saturating_mul(bounds->max, factor);
}
break;
case VectorReduce::Mul:
// Don't try to infer anything about bounds. Leave the
// alignment unchanged even though we could theoretically
Expand Down
12 changes: 12 additions & 0 deletions src/Simplify_Internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@
namespace Halide {
namespace Internal {

inline int64_t saturating_mul(int64_t a, int64_t b) {
if (mul_would_overflow(64, a, b)) {
if ((a > 0) == (b > 0)) {
return INT64_MAX;
} else {
return INT64_MIN;
}
} else {
return a * b;
}
}

class Simplify : public VariadicVisitor<Simplify, Expr, Stmt> {
using Super = VariadicVisitor<Simplify, Expr, Stmt>;

Expand Down
14 changes: 0 additions & 14 deletions src/Simplify_Mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,6 @@
namespace Halide {
namespace Internal {

namespace {
int64_t saturating_mul(int64_t a, int64_t b) {
if (mul_would_overflow(64, a, b)) {
if ((a > 0) == (b > 0)) {
return INT64_MAX;
} else {
return INT64_MIN;
}
} else {
return a * b;
}
}
} // namespace

Expr Simplify::visit(const Mul *op, ExprInfo *bounds) {
ExprInfo a_bounds, b_bounds;
Expr a = mutate(op->a, &a_bounds);
Expand Down
8 changes: 8 additions & 0 deletions src/VectorizeLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,12 @@ class VectorSubs : public IRMutator {
reduce_op = VectorReduce::Or;
}
}
} else if (const Call *call_op = store->value.as<Call>()) {
if (call_op->is_intrinsic(Call::saturating_add)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably run find_intrinsics before vectorize_loops in lowering, otherwise this will not work if people write saturating add as a pattern they expect to match rather than using the (currently Halide::Internal) intrinsic.

This is a fairly big change to make, because the simplifier and other lowering passes don't understand saturating_add (yet? @rootjalex @abadams).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds reasonable, I will try to move the find_intrinsics earlier in the pipeline and I will the tests.

Copy link
Contributor Author

@mcleary mcleary Mar 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dsharletg I placed the find_intrisics pass right before vectorize_loops, thanks for the suggestion (b880a61)

Tests are passing locally, I will keep an eye on the buildbots

a = call_op->args[0];
b = call_op->args[1];
reduce_op = VectorReduce::SaturatingAdd;
}
}

if (!a.defined() || !b.defined()) {
Expand Down Expand Up @@ -1175,6 +1181,8 @@ class VectorSubs : public IRMutator {
return a && b;
case VectorReduce::Or:
return a || b;
case VectorReduce::SaturatingAdd:
return saturating_add(a, b);
}
return Expr();
};
Expand Down
Loading