Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/Simplify_Add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,16 @@ Expr Simplify::visit(const Add *op, ExprInfo *info) {

if (rewrite(IRMatcher::Overflow() + x, a) ||
rewrite(x + IRMatcher::Overflow(), b) ||
rewrite(x + 0, x) ||
rewrite(0 + x, x)) {
rewrite(x + 0, a) ||
rewrite(0 + x, b)) {
if (info) {
if (rewrite.result.same_as(a)) {
info->intersect(a_info);
} else {
internal_assert(rewrite.result.same_as(b));
info->intersect(b_info);
}
}
return rewrite.result;
}

Expand Down
2 changes: 1 addition & 1 deletion src/Simplify_Cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Expr Simplify::visit(const Cast *op, ExprInfo *info) {
// It's possible we just reduced to a constant. E.g. if we cast an
// even number to uint1 we get zero.
if (value_info.bounds.is_single_point()) {
return make_const(op->type, value_info.bounds.min, nullptr);
return make_const(op->type, value_info.bounds.min, info);
}
}

Expand Down
8 changes: 7 additions & 1 deletion src/Simplify_Exprs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,17 @@ Expr Simplify::visit(const IntImm *op, ExprInfo *info) {
}

Expr Simplify::visit(const UIntImm *op, ExprInfo *info) {
if (info && Int(64).can_represent(op->value)) {
if (info) {
// Pretend it's an int constant that has been cast to uint.
int64_t v = (int64_t)(op->value);
info->bounds = ConstantInterval::single_point(v);
info->alignment = ModulusRemainder(0, v);
// If it's not representable as an int64, this will wrap the alignment appropriately:
info->cast_to(op->type);
// Be as informative as we can with bounds for out-of-range uint64s
if ((int64_t)op->value < 0) {
info->bounds = ConstantInterval::bounded_below(INT64_MAX);
}
} else {
clear_expr_info(info);
}
Expand Down
34 changes: 32 additions & 2 deletions src/Simplify_Internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,34 @@ class Simplify : public VariadicVisitor<Simplify, Expr, Stmt> {
}
}

// Truncate the bounds to the new type.
bounds.cast_to(t);
// We have to take special care with uint64, because their bounds
// and alignment may not be representable with ModulusRemainder and
// ConstantInterval.
if (t.bits() == 64 && t.is_uint()) {
// For UInt64 constants, the remainder might not be representable as an int64
if (alignment.modulus == 0 && alignment.remainder < 0) {
// Forget the leading two bits to get a representable modulus
// and remainder.
alignment.modulus = (int64_t)1 << 62;
alignment.remainder = alignment.remainder & (alignment.modulus - 1);
}

int64_t old_min = bounds.min;
bounds.cast_to(t);
if (bounds.min_defined && old_min > 0) {
// We don't want to lose a known positive min value for
// uint64s. In general a ConstantInterval represents
// infinite-precision integer intervals, and a cast from an infinite
// precision integer to a uint64 could overflow. However, in the
// simplifier, ConstantIntervals are used to represent bounds on the
// values a Halide::Expr could take on, and for all Halide Expr
// types, casting to a uint64_t can't overflow at the top end
// (e.g. double casts to uint64_t saturate).
bounds.min = old_min;
}
} else {
bounds.cast_to(t);
}
}

// Mix in existing knowledge about this Expr
Expand Down Expand Up @@ -241,6 +267,10 @@ class Simplify : public VariadicVisitor<Simplify, Expr, Stmt> {
// We never want to return make_const anything in the simplifier without
// also setting the ExprInfo, so shadow the global make_const.
Expr make_const(const Type &t, int64_t c, ExprInfo *info) {
if (t.is_uint() && c < 0) {
// Wrap it around
return make_const(t, (uint64_t)c, info);
}
c = normalize_constant(t, c);
set_expr_info_to_constant(info, c);
return Halide::Internal::make_const(t, c);
Expand Down
17 changes: 10 additions & 7 deletions src/Simplify_Max.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ Expr Simplify::visit(const Max *op, ExprInfo *info) {
if (max_info.bounds.is_single_point()) {
// This is possible when, for example, the largest number in the type
// that satisfies the alignment of the left-hand-side is smaller than
// the min value of the right-hand-side.
return make_const(op->type, max_info.bounds.min, nullptr);
// the min value of the right-hand-side. Reinferring the info can
// potentially give us something tighter than what was computed above if
// it's a large uint64.
return make_const(op->type, max_info.bounds.min, info);
}

auto strip_likely = [](const Expr &e) {
Expand Down Expand Up @@ -65,10 +67,10 @@ Expr Simplify::visit(const Max *op, ExprInfo *info) {
return rewrite.result;
}

// Cases where one side dominates. All of these must reduce to a or b in the
// RHS for ExprInfo to update correctly.
if (EVAL_IN_LAMBDA //
(rewrite(max(x, x), a) ||
rewrite(max(c0, c1), fold(max(c0, c1))) ||
// Cases where one side dominates:
rewrite(max(x, c0), b, is_max_value(c0)) ||
rewrite(max(x, c0), a, is_min_value(c0)) ||
rewrite(max((x / c0) * c0, x), b, c0 > 0) ||
Expand Down Expand Up @@ -148,16 +150,17 @@ Expr Simplify::visit(const Max *op, ExprInfo *info) {
// than just applying max to two constant intervals.
if (rewrite.result.same_as(a)) {
info->intersect(a_info);
} else if (rewrite.result.same_as(b)) {
} else {
internal_assert(rewrite.result.same_as(b));
info->intersect(b_info);
}
}

return rewrite.result;
}

if (EVAL_IN_LAMBDA //
(rewrite(max(max(x, c0), c1), max(x, fold(max(c0, c1)))) ||
(rewrite(max(c0, c1), fold(max(c0, c1))) ||
rewrite(max(max(x, c0), c1), max(x, fold(max(c0, c1)))) ||
rewrite(max(max(x, c0), y), max(max(x, y), c0)) ||
rewrite(max(max(x, y), max(x, z)), max(max(y, z), x)) ||
rewrite(max(max(y, x), max(x, z)), max(max(y, z), x)) ||
Expand Down
12 changes: 7 additions & 5 deletions src/Simplify_Min.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Expr Simplify::visit(const Min *op, ExprInfo *info) {
// This is possible when, for example, the smallest number in the type
// that satisfies the alignment of the left-hand-side is greater than
// the max value of the right-hand-side.
return make_const(op->type, min_info.bounds.min, nullptr);
return make_const(op->type, min_info.bounds.min, info);
}

// Early out when the bounds tells us one side or the other is smaller
Expand Down Expand Up @@ -66,10 +66,10 @@ Expr Simplify::visit(const Min *op, ExprInfo *info) {
return rewrite.result;
}

// Cases where one side dominates. All of these must reduce to a or b in the
// RHS for ExprInfo to update correctly.
if (EVAL_IN_LAMBDA //
(rewrite(min(x, x), a) ||
rewrite(min(c0, c1), fold(min(c0, c1))) ||
// Cases where one side dominates:
rewrite(min(x, c0), b, is_min_value(c0)) ||
rewrite(min(x, c0), a, is_max_value(c0)) ||
rewrite(min((x / c0) * c0, x), a, c0 > 0) ||
Expand Down Expand Up @@ -148,15 +148,17 @@ Expr Simplify::visit(const Min *op, ExprInfo *info) {
if (info) {
if (rewrite.result.same_as(a)) {
info->intersect(a_info);
} else if (rewrite.result.same_as(b)) {
} else {
internal_assert(rewrite.result.same_as(b));
info->intersect(b_info);
}
}
return rewrite.result;
}

if (EVAL_IN_LAMBDA //
(rewrite(min(min(x, c0), c1), min(x, fold(min(c0, c1)))) ||
(rewrite(min(c0, c1), fold(min(c0, c1))) ||
rewrite(min(min(x, c0), c1), min(x, fold(min(c0, c1)))) ||
rewrite(min(min(x, c0), y), min(min(x, y), c0)) ||
rewrite(min(min(x, y), min(x, z)), min(min(y, z), x)) ||
rewrite(min(min(y, x), min(x, z)), min(min(y, z), x)) ||
Expand Down
16 changes: 12 additions & 4 deletions src/Simplify_Mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,18 @@ Expr Simplify::visit(const Mul *op, ExprInfo *info) {
return rewrite.result;
}

if (rewrite(0 * x, 0) ||
rewrite(1 * x, x) ||
rewrite(x * 0, 0) ||
rewrite(x * 1, x)) {
if (rewrite(0 * x, a) ||
rewrite(1 * x, b) ||
rewrite(x * 0, b) ||
rewrite(x * 1, a)) {
if (info) {
if (rewrite.result.same_as(a)) {
info->intersect(a_info);
} else {
internal_assert(rewrite.result.same_as(b));
info->intersect(b_info);
}
}
return rewrite.result;
}

Expand Down
3 changes: 2 additions & 1 deletion src/Simplify_Select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ Expr Simplify::visit(const Select *op, ExprInfo *info) {
if (info) {
if (rewrite.result.same_as(true_value)) {
*info = t_info;
} else if (rewrite.result.same_as(false_value)) {
} else {
internal_assert(rewrite.result.same_as(false_value));
*info = f_info;
}
}
Expand Down
10 changes: 9 additions & 1 deletion src/Simplify_Sub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,15 @@ Expr Simplify::visit(const Sub *op, ExprInfo *info) {

if (rewrite(IRMatcher::Overflow() - x, a) ||
rewrite(x - IRMatcher::Overflow(), b) ||
rewrite(x - 0, x)) {
rewrite(x - 0, a)) {
if (info) {
if (rewrite.result.same_as(a)) {
info->intersect(a_info);
} else {
internal_assert(rewrite.result.same_as(b));
info->intersect(b_info);
}
}
return rewrite.result;
}

Expand Down
2 changes: 0 additions & 2 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ tests(GROUPS correctness
fused_where_inner_extent_is_zero.cpp
fuzz_float_stores.cpp
fuzz_schedule.cpp
fuzz_simplify.cpp
gameoflife.cpp
gather.cpp
gpu_alloc_group_profiling.cpp
Expand Down Expand Up @@ -356,7 +355,6 @@ tests(GROUPS correctness
vectorized_initialization.cpp
vectorized_load_from_vectorized_allocation.cpp
vectorized_reduction_bug.cpp
widening_lerp.cpp
widening_reduction.cpp
# keep-sorted end
)
Expand Down
Loading
Loading