diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 8fdba6650f25..db6c75642f3b 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -236,7 +236,7 @@ inline IntSet CombineInterval(Interval a, Interval b) { if (is_one(b.min)) return IntervalSet::make(a); Expr e1 = a.has_lower_bound() ? ComputeExpr(a.min, b.min) : a.min; Expr e2 = a.has_upper_bound() ? ComputeExpr(a.max, b.min) : a.max; - // This is relaxiation + // no relaxation is needed in here due to set is inclusive // TODO(tqchen): consider convert to StrideSet. if (is_positive_const(b.min)) { return IntervalSet::make(e1, e2); @@ -251,6 +251,32 @@ inline IntSet CombineInterval(Interval a, Interval b) { return IntSet::everything(); } +template<> +inline IntSet CombineInterval
(Interval a, Interval b) { + if (a.is_single_point() && b.is_single_point()) { + return IntSet::single_point(ComputeExpr
(a.min, b.min)); + } + if (b.is_single_point()) { + if (is_zero(b.min)) { + LOG(FATAL) << "Divide by zero in CombineInterval Div"; + } + if (is_one(b.min)) return IntervalSet::make(a); + Expr e1 = a.has_lower_bound() ? ComputeExpr
(a.min, b.min) : a.min; + Expr e2 = a.has_upper_bound() ? ComputeExpr
(a.max, b.min) : a.max; + // no relaxation is needed in here due to set is inclusive + if (is_positive_const(b.min)) { + return IntervalSet::make(e1, e2); + } else if (is_negative_const(b.min)) { + return IntervalSet::make(e2, e1); + } else if (a.is_bounded()) { + Expr cmp = b.min >= make_zero(b.min.type().element_of()); + return IntervalSet::make(select(cmp, e1, e2), select(cmp, e2, e1)); + } + } + LOG(WARNING) << "Return Everything in CombineInterval Div"; + return IntSet::everything(); +} + template<> inline IntSet CombineInterval(Interval a, Interval b) { if (a.is_single_point() && b.is_single_point()) {