Skip to content
9 changes: 6 additions & 3 deletions src/Simplify_Cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ Expr Simplify::visit(const Cast *op, ExprInfo *info) {
return make_const(op->type, safe_numeric_cast<double>(*u), info);
} else if (cast &&
op->type.code() == cast->type.code() &&
op->type.bits() < cast->type.bits()) {
op->type.bits() < cast->type.bits() &&
op->type.lanes() == cast->value.type().lanes()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is wrong. It's impossible to construct a cast where this is true.

// If this is a cast of a cast of the same type, where the
// outer cast is narrower, the inner cast can be
// eliminated.
Expand All @@ -93,7 +94,8 @@ Expr Simplify::visit(const Cast *op, ExprInfo *info) {
cast->type.is_int() &&
cast->value.type().is_int() &&
op->type.bits() >= cast->type.bits() &&
cast->type.bits() >= cast->value.type().bits()) {
cast->type.bits() >= cast->value.type().bits() &&
op->type.lanes() == cast->value.type().lanes()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

// Casting from a signed type always sign-extends, so widening
// partway to a signed type and the rest of the way to some other
// integer type is the same as just widening to that integer type
Expand All @@ -103,7 +105,8 @@ Expr Simplify::visit(const Cast *op, ExprInfo *info) {
op->type.is_int_or_uint() &&
cast->type.is_int_or_uint() &&
op->type.bits() <= cast->type.bits() &&
op->type.bits() <= op->value.type().bits()) {
op->type.bits() <= op->value.type().bits() &&
op->type.lanes() == cast->value.type().lanes()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

// If this is a cast between integer types, where the
// outer cast is narrower than the inner cast and the
// inner cast's argument, the inner cast can be
Expand Down
36 changes: 18 additions & 18 deletions src/Simplify_Exprs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) {
rewrite(h_min(max(broadcast(x, arg_lanes), y), lanes), max(h_min(y, lanes), broadcast(x, lanes))) ||
rewrite(h_min(broadcast(x, arg_lanes), lanes), broadcast(x, lanes)) ||
rewrite(h_min(broadcast(x, c0), lanes), h_min(x, lanes), factor % c0 == 0) ||
rewrite(h_min(ramp(x, y, arg_lanes), lanes), x + min(y * (arg_lanes - 1), 0)) ||
((lanes == 1) && rewrite(h_min(ramp(x, y, arg_lanes), lanes), x + min(y * (arg_lanes - 1), 0))) ||
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will cause merge conflict with #8629, we should just merge that first.

false) {
return mutate(rewrite.result, info);
}
Expand All @@ -151,7 +151,7 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) {
rewrite(h_max(max(broadcast(x, arg_lanes), y), lanes), max(h_max(y, lanes), broadcast(x, lanes))) ||
rewrite(h_max(broadcast(x, arg_lanes), lanes), broadcast(x, lanes)) ||
rewrite(h_max(broadcast(x, c0), lanes), h_max(x, lanes), factor % c0 == 0) ||
rewrite(h_max(ramp(x, y, arg_lanes), lanes), x + max(y * (arg_lanes - 1), 0)) ||
((lanes == 1) && rewrite(h_max(ramp(x, y, arg_lanes), lanes), x + max(y * (arg_lanes - 1), 0))) ||
false) {
return mutate(rewrite.result, info);
}
Expand All @@ -165,14 +165,14 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) {
rewrite(h_and(broadcast(x, arg_lanes) && y, lanes), h_and(y, lanes) && broadcast(x, lanes)) ||
rewrite(h_and(broadcast(x, arg_lanes), lanes), broadcast(x, lanes)) ||
rewrite(h_and(broadcast(x, c0), lanes), h_and(x, lanes), factor % c0 == 0) ||
rewrite(h_and(ramp(x, y, arg_lanes) < broadcast(z, arg_lanes), lanes),
x + max(y * (arg_lanes - 1), 0) < z) ||
rewrite(h_and(ramp(x, y, arg_lanes) <= broadcast(z, arg_lanes), lanes),
x + max(y * (arg_lanes - 1), 0) <= z) ||
rewrite(h_and(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes),
x < y + min(z * (arg_lanes - 1), 0)) ||
rewrite(h_and(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes),
x <= y + min(z * (arg_lanes - 1), 0)) ||
((lanes == 1) && rewrite(h_and(ramp(x, y, arg_lanes) < broadcast(z, arg_lanes), lanes),
x + max(y * (arg_lanes - 1), 0) < z)) ||
((lanes == 1) && rewrite(h_and(ramp(x, y, arg_lanes) <= broadcast(z, arg_lanes), lanes),
x + max(y * (arg_lanes - 1), 0) <= z)) ||
((lanes == 1) && rewrite(h_and(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes),
x < y + min(z * (arg_lanes - 1), 0))) ||
((lanes == 1) && rewrite(h_and(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes),
x <= y + min(z * (arg_lanes - 1), 0))) ||
false) {
return mutate(rewrite.result, info);
}
Expand All @@ -187,14 +187,14 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) {
rewrite(h_or(broadcast(x, arg_lanes), lanes), broadcast(x, lanes)) ||
rewrite(h_or(broadcast(x, c0), lanes), h_or(x, lanes), factor % c0 == 0) ||
// type of arg_lanes is somewhat indeterminate
rewrite(h_or(ramp(x, y, arg_lanes) < broadcast(z, arg_lanes), lanes),
x + min(y * (arg_lanes - 1), 0) < z) ||
rewrite(h_or(ramp(x, y, arg_lanes) <= broadcast(z, arg_lanes), lanes),
x + min(y * (arg_lanes - 1), 0) <= z) ||
rewrite(h_or(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes),
x < y + max(z * (arg_lanes - 1), 0)) ||
rewrite(h_or(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes),
x <= y + max(z * (arg_lanes - 1), 0)) ||
((lanes == 1) && rewrite(h_or(ramp(x, y, arg_lanes) < broadcast(z, arg_lanes), lanes),
x + min(y * (arg_lanes - 1), 0) < z)) ||
((lanes == 1) && rewrite(h_or(ramp(x, y, arg_lanes) <= broadcast(z, arg_lanes), lanes),
x + min(y * (arg_lanes - 1), 0) <= z)) ||
((lanes == 1) && rewrite(h_or(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes),
x < y + max(z * (arg_lanes - 1), 0))) ||
((lanes == 1) && rewrite(h_or(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes),
x <= y + max(z * (arg_lanes - 1), 0))) ||
false) {
return mutate(rewrite.result, info);
}
Expand Down
4 changes: 2 additions & 2 deletions src/Simplify_Shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *info) {
}
}
if (can_collapse) {
return Ramp::make(r->base, r->stride, op->indices.size());
return mutate(Ramp::make(r->base, r->stride, op->indices.size()), info);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

r was already simplified, this just gives it more lanes. Was there a case here where that made further simplification possible? I.e. was this an idempotence failure?

}
}

