From 9fe68a09915db9ddb95b7c41b61fc656bd09330d Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Mon, 8 Mar 2021 08:33:18 -0700 Subject: [PATCH 1/2] Handle some reassociation when simplifying nested broadcasts. --- src/Simplify_Add.cpp | 11 ++++++++--- test/correctness/simplify.cpp | 4 ++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/Simplify_Add.cpp b/src/Simplify_Add.cpp index b25c7c7e6a46..64ada056c357 100644 --- a/src/Simplify_Add.cpp +++ b/src/Simplify_Add.cpp @@ -51,12 +51,17 @@ Expr Simplify::visit(const Add *op, ExprInfo *bounds) { (rewrite(x + x, x * 2) || rewrite(ramp(x, y, c0) + ramp(z, w, c0), ramp(x + z, y + w, c0)) || rewrite(ramp(x, y, c0) + broadcast(z, c0), ramp(x + z, y, c0)) || - rewrite(broadcast(x, c0) + broadcast(y, c0), broadcast(x + y, c0)) || rewrite(broadcast(x, c0) + broadcast(y, c1), broadcast(x + broadcast(y, fold(c1/c0)), c0), c1 % c0 == 0) || rewrite(broadcast(y, c1) + broadcast(x, c0), broadcast(x + broadcast(y, fold(c1/c0)), c0), c1 % c0 == 0) || - rewrite((x + broadcast(y, c0)) + broadcast(z, c0), x + broadcast(y + z, c0)) || - rewrite((x - broadcast(y, c0)) + broadcast(z, c0), x + broadcast(z - y, c0)) || + rewrite((x + broadcast(y, c0)) + broadcast(z, c1), x + broadcast(y + broadcast(z, fold(c1/c0)), c0), c1 % c0 == 0) || + rewrite((x + broadcast(z, c1)) + broadcast(y, c0), x + broadcast(y + broadcast(z, fold(c1/c0)), c0), c1 % c0 == 0) || + rewrite((broadcast(y, c0) + x) + broadcast(z, c1), x + broadcast(y + broadcast(z, fold(c1/c0)), c0), c1 % c0 == 0) || + rewrite((broadcast(z, c1) + x) + broadcast(y, c0), x + broadcast(y + broadcast(z, fold(c1/c0)), c0), c1 % c0 == 0) || + rewrite((x - broadcast(y, c0)) + broadcast(z, c1), x + broadcast(broadcast(z, fold(c1/c0)) - y, c0), c1 % c0 == 0) || + rewrite((x - broadcast(z, c1)) + broadcast(y, c0), x + broadcast(y - broadcast(z, fold(c1/c0)), c0), c1 % c0 == 0) || + rewrite((broadcast(y, c0) - x) + broadcast(z, c1), broadcast(y + broadcast(z, fold(c1/c0)), c0) - x, c1 % c0 == 0) || + rewrite((broadcast(z, c1) - x) + broadcast(y, c0), broadcast(y + broadcast(z, fold(c1/c0)), c0) - x, c1 % c0 == 0) || rewrite(select(x, y, z) + select(x, w, u), select(x, y + w, z + u)) || rewrite(select(x, c0, c1) + c2, select(x, fold(c0 + c2), fold(c1 + c2))) || diff --git a/test/correctness/simplify.cpp b/test/correctness/simplify.cpp index db74645e8d6d..1dc6868f22f5 100644 --- a/test/correctness/simplify.cpp +++ b/test/correctness/simplify.cpp @@ -561,6 +561,10 @@ void check_vectors() { check(ramp(0, 1, 4) == broadcast(2, 4), ramp(-2, 1, 4) == broadcast(0, 4)); + check(ramp(broadcast(0, 6), broadcast(6, 6), 4) + broadcast(ramp(0, 1, 3), 8) + + broadcast(ramp(broadcast(0, 3), broadcast(3, 3), 2), 4), + ramp(0, 1, 24)); + // Any linear combination of simple ramps and broadcasts should // reduce to a single ramp or broadcast. std::mt19937 rng(0); From 2324c402619585ab1fe73cfc07f8199ebcc9b66b Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Mon, 8 Mar 2021 09:42:21 -0700 Subject: [PATCH 2/2] clang-format. --- test/correctness/simplify.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/correctness/simplify.cpp b/test/correctness/simplify.cpp index 1dc6868f22f5..e42c87fc07e8 100644 --- a/test/correctness/simplify.cpp +++ b/test/correctness/simplify.cpp @@ -562,7 +562,7 @@ void check_vectors() { ramp(-2, 1, 4) == broadcast(0, 4)); check(ramp(broadcast(0, 6), broadcast(6, 6), 4) + broadcast(ramp(0, 1, 3), 8) + - broadcast(ramp(broadcast(0, 3), broadcast(3, 3), 2), 4), + broadcast(ramp(broadcast(0, 3), broadcast(3, 3), 2), 4), ramp(0, 1, 24)); // Any linear combination of simple ramps and broadcasts should