diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index 2ee427beb86c..d26ac3667620 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -173,8 +173,9 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) { WithRecordIterPredicate(cond, [&] { true_value = this->VisitExpr(op->args[1]); }); } { - With constraint(analyzer_, analyzer_->rewrite_simplify(Not(cond))); - false_value = this->VisitExpr(op->args[2]); + PrimExpr not_cond = Not(cond); + With constraint(analyzer_, not_cond); + WithRecordIterPredicate(not_cond, [&] { false_value = this->VisitExpr(op->args[2]); }); } if (is_zero(cond)) { return false_value; diff --git a/tests/python/tir-transform/test_tir_transform_simplify.py b/tests/python/tir-transform/test_tir_transform_simplify.py index c779d92f9c47..6bad817c4955 100644 --- a/tests/python/tir-transform/test_tir_transform_simplify.py +++ b/tests/python/tir-transform/test_tir_transform_simplify.py @@ -1757,5 +1757,17 @@ def expected(a: T.handle): A[T.int64(1)] = T.float32(0) +class TestNestedIfElimination(BaseBeforeAfter): + def before(a: T.Buffer((2, 8), "int32"), b: T.Buffer((2, 8), "int32")): + for i0, j0 in T.grid(2, 8): + b[i0, j0] = T.if_then_else( + i0 == 1 and 6 <= j0, 0, T.max(0, T.if_then_else(i0 == 1 and 6 <= j0, 0, a[i0, j0])) + ) + + def expected(a: T.Buffer((2, 8), "int32"), b: T.Buffer((2, 8), "int32")): + for i0, j0 in T.grid(2, 8): + b[i0, j0] = T.if_then_else(i0 == 1 and 6 <= j0, 0, T.max(0, a[i0, j0])) + + if __name__ == "__main__": tvm.testing.main()