Expand All @@ -257,7 +257,7 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *info) {
}

if (can_collapse) {
return Ramp::make(new_vectors[0], stride, op->indices.size());
return mutate(Ramp::make(new_vectors[0], stride, op->indices.size()), info);
}
}
}
Expand Down
107 changes: 105 additions & 2 deletions test/correctness/fuzz_simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ int get_random_divisor(RandomEngine &rng, Type t) {
return random_choice(rng, divisors);
}

int random_vector_width(RandomEngine &rng, int min_lanes = 2, int multiple_of = 1) {
std::vector<int> widths;
for (int width : {2, 3, 4, 6, 8}) {
if (width >= min_lanes && (width % multiple_of) == 0) {
widths.push_back(width);
}
}
internal_assert(!widths.empty());
return random_choice(rng, widths);
}

Expr random_leaf(RandomEngine &rng, Type t, bool overflow_undef = false, bool imm_only = false) {
if (t.is_int() && t.bits() == 32) {
overflow_undef = true;
Expand Down Expand Up @@ -85,6 +96,87 @@ Expr random_leaf(RandomEngine &rng, Type t, bool overflow_undef = false, bool im

Expr random_expr(RandomEngine &rng, Type t, int depth, bool overflow_undef = false);

Expr random_shuffle_expr(RandomEngine &rng, Type t, int depth, bool overflow_undef) {
if (t.is_scalar()) {
int lanes = random_vector_width(rng);
Expr vector = random_expr(rng, t.with_lanes(lanes), depth, overflow_undef);
std::uniform_int_distribution<int> dist(0, lanes - 1);
return Shuffle::make_extract_element(vector, dist(rng));
}

std::vector<std::function<Expr()>> shuffles = {
[&]() {
int vectors = get_random_divisor(rng, t);
Type subtype = t.with_lanes(t.lanes() / vectors);
std::vector<Expr> exprs;
exprs.reserve(vectors);
for (int i = 0; i < vectors; i++) {
exprs.push_back(random_expr(rng, subtype, depth, overflow_undef));
}
return Shuffle::make_concat(exprs);
},
[&]() {
int vectors = get_random_divisor(rng, t);
Type subtype = t.with_lanes(t.lanes() / vectors);
std::vector<Expr> exprs;
exprs.reserve(vectors);
for (int i = 0; i < vectors; i++) {
exprs.push_back(random_expr(rng, subtype, depth, overflow_undef));
}
return Shuffle::make_interleave(exprs);
},
[&]() {
Expr vector = random_expr(rng, t, depth, overflow_undef);
std::vector<int> indices(t.lanes());
for (int i = 0; i < t.lanes(); i++) {
indices[i] = i;
if (i & 1) {
int tmp = indices[i];
indices[i] = indices[i / 2];
indices[i / 2] = tmp;
}
}
return Shuffle::make({vector}, indices);
},
};

if (t.lanes() * 2 <= 8) {
shuffles.push_back([&]() {
Expr vector = random_expr(rng, t.with_lanes(t.lanes() * 2), depth, overflow_undef);
std::uniform_int_distribution<int> dist(0, 1);
return Shuffle::make_slice(vector, dist(rng), 2, t.lanes());
});
}

return random_choice(rng, shuffles)();
}

Expr random_vector_reduce_expr(RandomEngine &rng, Type t, int depth, bool overflow_undef) {
int input_lanes = t.is_scalar() ? random_vector_width(rng) : random_vector_width(rng, t.lanes(), t.lanes());
Type input_type = t.with_lanes(input_lanes);
Expr vec = random_expr(rng, input_type, depth, overflow_undef);
int output_lanes = t.lanes();

if (input_type.is_bool()) {
VectorReduce::Operator reduce_ops[] = {
VectorReduce::And,
VectorReduce::Or,
};
return VectorReduce::make(random_choice(rng, reduce_ops), vec, output_lanes);
}

VectorReduce::Operator reduce_ops[] = {
VectorReduce::Add,
VectorReduce::SaturatingAdd,
VectorReduce::Mul,
VectorReduce::Min,
VectorReduce::Max,
VectorReduce::And,
VectorReduce::Or,
};
return VectorReduce::make(random_choice(rng, reduce_ops), vec, output_lanes);
}

Expr random_condition(RandomEngine &rng, Type t, int depth, bool maybe_scalar) {
static make_bin_op_fn make_bin_op[] = {
EQ::make,
Expand All @@ -95,13 +187,18 @@ Expr random_condition(RandomEngine &rng, Type t, int depth, bool maybe_scalar) {
GE::make,
};

int lanes = t.lanes();
if (maybe_scalar && (rng() & 1)) {
t = t.element_of();
}

Expr a = random_expr(rng, t, depth);
Expr b = random_expr(rng, t, depth);
return random_choice(rng, make_bin_op)(a, b);
Expr result = random_choice(rng, make_bin_op)(a, b);
if (result.type().lanes() != lanes) {
result = Broadcast::make(result, lanes);
}
return result;
}

Expr make_absd(Expr a, Expr b) {
Expand Down Expand Up @@ -135,7 +232,7 @@ Expr make_bitwise_not(Expr a, Expr) {
}

Expr make_shift_right(Expr a, Expr b) {
return a >> (b % a.type().bits());
return a >> (b % make_const(b.type(), a.type().bits()));
}

Expr random_expr(RandomEngine &rng, Type t, int depth, bool overflow_undef) {
Expand Down Expand Up @@ -174,6 +271,12 @@ Expr random_expr(RandomEngine &rng, Type t, int depth, bool overflow_undef) {
}
return random_expr(rng, t, depth, overflow_undef);
},
[&]() {
return random_shuffle_expr(rng, t, depth, overflow_undef);
},
[&]() {
return random_vector_reduce_expr(rng, t, depth, overflow_undef);
},
[&]() {
if (t.is_bool()) {
auto e1 = random_expr(rng, t, depth);
Expand Down
6 changes: 6 additions & 0 deletions test/correctness/simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,7 @@ void check_vectors() {
// Collapse some vector concats
check(concat_vectors({ramp(x, 2, 4), ramp(x + 8, 2, 4)}), ramp(x, 2, 8));
check(concat_vectors({ramp(x, 3, 2), ramp(x + 6, 3, 2), ramp(x + 12, 3, 2)}), ramp(x, 3, 6));
check(concat_vectors({x, x}), Broadcast::make(x, 2));

// Now some ones that can't work
{
Expand Down Expand Up @@ -828,6 +829,11 @@ void check_vectors() {
Expr u8_x = Variable::make(UInt(8), "u8_x");
check(VectorReduce::make(VectorReduce::Add, broadcast(u8_x, 9), 3), broadcast(u8_x * cast(UInt(8), 3), 3));
}

check(cast(UInt(32, 2),
VectorReduce::make(VectorReduce::Max, Ramp::make(cast(UInt(8), x), cast(UInt(8), y), 8), 2)),
cast(UInt(32, 2),
VectorReduce::make(VectorReduce::Max, Ramp::make(cast(UInt(8), x), cast(UInt(8), y), 8), 2)));
}

void check_bounds() {
Expand Down