From d6c0b8eb26208c9c56afd4e4662b3b68cb6364b8 Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Mon, 22 Mar 2021 08:13:09 -0600 Subject: [PATCH 1/4] Add simplifier rules helpful for specialization. --- src/Simplify_Add.cpp | 3 +++ src/Simplify_Max.cpp | 18 +++++++++++-- src/Simplify_Min.cpp | 18 +++++++++++-- src/Simplify_Select.cpp | 51 +++++++++++++++++++++++++++++++++-- src/Simplify_Stmts.cpp | 7 ++++- test/correctness/simplify.cpp | 12 +++++++++ 6 files changed, 102 insertions(+), 7 deletions(-) diff --git a/src/Simplify_Add.cpp b/src/Simplify_Add.cpp index 64ada056c357..c835dbe1c8e7 100644 --- a/src/Simplify_Add.cpp +++ b/src/Simplify_Add.cpp @@ -64,6 +64,9 @@ Expr Simplify::visit(const Add *op, ExprInfo *bounds) { rewrite((broadcast(z, c1) - x) + broadcast(y, c0), broadcast(y + broadcast(z, fold(c1/c0)), c0) - x, c1 % c0 == 0) || rewrite(select(x, y, z) + select(x, w, u), select(x, y + w, z + u)) || rewrite(select(x, c0, c1) + c2, select(x, fold(c0 + c2), fold(c1 + c2))) || + rewrite(select(x, y + c0, c1) + c2, select(x, y + fold(c0 + c2), fold(c1 + c2))) || + rewrite(select(x, c0, z + c1) + c2, select(x, fold(c0 + c2), z + fold(c1 + c2))) || + rewrite(select(x, y + c0, z + c1) + c2, select(x, y + fold(c0 + c2), z + fold(c1 + c2))) || rewrite(ramp(broadcast(x, c0), y, c1) + broadcast(z, c2), ramp(broadcast(x + z, c0), y, c1), c2 == c0 * c1) || rewrite(ramp(ramp(x, y, c0), z, c1) + broadcast(w, c2), ramp(ramp(x + w, y, c0), z, c1), c2 == c0 * c1) || diff --git a/src/Simplify_Max.cpp b/src/Simplify_Max.cpp index 117956108535..0833c415a0f2 100644 --- a/src/Simplify_Max.cpp +++ b/src/Simplify_Max.cpp @@ -173,6 +173,22 @@ Expr Simplify::visit(const Max *op, ExprInfo *bounds) { rewrite(max(x, max(min(x, y), z)), max(x, z)) || rewrite(max(x, max(min(y, x), z)), max(x, z)) || + rewrite(max(select(x, y, z), z), select(x, max(y, z), z)) || + rewrite(max(select(x, y, z), y), select(x, y, max(z, y))) || + rewrite(max(z, select(x, y, z)), select(x, max(z, y), z)) || + rewrite(max(y, select(x, y, z)), select(x, y, max(y, z))) || + + rewrite(max(select(x, min(y, z), w), z), select(x, z, max(w, z))) || + rewrite(max(select(x, min(z, y), w), z), select(x, z, max(w, z))) || + rewrite(max(z, select(x, min(y, z), w)), select(x, z, max(z, w))) || + rewrite(max(z, select(x, min(z, y), w)), select(x, z, max(z, w))) || + rewrite(max(select(x, y, min(w, z)), z), select(x, max(y, z), z)) || + rewrite(max(select(x, y, min(z, w)), z), select(x, max(y, z), z)) || + rewrite(max(z, select(x, y, min(w, z))), select(x, max(z, y), z)) || + rewrite(max(z, select(x, y, min(z, w))), select(x, max(z, y), z)) || + + rewrite(max(select(x, y, z), select(x, w, u)), select(x, max(y, w), max(z, u))) || + (no_overflow(op->type) && (rewrite(max(max(x, y) + c0, x), max(x, y + c0), c0 < 0) || rewrite(max(max(x, y) + c0, x), max(x, y) + c0, c0 > 0) || @@ -268,8 +284,6 @@ Expr Simplify::visit(const Max *op, ExprInfo *bounds) { rewrite(max(((x + c0) / c1) * c1, x + c2), ((x + c0) / c1) * c1, c1 > 0 && c0 + 1 >= c1 + c2) || - rewrite(max(select(x, y, z), select(x, w, u)), select(x, max(y, w), max(z, u))) || - rewrite(max(c0 - x, c1), c0 - min(x, fold(c0 - c1))))))) { return mutate(rewrite.result, bounds); diff --git a/src/Simplify_Min.cpp b/src/Simplify_Min.cpp index 6fca6a52d11b..252ed8119b3c 100644 --- a/src/Simplify_Min.cpp +++ b/src/Simplify_Min.cpp @@ -176,6 +176,22 @@ Expr Simplify::visit(const Min *op, ExprInfo *bounds) { rewrite(min(min(max(x, y), z), x), min(z, x)) || rewrite(min(min(max(x, y), z), y), min(z, y)) || + rewrite(min(select(x, y, z), z), select(x, min(y, z), z)) || + rewrite(min(select(x, y, z), y), select(x, y, min(z, y))) || + rewrite(min(z, select(x, y, z)), select(x, min(z, y), z)) || + rewrite(min(y, select(x, y, z)), select(x, y, min(y, z))) || + + rewrite(min(select(x, max(y, z), w), z), select(x, z, min(w, z))) || + rewrite(min(select(x, max(z, y), w), z), select(x, z, min(w, z))) || + rewrite(min(z, select(x, max(y, z), w)), select(x, z, min(z, w))) || + rewrite(min(z, select(x, max(z, y), w)), select(x, z, min(z, w))) || + rewrite(min(select(x, y, max(w, z)), z), select(x, min(y, z), z)) || + rewrite(min(select(x, y, max(z, w)), z), select(x, min(y, z), z)) || + rewrite(min(z, select(x, y, max(w, z))), select(x, min(z, y), z)) || + rewrite(min(z, select(x, y, max(z, w))), select(x, min(z, y), z)) || + + rewrite(min(select(x, y, z), select(x, w, u)), select(x, min(y, w), min(z, u))) || + (no_overflow(op->type) && (rewrite(min(min(x, y) + c0, x), min(x, y + c0), c0 > 0) || rewrite(min(min(x, y) + c0, x), min(x, y) + c0, c0 < 0) || @@ -273,8 +289,6 @@ Expr Simplify::visit(const Min *op, ExprInfo *bounds) { rewrite(min(((x + c0) / c1) * c1, x + c2), x + c2, c1 > 0 && c0 + 1 >= c1 + c2) || - rewrite(min(select(x, y, z), select(x, w, u)), select(x, min(y, w), min(z, u))) || - rewrite(min(c0 - x, c1), c0 - max(x, fold(c0 - c1))) || // Required for nested GuardWithIf tilings diff --git a/src/Simplify_Select.cpp b/src/Simplify_Select.cpp index 34f3de0f778b..f761f110f07f 100644 --- a/src/Simplify_Select.cpp +++ b/src/Simplify_Select.cpp @@ -7,8 +7,55 @@ Expr Simplify::visit(const Select *op, ExprInfo *bounds) { ExprInfo t_bounds, f_bounds; Expr condition = mutate(op->condition, nullptr); - Expr true_value = mutate(op->true_value, &t_bounds); - Expr false_value = mutate(op->false_value, &f_bounds); + + Expr true_value, false_value; + Expr false_value_when_true, true_value_when_false; + { + auto f = scoped_truth(condition); + + true_value = mutate(op->true_value, &t_bounds); + Expr learned_true_value = f.substitute_facts(true_value); + if (!learned_true_value.same_as(true_value)) { + true_value = mutate(learned_true_value, &t_bounds); + } + + false_value_when_true = mutate(op->false_value, nullptr); + Expr learned_false_value_when_true = f.substitute_facts(false_value_when_true); + if (!learned_false_value_when_true.same_as(false_value_when_true)) { + false_value_when_true = mutate(learned_false_value_when_true, nullptr); + } + } + { + auto f = scoped_falsehood(condition); + + false_value = mutate(op->false_value, &f_bounds); + Expr learned_false_value = f.substitute_facts(false_value); + if (!learned_false_value.same_as(false_value)) { + false_value = mutate(learned_false_value, &f_bounds); + } + + true_value_when_false = mutate(op->true_value, nullptr); + Expr learned_true_value_when_false = f.substitute_facts(true_value_when_false); + if (!learned_true_value_when_false.same_as(true_value_when_false)) { + true_value_when_false = mutate(learned_true_value_when_false, nullptr); + } + } + + // If the false value when the condition is equal to the true value, + // the value is always equal to the false value. This simplifies + // things like select(x == 1, y, y*x) to y*x. + if (equal(false_value_when_true, true_value)) { + if (bounds) { + *bounds = f_bounds; + } + return false_value; + } + if (equal(true_value_when_false, false_value)) { + if (bounds) { + *bounds = t_bounds; + } + return true_value; + } if (bounds) { bounds->min_defined = t_bounds.min_defined && f_bounds.min_defined; diff --git a/src/Simplify_Stmts.cpp b/src/Simplify_Stmts.cpp index 3c53b2b34a66..c991c10729d3 100644 --- a/src/Simplify_Stmts.cpp +++ b/src/Simplify_Stmts.cpp @@ -184,7 +184,12 @@ Stmt Simplify::visit(const For *op) { bounds_and_alignment_info.push(op->name, min_bounds); } - Stmt new_body = mutate(op->body); + Stmt new_body; + { + // If we're in the loop, the extent must be greater than 0. + ScopedFact fact = scoped_truth(0 < new_extent); + new_body = mutate(op->body); + } if (bounds_tracked) { bounds_and_alignment_info.pop(op->name); diff --git a/test/correctness/simplify.cpp b/test/correctness/simplify.cpp index e42c87fc07e8..5c7392f0702d 100644 --- a/test/correctness/simplify.cpp +++ b/test/correctness/simplify.cpp @@ -1289,6 +1289,12 @@ void check_boolean() { check(min(select((x == 1), -1, x), x), select((x == 1), -1, x)); check(min(select((x == -17), -1, x), x), x); + check(select(x == 1, y, x*y), x*y); + check(select(x != 1, x*y, y), x*y); + + check(min(select(x == 0, max(y, w), z), w), select(x == 0, w, min(w, z))); + check(max(select(x == 0, y, min(z, w)), w), select(x == 0, max(w, y), w)); + check((1 - xf) * 6 < 3, 0.5f < xf); check(!f, t); @@ -2245,6 +2251,12 @@ int main(int argc, char **argv) { check(slice(concat_vectors({vec_x, vec_y, vec_z}), 33, 2, 16), slice(concat_vectors({vec_y}), 1, 2, 16)); } + { + Stmt body = AssertStmt::make(x > 0, y); + check(For::make("t", 0, x, ForType::Serial, DeviceAPI::None, body), + Evaluate::make(0)); + } + // Check a bounds-related fuzz tester failure found in issue https://github.com/halide/Halide/issues/3764 check(Let::make("b", 105, 336 / max(cast(cast(Variable::make(Int(32), "b"))), 38) + 29), 32); From f49d3d0d9a7750bd89fa77e38ea28e08d706419a Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Mon, 22 Mar 2021 15:42:41 -0600 Subject: [PATCH 2/4] clang-format --- test/correctness/simplify.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/correctness/simplify.cpp b/test/correctness/simplify.cpp index 5c7392f0702d..6f036392248a 100644 --- a/test/correctness/simplify.cpp +++ b/test/correctness/simplify.cpp @@ -1289,8 +1289,8 @@ void check_boolean() { check(min(select((x == 1), -1, x), x), select((x == 1), -1, x)); check(min(select((x == -17), -1, x), x), x); - check(select(x == 1, y, x*y), x*y); - check(select(x != 1, x*y, y), x*y); + check(select(x == 1, y, x * y), x * y); + check(select(x != 1, x * y, y), x * y); check(min(select(x == 0, max(y, w), z), w), select(x == 0, w, min(w, z))); check(max(select(x == 0, y, min(z, w)), w), select(x == 0, max(w, y), w)); From e1bbf469581221525c19aa3ef6e37f27480bc479 Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Mon, 22 Mar 2021 17:54:03 -0600 Subject: [PATCH 3/4] Revert sketchy select simplification. --- src/Simplify_Select.cpp | 51 ++--------------------------------- test/correctness/simplify.cpp | 3 --- 2 files changed, 2 insertions(+), 52 deletions(-) diff --git a/src/Simplify_Select.cpp b/src/Simplify_Select.cpp index f761f110f07f..34f3de0f778b 100644 --- a/src/Simplify_Select.cpp +++ b/src/Simplify_Select.cpp @@ -7,55 +7,8 @@ Expr Simplify::visit(const Select *op, ExprInfo *bounds) { ExprInfo t_bounds, f_bounds; Expr condition = mutate(op->condition, nullptr); - - Expr true_value, false_value; - Expr false_value_when_true, true_value_when_false; - { - auto f = scoped_truth(condition); - - true_value = mutate(op->true_value, &t_bounds); - Expr learned_true_value = f.substitute_facts(true_value); - if (!learned_true_value.same_as(true_value)) { - true_value = mutate(learned_true_value, &t_bounds); - } - - false_value_when_true = mutate(op->false_value, nullptr); - Expr learned_false_value_when_true = f.substitute_facts(false_value_when_true); - if (!learned_false_value_when_true.same_as(false_value_when_true)) { - false_value_when_true = mutate(learned_false_value_when_true, nullptr); - } - } - { - auto f = scoped_falsehood(condition); - - false_value = mutate(op->false_value, &f_bounds); - Expr learned_false_value = f.substitute_facts(false_value); - if (!learned_false_value.same_as(false_value)) { - false_value = mutate(learned_false_value, &f_bounds); - } - - true_value_when_false = mutate(op->true_value, nullptr); - Expr learned_true_value_when_false = f.substitute_facts(true_value_when_false); - if (!learned_true_value_when_false.same_as(true_value_when_false)) { - true_value_when_false = mutate(learned_true_value_when_false, nullptr); - } - } - - // If the false value when the condition is equal to the true value, - // the value is always equal to the false value. This simplifies - // things like select(x == 1, y, y*x) to y*x. - if (equal(false_value_when_true, true_value)) { - if (bounds) { - *bounds = f_bounds; - } - return false_value; - } - if (equal(true_value_when_false, false_value)) { - if (bounds) { - *bounds = t_bounds; - } - return true_value; - } + Expr true_value = mutate(op->true_value, &t_bounds); + Expr false_value = mutate(op->false_value, &f_bounds); if (bounds) { bounds->min_defined = t_bounds.min_defined && f_bounds.min_defined; diff --git a/test/correctness/simplify.cpp b/test/correctness/simplify.cpp index 6f036392248a..d7a5b3649fa4 100644 --- a/test/correctness/simplify.cpp +++ b/test/correctness/simplify.cpp @@ -1289,9 +1289,6 @@ void check_boolean() { check(min(select((x == 1), -1, x), x), select((x == 1), -1, x)); check(min(select((x == -17), -1, x), x), x); - check(select(x == 1, y, x * y), x * y); - check(select(x != 1, x * y, y), x * y); - check(min(select(x == 0, max(y, w), z), w), select(x == 0, w, min(w, z))); check(max(select(x == 0, y, min(z, w)), w), select(x == 0, max(w, y), w)); From 5a3e507c144849f79899b97b1dbd058660188796 Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Fri, 26 Mar 2021 15:36:24 -0600 Subject: [PATCH 4/4] Add different rules. --- src/Simplify_Max.cpp | 5 ----- src/Simplify_Min.cpp | 5 ----- src/Simplify_Select.cpp | 43 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 10 deletions(-) diff --git a/src/Simplify_Max.cpp b/src/Simplify_Max.cpp index 0833c415a0f2..df7a99777670 100644 --- a/src/Simplify_Max.cpp +++ b/src/Simplify_Max.cpp @@ -173,11 +173,6 @@ Expr Simplify::visit(const Max *op, ExprInfo *bounds) { rewrite(max(x, max(min(x, y), z)), max(x, z)) || rewrite(max(x, max(min(y, x), z)), max(x, z)) || - rewrite(max(select(x, y, z), z), select(x, max(y, z), z)) || - rewrite(max(select(x, y, z), y), select(x, y, max(z, y))) || - rewrite(max(z, select(x, y, z)), select(x, max(z, y), z)) || - rewrite(max(y, select(x, y, z)), select(x, y, max(y, z))) || - rewrite(max(select(x, min(y, z), w), z), select(x, z, max(w, z))) || rewrite(max(select(x, min(z, y), w), z), select(x, z, max(w, z))) || rewrite(max(z, select(x, min(y, z), w)), select(x, z, max(z, w))) || diff --git a/src/Simplify_Min.cpp b/src/Simplify_Min.cpp index 252ed8119b3c..db93a3e40787 100644 --- a/src/Simplify_Min.cpp +++ b/src/Simplify_Min.cpp @@ -176,11 +176,6 @@ Expr Simplify::visit(const Min *op, ExprInfo *bounds) { rewrite(min(min(max(x, y), z), x), min(z, x)) || rewrite(min(min(max(x, y), z), y), min(z, y)) || - rewrite(min(select(x, y, z), z), select(x, min(y, z), z)) || - rewrite(min(select(x, y, z), y), select(x, y, min(z, y))) || - rewrite(min(z, select(x, y, z)), select(x, min(z, y), z)) || - rewrite(min(y, select(x, y, z)), select(x, y, min(y, z))) || - rewrite(min(select(x, max(y, z), w), z), select(x, z, min(w, z))) || rewrite(min(select(x, max(z, y), w), z), select(x, z, min(w, z))) || rewrite(min(z, select(x, max(y, z), w)), select(x, z, min(z, w))) || diff --git a/src/Simplify_Select.cpp b/src/Simplify_Select.cpp index 34f3de0f778b..99bbb7615c1d 100644 --- a/src/Simplify_Select.cpp +++ b/src/Simplify_Select.cpp @@ -102,6 +102,49 @@ Expr Simplify::visit(const Select *op, ExprInfo *bounds) { rewrite(select(x < 0, x * y, 0), min(x, 0) * y) || rewrite(select(x < 0, 0, x * y), max(x, 0) * y) || + rewrite(select(x, min(y, w), min(z, w)), min(select(x, y, z), w)) || + rewrite(select(x, min(y, w), min(w, z)), min(select(x, y, z), w)) || + rewrite(select(x, min(w, y), min(z, w)), min(w, select(x, y, z))) || + rewrite(select(x, min(w, y), min(w, z)), min(w, select(x, y, z))) || + rewrite(select(x, max(y, w), max(z, w)), max(select(x, y, z), w)) || + rewrite(select(x, max(y, w), max(w, z)), max(select(x, y, z), w)) || + rewrite(select(x, max(w, y), max(z, w)), max(w, select(x, y, z))) || + rewrite(select(x, max(w, y), max(w, z)), max(w, select(x, y, z))) || + + rewrite(select(x, select(y, z, min(w, z)), min(u, z)), min(select(x, select(y, z, w), u), z)) || + rewrite(select(x, select(y, min(w, z), z), min(u, z)), min(select(x, select(y, w, z), u), z)) || + rewrite(select(x, min(u, z), select(y, z, min(w, z))), min(select(x, u, select(y, z, w)), z)) || + rewrite(select(x, min(u, z), select(y, min(w, z), z)), min(select(x, u, select(y, w, z)), z)) || + rewrite(select(x, select(y, z, min(w, z)), min(z, u)), min(select(x, select(y, z, w), u), z)) || + rewrite(select(x, select(y, min(w, z), z), min(z, u)), min(select(x, select(y, w, z), u), z)) || + rewrite(select(x, min(z, u), select(y, z, min(w, z))), min(select(x, u, select(y, z, w)), z)) || + rewrite(select(x, min(z, u), select(y, min(w, z), z)), min(select(x, u, select(y, w, z)), z)) || + rewrite(select(x, select(y, z, min(z, w)), min(u, z)), min(select(x, select(y, z, w), u), z)) || + rewrite(select(x, select(y, min(z, w), z), min(u, z)), min(select(x, select(y, w, z), u), z)) || + rewrite(select(x, min(u, z), select(y, z, min(z, w))), min(select(x, u, select(y, z, w)), z)) || + rewrite(select(x, min(u, z), select(y, min(z, w), z)), min(select(x, u, select(y, w, z)), z)) || + rewrite(select(x, select(y, z, min(z, w)), min(z, u)), min(select(x, select(y, z, w), u), z)) || + rewrite(select(x, select(y, min(z, w), z), min(z, u)), min(select(x, select(y, w, z), u), z)) || + rewrite(select(x, min(z, u), select(y, z, min(z, w))), min(select(x, u, select(y, z, w)), z)) || + rewrite(select(x, min(z, u), select(y, min(z, w), z)), min(select(x, u, select(y, w, z)), z)) || + + rewrite(select(x, select(y, z, max(w, z)), max(u, z)), max(select(x, select(y, z, w), u), z)) || + rewrite(select(x, select(y, max(w, z), z), max(u, z)), max(select(x, select(y, w, z), u), z)) || + rewrite(select(x, max(u, z), select(y, z, max(w, z))), max(select(x, u, select(y, z, w)), z)) || + rewrite(select(x, max(u, z), select(y, max(w, z), z)), max(select(x, u, select(y, w, z)), z)) || + rewrite(select(x, select(y, z, max(w, z)), max(z, u)), max(select(x, select(y, z, w), u), z)) || + rewrite(select(x, select(y, max(w, z), z), max(z, u)), max(select(x, select(y, w, z), u), z)) || + rewrite(select(x, max(z, u), select(y, z, max(w, z))), max(select(x, u, select(y, z, w)), z)) || + rewrite(select(x, max(z, u), select(y, max(w, z), z)), max(select(x, u, select(y, w, z)), z)) || + rewrite(select(x, select(y, z, max(z, w)), max(u, z)), max(select(x, select(y, z, w), u), z)) || + rewrite(select(x, select(y, max(z, w), z), max(u, z)), max(select(x, select(y, w, z), u), z)) || + rewrite(select(x, max(u, z), select(y, z, max(z, w))), max(select(x, u, select(y, z, w)), z)) || + rewrite(select(x, max(u, z), select(y, max(z, w), z)), max(select(x, u, select(y, w, z)), z)) || + rewrite(select(x, select(y, z, max(z, w)), max(z, u)), max(select(x, select(y, z, w), u), z)) || + rewrite(select(x, select(y, max(z, w), z), max(z, u)), max(select(x, select(y, w, z), u), z)) || + rewrite(select(x, max(z, u), select(y, z, max(z, w))), max(select(x, u, select(y, z, w)), z)) || + rewrite(select(x, max(z, u), select(y, max(z, w), z)), max(select(x, u, select(y, w, z)), z)) || + // Note that in the rules below we know y is not a // constant because it appears on the LHS of an // addition. These rules therefore trade a non-constant