From 2c0933038998a340b3bd12aae59d51ecbea7c966 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Sun, 15 Mar 2026 13:10:52 +0100 Subject: [PATCH] Fixes #9030. Co-authored-by: Gemini 3.1 Pro --- src/Simplify_Exprs.cpp | 5 +++-- test/correctness/simplify.cpp | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/Simplify_Exprs.cpp b/src/Simplify_Exprs.cpp index 0eb3bbaf3c15..bbd67a5bace0 100644 --- a/src/Simplify_Exprs.cpp +++ b/src/Simplify_Exprs.cpp @@ -69,7 +69,7 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) { return value; } - if (info && op->type.is_int()) { + if (info && op->type.is_int_or_uint()) { switch (op->op) { case VectorReduce::Add: // Alignment of result is the alignment of the arg. Bounds @@ -123,7 +123,8 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) { case VectorReduce::Add: { auto rewrite = IRMatcher::rewriter(IRMatcher::h_add(value, lanes), op->type); if (rewrite(h_add(x * broadcast(y, arg_lanes), lanes), h_add(x, lanes) * broadcast(y, lanes)) || - rewrite(h_add(broadcast(x, arg_lanes) * y, lanes), h_add(y, lanes) * broadcast(x, lanes))) { + rewrite(h_add(broadcast(x, arg_lanes) * y, lanes), h_add(y, lanes) * broadcast(x, lanes)) || + rewrite(h_add(broadcast(x, arg_lanes), lanes), broadcast(x * factor, lanes))) { return mutate(rewrite.result, info); } break; diff --git a/test/correctness/simplify.cpp b/test/correctness/simplify.cpp index 628de4d91504..de10bde5a1b9 100644 --- a/test/correctness/simplify.cpp +++ b/test/correctness/simplify.cpp @@ -810,6 +810,24 @@ void check_vectors() { int_vector); check(VectorReduce::make(VectorReduce::Max, Broadcast::make(int_vector, 4), 8), VectorReduce::make(VectorReduce::Max, Broadcast::make(int_vector, 4), 8)); + + { + // h_add(broadcast(x, 8), 4) should simplify to broadcast(x * 2, 4) + check(VectorReduce::make(VectorReduce::Add, broadcast(x, 8), 4), + broadcast(x * 2, 4)); + } + + { + Expr const_u8 = cast(UInt(8), 3); + check(VectorReduce::make(VectorReduce::Add, broadcast(const_u8, 9), 3), broadcast(cast(UInt(8), 9), 3)); + } + + { + // Test VectorReduce::Add on a variable of unsigned type to ensure the multiplied factor + // keeps the correct type and avoids type-mismatch assertion failures. + Expr u8_x = Variable::make(UInt(8), "u8_x"); + check(VectorReduce::make(VectorReduce::Add, broadcast(u8_x, 9), 3), broadcast(u8_x * cast(UInt(8), 3), 3)); + } } void check_bounds() {