From d781eef57007615fb32e9043787aea6ee89e9df9 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 16 Mar 2023 02:41:47 +0000 Subject: [PATCH 1/2] [Fix][TIR] Fix tvm::arith::UnionLowerBound --- src/arith/int_set.cc | 1 + tests/python/unittest/test_arith_intset.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 7d601d9a8bae..a75d316a7ece 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -865,6 +865,7 @@ IntSet UnionLowerBound(const Array& sets) { PrimExpr min_inclusive{nullptr}; PrimExpr max_inclusive(nullptr); for (const IntSet& int_set : sets) { + if (int_set.IsNothing()) continue; if (const auto* interval_set = int_set.as()) { PrimExpr new_min_inclusive = interval_set->min_value; PrimExpr new_max_inclusive = interval_set->max_value; diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index da3fd94f8192..12214c596ce7 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -373,6 +373,10 @@ def test_union_lower_bound(): result = tvm.arith.int_set.union_lower_bound([set_0, set_1]) assert result.min_value.same_as(neg_inf) assert result.max_value.same_as(pos_inf) + set_2 = tvm.arith.IntervalSet(min_value=pos_inf, max_value=neg_inf) + result = tvm.arith.int_set.union_lower_bound([set_0, set_2]) + assert result.min_value.same_as(neg_inf) + assert result.max_value.same_as(0) if __name__ == "__main__": From cff6f3fc05518f0efc94e76bd678178d562a63e4 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 16 Mar 2023 08:31:00 +0000 Subject: [PATCH 2/2] fix CI test --- tests/python/unittest/test_arith_intset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index 12214c596ce7..5b991151488e 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -374,9 +374,9 @@ def test_union_lower_bound(): assert result.min_value.same_as(neg_inf) assert result.max_value.same_as(pos_inf) set_2 = tvm.arith.IntervalSet(min_value=pos_inf, max_value=neg_inf) - result = tvm.arith.int_set.union_lower_bound([set_0, set_2]) + result = tvm.arith.int_set.union_lower_bound([set_0, set_1, set_2]) assert result.min_value.same_as(neg_inf) - assert result.max_value.same_as(0) + assert result.max_value.same_as(pos_inf) if __name__ == "__main__":