diff --git a/src/Simplify_Cast.cpp b/src/Simplify_Cast.cpp index 8738c26aeb97..d217f5968de4 100644 --- a/src/Simplify_Cast.cpp +++ b/src/Simplify_Cast.cpp @@ -83,7 +83,8 @@ Expr Simplify::visit(const Cast *op, ExprInfo *info) { return make_const(op->type, safe_numeric_cast(*u), info); } else if (cast && op->type.code() == cast->type.code() && - op->type.bits() < cast->type.bits()) { + op->type.bits() < cast->type.bits() && + op->type.lanes() == cast->value.type().lanes()) { // If this is a cast of a cast of the same type, where the // outer cast is narrower, the inner cast can be // eliminated. @@ -93,7 +94,8 @@ Expr Simplify::visit(const Cast *op, ExprInfo *info) { cast->type.is_int() && cast->value.type().is_int() && op->type.bits() >= cast->type.bits() && - cast->type.bits() >= cast->value.type().bits()) { + cast->type.bits() >= cast->value.type().bits() && + op->type.lanes() == cast->value.type().lanes()) { // Casting from a signed type always sign-extends, so widening // partway to a signed type and the rest of the way to some other // integer type is the same as just widening to that integer type @@ -103,7 +105,8 @@ Expr Simplify::visit(const Cast *op, ExprInfo *info) { op->type.is_int_or_uint() && cast->type.is_int_or_uint() && op->type.bits() <= cast->type.bits() && - op->type.bits() <= op->value.type().bits()) { + op->type.bits() <= op->value.type().bits() && + op->type.lanes() == cast->value.type().lanes()) { // If this is a cast between integer types, where the // outer cast is narrower than the inner cast and the // inner cast's argument, the inner cast can be diff --git a/src/Simplify_Exprs.cpp b/src/Simplify_Exprs.cpp index bbd67a5bace0..0b167f8ab0a1 100644 --- a/src/Simplify_Exprs.cpp +++ b/src/Simplify_Exprs.cpp @@ -137,7 +137,7 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) { rewrite(h_min(max(broadcast(x, arg_lanes), y), lanes), max(h_min(y, lanes), broadcast(x, lanes))) || rewrite(h_min(broadcast(x, arg_lanes), lanes), broadcast(x, lanes)) || rewrite(h_min(broadcast(x, c0), lanes), h_min(x, lanes), factor % c0 == 0) || - rewrite(h_min(ramp(x, y, arg_lanes), lanes), x + min(y * (arg_lanes - 1), 0)) || + ((lanes == 1) && rewrite(h_min(ramp(x, y, arg_lanes), lanes), x + min(y * (arg_lanes - 1), 0))) || false) { return mutate(rewrite.result, info); } @@ -151,7 +151,7 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) { rewrite(h_max(max(broadcast(x, arg_lanes), y), lanes), max(h_max(y, lanes), broadcast(x, lanes))) || rewrite(h_max(broadcast(x, arg_lanes), lanes), broadcast(x, lanes)) || rewrite(h_max(broadcast(x, c0), lanes), h_max(x, lanes), factor % c0 == 0) || - rewrite(h_max(ramp(x, y, arg_lanes), lanes), x + max(y * (arg_lanes - 1), 0)) || + ((lanes == 1) && rewrite(h_max(ramp(x, y, arg_lanes), lanes), x + max(y * (arg_lanes - 1), 0))) || false) { return mutate(rewrite.result, info); } @@ -165,14 +165,14 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) { rewrite(h_and(broadcast(x, arg_lanes) && y, lanes), h_and(y, lanes) && broadcast(x, lanes)) || rewrite(h_and(broadcast(x, arg_lanes), lanes), broadcast(x, lanes)) || rewrite(h_and(broadcast(x, c0), lanes), h_and(x, lanes), factor % c0 == 0) || - rewrite(h_and(ramp(x, y, arg_lanes) < broadcast(z, arg_lanes), lanes), - x + max(y * (arg_lanes - 1), 0) < z) || - rewrite(h_and(ramp(x, y, arg_lanes) <= broadcast(z, arg_lanes), lanes), - x + max(y * (arg_lanes - 1), 0) <= z) || - rewrite(h_and(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), - x < y + min(z * (arg_lanes - 1), 0)) || - rewrite(h_and(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), - x <= y + min(z * (arg_lanes - 1), 0)) || + ((lanes == 1) && rewrite(h_and(ramp(x, y, arg_lanes) < broadcast(z, arg_lanes), lanes), + x + max(y * (arg_lanes - 1), 0) < z)) || + ((lanes == 1) && rewrite(h_and(ramp(x, y, arg_lanes) <= broadcast(z, arg_lanes), lanes), + x + max(y * (arg_lanes - 1), 0) <= z)) || + ((lanes == 1) && rewrite(h_and(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), + x < y + min(z * (arg_lanes - 1), 0))) || + ((lanes == 1) && rewrite(h_and(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), + x <= y + min(z * (arg_lanes - 1), 0))) || false) { return mutate(rewrite.result, info); } @@ -187,14 +187,14 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) { rewrite(h_or(broadcast(x, arg_lanes), lanes), broadcast(x, lanes)) || rewrite(h_or(broadcast(x, c0), lanes), h_or(x, lanes), factor % c0 == 0) || // type of arg_lanes is somewhat indeterminate - rewrite(h_or(ramp(x, y, arg_lanes) < broadcast(z, arg_lanes), lanes), - x + min(y * (arg_lanes - 1), 0) < z) || - rewrite(h_or(ramp(x, y, arg_lanes) <= broadcast(z, arg_lanes), lanes), - x + min(y * (arg_lanes - 1), 0) <= z) || - rewrite(h_or(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), - x < y + max(z * (arg_lanes - 1), 0)) || - rewrite(h_or(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), - x <= y + max(z * (arg_lanes - 1), 0)) || + ((lanes == 1) && rewrite(h_or(ramp(x, y, arg_lanes) < broadcast(z, arg_lanes), lanes), + x + min(y * (arg_lanes - 1), 0) < z)) || + ((lanes == 1) && rewrite(h_or(ramp(x, y, arg_lanes) <= broadcast(z, arg_lanes), lanes), + x + min(y * (arg_lanes - 1), 0) <= z)) || + ((lanes == 1) && rewrite(h_or(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), + x < y + max(z * (arg_lanes - 1), 0))) || + ((lanes == 1) && rewrite(h_or(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), + x <= y + max(z * (arg_lanes - 1), 0))) || false) { return mutate(rewrite.result, info); } diff --git a/src/Simplify_Shuffle.cpp b/src/Simplify_Shuffle.cpp index aecb4c6fc99a..5ea02bd2170e 100644 --- a/src/Simplify_Shuffle.cpp +++ b/src/Simplify_Shuffle.cpp @@ -236,7 +236,7 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *info) { } } if (can_collapse) { - return Ramp::make(r->base, r->stride, op->indices.size()); + return mutate(Ramp::make(r->base, r->stride, op->indices.size()), info); } } @@ -257,7 +257,7 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *info) { } if (can_collapse) { - return Ramp::make(new_vectors[0], stride, op->indices.size()); + return mutate(Ramp::make(new_vectors[0], stride, op->indices.size()), info); } } } diff --git a/test/correctness/fuzz_simplify.cpp b/test/correctness/fuzz_simplify.cpp index 4154773dd1ec..235427eab330 100644 --- a/test/correctness/fuzz_simplify.cpp +++ b/test/correctness/fuzz_simplify.cpp @@ -53,6 +53,17 @@ int get_random_divisor(RandomEngine &rng, Type t) { return random_choice(rng, divisors); } +int random_vector_width(RandomEngine &rng, int min_lanes = 2, int multiple_of = 1) { + std::vector widths; + for (int width : {2, 3, 4, 6, 8}) { + if (width >= min_lanes && (width % multiple_of) == 0) { + widths.push_back(width); + } + } + internal_assert(!widths.empty()); + return random_choice(rng, widths); +} + Expr random_leaf(RandomEngine &rng, Type t, bool overflow_undef = false, bool imm_only = false) { if (t.is_int() && t.bits() == 32) { overflow_undef = true; @@ -85,6 +96,87 @@ Expr random_leaf(RandomEngine &rng, Type t, bool overflow_undef = false, bool im Expr random_expr(RandomEngine &rng, Type t, int depth, bool overflow_undef = false); +Expr random_shuffle_expr(RandomEngine &rng, Type t, int depth, bool overflow_undef) { + if (t.is_scalar()) { + int lanes = random_vector_width(rng); + Expr vector = random_expr(rng, t.with_lanes(lanes), depth, overflow_undef); + std::uniform_int_distribution dist(0, lanes - 1); + return Shuffle::make_extract_element(vector, dist(rng)); + } + + std::vector> shuffles = { + [&]() { + int vectors = get_random_divisor(rng, t); + Type subtype = t.with_lanes(t.lanes() / vectors); + std::vector exprs; + exprs.reserve(vectors); + for (int i = 0; i < vectors; i++) { + exprs.push_back(random_expr(rng, subtype, depth, overflow_undef)); + } + return Shuffle::make_concat(exprs); + }, + [&]() { + int vectors = get_random_divisor(rng, t); + Type subtype = t.with_lanes(t.lanes() / vectors); + std::vector exprs; + exprs.reserve(vectors); + for (int i = 0; i < vectors; i++) { + exprs.push_back(random_expr(rng, subtype, depth, overflow_undef)); + } + return Shuffle::make_interleave(exprs); + }, + [&]() { + Expr vector = random_expr(rng, t, depth, overflow_undef); + std::vector indices(t.lanes()); + for (int i = 0; i < t.lanes(); i++) { + indices[i] = i; + if (i & 1) { + int tmp = indices[i]; + indices[i] = indices[i / 2]; + indices[i / 2] = tmp; + } + } + return Shuffle::make({vector}, indices); + }, + }; + + if (t.lanes() * 2 <= 8) { + shuffles.push_back([&]() { + Expr vector = random_expr(rng, t.with_lanes(t.lanes() * 2), depth, overflow_undef); + std::uniform_int_distribution dist(0, 1); + return Shuffle::make_slice(vector, dist(rng), 2, t.lanes()); + }); + } + + return random_choice(rng, shuffles)(); +} + +Expr random_vector_reduce_expr(RandomEngine &rng, Type t, int depth, bool overflow_undef) { + int input_lanes = t.is_scalar() ? random_vector_width(rng) : random_vector_width(rng, t.lanes(), t.lanes()); + Type input_type = t.with_lanes(input_lanes); + Expr vec = random_expr(rng, input_type, depth, overflow_undef); + int output_lanes = t.lanes(); + + if (input_type.is_bool()) { + VectorReduce::Operator reduce_ops[] = { + VectorReduce::And, + VectorReduce::Or, + }; + return VectorReduce::make(random_choice(rng, reduce_ops), vec, output_lanes); + } + + VectorReduce::Operator reduce_ops[] = { + VectorReduce::Add, + VectorReduce::SaturatingAdd, + VectorReduce::Mul, + VectorReduce::Min, + VectorReduce::Max, + VectorReduce::And, + VectorReduce::Or, + }; + return VectorReduce::make(random_choice(rng, reduce_ops), vec, output_lanes); +} + Expr random_condition(RandomEngine &rng, Type t, int depth, bool maybe_scalar) { static make_bin_op_fn make_bin_op[] = { EQ::make, @@ -95,13 +187,18 @@ Expr random_condition(RandomEngine &rng, Type t, int depth, bool maybe_scalar) { GE::make, }; + int lanes = t.lanes(); if (maybe_scalar && (rng() & 1)) { t = t.element_of(); } Expr a = random_expr(rng, t, depth); Expr b = random_expr(rng, t, depth); - return random_choice(rng, make_bin_op)(a, b); + Expr result = random_choice(rng, make_bin_op)(a, b); + if (result.type().lanes() != lanes) { + result = Broadcast::make(result, lanes); + } + return result; } Expr make_absd(Expr a, Expr b) { @@ -135,7 +232,7 @@ Expr make_bitwise_not(Expr a, Expr) { } Expr make_shift_right(Expr a, Expr b) { - return a >> (b % a.type().bits()); + return a >> (b % make_const(b.type(), a.type().bits())); } Expr random_expr(RandomEngine &rng, Type t, int depth, bool overflow_undef) { @@ -174,6 +271,12 @@ Expr random_expr(RandomEngine &rng, Type t, int depth, bool overflow_undef) { } return random_expr(rng, t, depth, overflow_undef); }, + [&]() { + return random_shuffle_expr(rng, t, depth, overflow_undef); + }, + [&]() { + return random_vector_reduce_expr(rng, t, depth, overflow_undef); + }, [&]() { if (t.is_bool()) { auto e1 = random_expr(rng, t, depth); diff --git a/test/correctness/simplify.cpp b/test/correctness/simplify.cpp index de10bde5a1b9..8616bfa024aa 100644 --- a/test/correctness/simplify.cpp +++ b/test/correctness/simplify.cpp @@ -672,6 +672,7 @@ void check_vectors() { // Collapse some vector concats check(concat_vectors({ramp(x, 2, 4), ramp(x + 8, 2, 4)}), ramp(x, 2, 8)); check(concat_vectors({ramp(x, 3, 2), ramp(x + 6, 3, 2), ramp(x + 12, 3, 2)}), ramp(x, 3, 6)); + check(concat_vectors({x, x}), Broadcast::make(x, 2)); // Now some ones that can't work { @@ -828,6 +829,11 @@ void check_vectors() { Expr u8_x = Variable::make(UInt(8), "u8_x"); check(VectorReduce::make(VectorReduce::Add, broadcast(u8_x, 9), 3), broadcast(u8_x * cast(UInt(8), 3), 3)); } + + check(cast(UInt(32, 2), + VectorReduce::make(VectorReduce::Max, Ramp::make(cast(UInt(8), x), cast(UInt(8), y), 8), 2)), + cast(UInt(32, 2), + VectorReduce::make(VectorReduce::Max, Ramp::make(cast(UInt(8), x), cast(UInt(8), y), 8), 2))); } void check_bounds() {