diff --git a/src/Simplify_Add.cpp b/src/Simplify_Add.cpp index b25c7c7e6a46..64ada056c357 100644 --- a/src/Simplify_Add.cpp +++ b/src/Simplify_Add.cpp @@ -51,12 +51,17 @@ Expr Simplify::visit(const Add *op, ExprInfo *bounds) { (rewrite(x + x, x * 2) || rewrite(ramp(x, y, c0) + ramp(z, w, c0), ramp(x + z, y + w, c0)) || rewrite(ramp(x, y, c0) + broadcast(z, c0), ramp(x + z, y, c0)) || - rewrite(broadcast(x, c0) + broadcast(y, c0), broadcast(x + y, c0)) || rewrite(broadcast(x, c0) + broadcast(y, c1), broadcast(x + broadcast(y, fold(c1/c0)), c0), c1 % c0 == 0) || rewrite(broadcast(y, c1) + broadcast(x, c0), broadcast(x + broadcast(y, fold(c1/c0)), c0), c1 % c0 == 0) || - rewrite((x + broadcast(y, c0)) + broadcast(z, c0), x + broadcast(y + z, c0)) || - rewrite((x - broadcast(y, c0)) + broadcast(z, c0), x + broadcast(z - y, c0)) || + rewrite((x + broadcast(y, c0)) + broadcast(z, c1), x + broadcast(y + broadcast(z, fold(c1/c0)), c0), c1 % c0 == 0) || + rewrite((x + broadcast(z, c1)) + broadcast(y, c0), x + broadcast(y + broadcast(z, fold(c1/c0)), c0), c1 % c0 == 0) || + rewrite((broadcast(y, c0) + x) + broadcast(z, c1), x + broadcast(y + broadcast(z, fold(c1/c0)), c0), c1 % c0 == 0) || + rewrite((broadcast(z, c1) + x) + broadcast(y, c0), x + broadcast(y + broadcast(z, fold(c1/c0)), c0), c1 % c0 == 0) || + rewrite((x - broadcast(y, c0)) + broadcast(z, c1), x + broadcast(broadcast(z, fold(c1/c0)) - y, c0), c1 % c0 == 0) || + rewrite((x - broadcast(z, c1)) + broadcast(y, c0), x + broadcast(y - broadcast(z, fold(c1/c0)), c0), c1 % c0 == 0) || + rewrite((broadcast(y, c0) - x) + broadcast(z, c1), broadcast(y + broadcast(z, fold(c1/c0)), c0) - x, c1 % c0 == 0) || + rewrite((broadcast(z, c1) - x) + broadcast(y, c0), broadcast(y + broadcast(z, fold(c1/c0)), c0) - x, c1 % c0 == 0) || rewrite(select(x, y, z) + select(x, w, u), select(x, y + w, z + u)) || rewrite(select(x, c0, c1) + c2, select(x, fold(c0 + c2), fold(c1 + c2))) || diff --git a/test/correctness/simplify.cpp b/test/correctness/simplify.cpp index db74645e8d6d..e42c87fc07e8 100644 --- a/test/correctness/simplify.cpp +++ b/test/correctness/simplify.cpp @@ -561,6 +561,10 @@ void check_vectors() { check(ramp(0, 1, 4) == broadcast(2, 4), ramp(-2, 1, 4) == broadcast(0, 4)); + check(ramp(broadcast(0, 6), broadcast(6, 6), 4) + broadcast(ramp(0, 1, 3), 8) + + broadcast(ramp(broadcast(0, 3), broadcast(3, 3), 2), 4), + ramp(0, 1, 24)); + // Any linear combination of simple ramps and broadcasts should // reduce to a single ramp or broadcast. std::mt19937 rng(0);