From fed5cc6a410bde30a0694358ff02c3d2bee7805a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 7 Jun 2022 08:20:27 -0500 Subject: [PATCH 1/6] [TIR] Simplify expressions using tir.ceil and tir.log2 These expressions are introduced in `topi.math.ceil_log2`, and can otherwise be propagated through to the generated kernel. --- src/arith/rewrite_simplify.cc | 14 ++++++++++ .../unittest/test_tir_transform_simplify.py | 26 +++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index a168e1f0836c..ca6b88f0134c 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -1640,13 +1640,27 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { // the operator overload will eagerly constant fold. return op->args[0] << op->args[1]; } + } else if (op->op.same_as(Op::Get("tir.ceil"))) { + if (auto as_int = op->args[0].as()) { + return cast(op->dtype, IntImm(as_int->dtype, as_int->value)); + } else if (auto as_float = op->args[0].as()) { + return cast(op->dtype, FloatImm(as_float->dtype, std::ceil(as_float->value))); + } + } else if (op->op.same_as(Op::Get("tir.log2"))) { + if (auto as_int = op->args[0].as()) { + return cast(op->dtype, FloatImm(as_int->dtype, std::log2(as_int->value))); + } else if (auto as_float = op->args[0].as()) { + return cast(op->dtype, FloatImm(as_float->dtype, std::log2(as_float->value))); + } } + if (op->op.same_as(tir::builtin::likely())) { // Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } } if (auto match = TryMatchLiteralConstraint(op->args[0])) { return match.value(); } } + return ret; } diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py index 4f727cd89b12..7870c5c82c62 100644 --- a/tests/python/unittest/test_tir_transform_simplify.py +++ b/tests/python/unittest/test_tir_transform_simplify.py @@ -391,5 +391,31 @@ def expected(A: T.Buffer[(16, 16), "int32"], n: T.int32): A[i, j] = 2 +class TestCeilLog2Float(BaseBeforeAfter): + """Simplify expressions resulting from topi.math.ceil_log2""" + + @T.prim_func + def before(A: T.Buffer[1, "float32"]): + A[0] = T.ceil(T.log2(14.0, dtype="float32"), dtype="float32") + + @T.prim_func + def expected(A: T.Buffer[1, "float32"]): + A[0] = 4.0 + + +class TestCeilLog2Int(BaseBeforeAfter): + """Simplify expressions resulting from topi.math.ceil_log2""" + + @T.prim_func + def before(A: T.Buffer[1, "int32"]): + A[0] = T.cast( + T.ceil(T.log2(T.cast(14, "float64"), dtype="float64"), dtype="float64"), dtype="int32" + ) + + @T.prim_func + def expected(A: T.Buffer[1, "int32"]): + A[0] = 4 + + if __name__ == "__main__": tvm.testing.main() From 2cc2661e3740ff90b747d5abf7301c1ac4ee62ea Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 9 Jun 2022 10:34:44 -0500 Subject: [PATCH 2/6] [Arith] Added left shift handling to ConstIntBoundsAnalyzer Previously, only right shift was handled. These left shifts are used in the `cuda.sort` implementation. --- src/arith/const_int_bound.cc | 36 ++++++++++++++++- .../unittest/test_tir_transform_simplify.py | 40 +++++++++++++++++++ 2 files changed, 75 insertions(+), 1 deletion(-) diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 4fd27a0fde10..597da43f5423 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -314,6 +314,8 @@ class ConstIntBoundAnalyzer::Impl if (op->op.same_as(tir::builtin::shift_right())) { return VisitRightShift(op); + } else if (op->op.same_as(tir::builtin::shift_left())) { + return VisitLeftShift(op); } else if (op->op.same_as(tir::builtin::bitwise_and())) { return VisitBitwiseAnd(op); } else { @@ -341,6 +343,12 @@ class ConstIntBoundAnalyzer::Impl } } + Entry VisitLeftShift(const CallNode* op) { + Entry a = VisitExpr(op->args[0]); + Entry b = VisitExpr(op->args[1]); + return BinaryOpBoundary(a, b, InfAwareLeftShift); + } + Entry VisitRightShift(const CallNode* op) { Entry a = VisitExpr(op->args[0]); Entry b = VisitExpr(op->args[1]); @@ -509,7 +517,33 @@ class ConstIntBoundAnalyzer::Impl return floordiv(x, y); } /*! - * \brief Compute x / y, aware of inf. + * \brief Compute x << y, aware of inf. + * \param x The left operand. + * \param y The right operand. + * \return the result. + */ + static int64_t InfAwareLeftShift(int64_t x, int64_t y) { + if (x == kPosInf || x == kNegInf) return x; + + // Can be replaced with std::bit_width in C++20 + auto bit_width = [](int64_t as_signed) { + uint64_t val = std::abs(as_signed); + int num_bits = 0; + while (val) { + ++num_bits; + val >>= 1; + } + return num_bits; + }; + int x_bits = bit_width(x); + if (x_bits + y < 64) { + return x << y; + } else { + return kPosInf; + } + } + /*! + * \brief Compute x >> y, aware of inf. * \param x The left operand. * \param y The right operand. * \return the result. diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py index 7870c5c82c62..b35b3766073a 100644 --- a/tests/python/unittest/test_tir_transform_simplify.py +++ b/tests/python/unittest/test_tir_transform_simplify.py @@ -417,5 +417,45 @@ def expected(A: T.Buffer[1, "int32"]): A[0] = 4 +class TestLeftShiftLowerBound(BaseBeforeAfter): + """Integer bounds are propagated through left shift + + min(1 << i) = 1 << min(i) + = 1 << 0 + = 1 + """ + + @T.prim_func + def before(A: T.Buffer[16, "float32"]): + for i in T.serial(16): + if T.shift_left(1, i, dtype="int32") >= 1: + A[i] = 0.0 + + @T.prim_func + def expected(A: T.Buffer[16, "float32"]): + for i in T.serial(16): + A[i] = 0.0 + + +class TestLeftShiftUpperBound(BaseBeforeAfter): + """Integer bounds are propagated through left shift + + max(31 << i) = 31 << max(i) + = 31 << 15 + = 1015808 + """ + + @T.prim_func + def before(A: T.Buffer[16, "float32"]): + for i in T.serial(16): + if T.shift_left(31, i, dtype="int32") <= 1015808: + A[i] = 0.0 + + @T.prim_func + def expected(A: T.Buffer[16, "float32"]): + for i in T.serial(16): + A[i] = 0.0 + + if __name__ == "__main__": tvm.testing.main() From 66687e6522bdc4602dffb06dccaa10738c50c253 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 9 Jun 2022 15:17:12 -0500 Subject: [PATCH 3/6] Update to avoid left shift of negative numbers --- src/arith/const_int_bound.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 597da43f5423..235b35d837aa 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -346,6 +346,12 @@ class ConstIntBoundAnalyzer::Impl Entry VisitLeftShift(const CallNode* op) { Entry a = VisitExpr(op->args[0]); Entry b = VisitExpr(op->args[1]); + + // Until C++20, performing a left shift is only well-defined for + // positive arguments. If we have a negative argument, it just + // means we couldn't prove that the inputs were positive. + a.min_value = std::max(int64_t(0), a.min_value); + b.min_value = std::max(int64_t(0), b.min_value); return BinaryOpBoundary(a, b, InfAwareLeftShift); } From 86d8165b01a89ce58b95ee05bdec31856b74e4fe Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 10 Jun 2022 10:49:26 -0500 Subject: [PATCH 4/6] Updated rewriting of log2(x) to only occur in ceil(log2(x)) Per @wrongtest's request, to avoid rounding differences between different devices. --- src/arith/rewrite_simplify.cc | 27 ++++++++++++------- .../unittest/test_tir_transform_simplify.py | 12 --------- 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index ca6b88f0134c..769e58698e09 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -1641,16 +1641,23 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { return op->args[0] << op->args[1]; } } else if (op->op.same_as(Op::Get("tir.ceil"))) { - if (auto as_int = op->args[0].as()) { - return cast(op->dtype, IntImm(as_int->dtype, as_int->value)); - } else if (auto as_float = op->args[0].as()) { - return cast(op->dtype, FloatImm(as_float->dtype, std::ceil(as_float->value))); - } - } else if (op->op.same_as(Op::Get("tir.log2"))) { - if (auto as_int = op->args[0].as()) { - return cast(op->dtype, FloatImm(as_int->dtype, std::log2(as_int->value))); - } else if (auto as_float = op->args[0].as()) { - return cast(op->dtype, FloatImm(as_float->dtype, std::log2(as_float->value))); + PrimExpr ceil_arg = op->args[0]; + if (auto arg_int = op->args[0].as()) { + return cast(op->dtype, IntImm(arg_int->dtype, arg_int->value)); + } else if (auto arg_float = ceil_arg.as()) { + return cast(op->dtype, FloatImm(arg_float->dtype, std::ceil(arg_float->value))); + } else if (auto arg_call = ceil_arg.as()) { + // ceil(log2(cast(n,"float64"))) is used as the implementation of + // topi.math.ceil_log2, and appears in iteration bounds. + if (arg_call->op.same_as(Op::Get("tir.log2"))) { + PrimExpr log_arg = arg_call->args[0]; + if (auto as_float = log_arg.as()) { + // ceil(log2(n)) can be simplified, and should produce the + // same integer result regardless of the target's rounding + // conventions. + return FloatImm(op->dtype, std::ceil(std::log2(as_float->value))); + } + } } } diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py index b35b3766073a..09254bd6704e 100644 --- a/tests/python/unittest/test_tir_transform_simplify.py +++ b/tests/python/unittest/test_tir_transform_simplify.py @@ -391,18 +391,6 @@ def expected(A: T.Buffer[(16, 16), "int32"], n: T.int32): A[i, j] = 2 -class TestCeilLog2Float(BaseBeforeAfter): - """Simplify expressions resulting from topi.math.ceil_log2""" - - @T.prim_func - def before(A: T.Buffer[1, "float32"]): - A[0] = T.ceil(T.log2(14.0, dtype="float32"), dtype="float32") - - @T.prim_func - def expected(A: T.Buffer[1, "float32"]): - A[0] = 4.0 - - class TestCeilLog2Int(BaseBeforeAfter): """Simplify expressions resulting from topi.math.ceil_log2""" From 1e4e642408cc92352855d358147ef87bc48becfd Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 21 Jun 2022 09:05:00 -0500 Subject: [PATCH 5/6] Avoid assumptions made of negative arguments to left-shift --- src/arith/const_int_bound.cc | 11 +++--- .../unittest/test_tir_transform_simplify.py | 34 +++++++++++++++++++ 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 235b35d837aa..716ce74b67ba 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -347,11 +347,12 @@ class ConstIntBoundAnalyzer::Impl Entry a = VisitExpr(op->args[0]); Entry b = VisitExpr(op->args[1]); - // Until C++20, performing a left shift is only well-defined for - // positive arguments. If we have a negative argument, it just - // means we couldn't prove that the inputs were positive. - a.min_value = std::max(int64_t(0), a.min_value); - b.min_value = std::max(int64_t(0), b.min_value); + if (a.min_value < 0 || b.min_value < 0) { + // If either operand can negative, we may run into undefined + // behavior for some targets. In these cases, avoid making any + // assumptions about the result. + return Everything(op->dtype); + } return BinaryOpBoundary(a, b, InfAwareLeftShift); } diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py index 09254bd6704e..dba0cea6baa1 100644 --- a/tests/python/unittest/test_tir_transform_simplify.py +++ b/tests/python/unittest/test_tir_transform_simplify.py @@ -445,5 +445,39 @@ def expected(A: T.Buffer[16, "float32"]): A[i] = 0.0 +class TestLeftShiftOfNegativeValue(BaseBeforeAfter): + """No const int bounds of left shift of negative value. + + This is target dependent, and does not currently have a specified + behavior in TIR. For example, in CodeGenC, this generates C code + with undefined behavior. + """ + + @T.prim_func + def before(A: T.Buffer[16, "float32"]): + for i in T.serial(16): + if -64 <= T.shift_left(-i, 4, dtype="int32"): + A[i] = 0.0 + + expected = before + + +class TestLeftShiftByNegativeValue(BaseBeforeAfter): + """No const int bounds of left shift by negative bit count. + + This is target dependent, and does not currently have a specified + behavior in TIR. For example, in CodeGenC, this generates C code + with undefined behavior. + """ + + @T.prim_func + def before(A: T.Buffer[16, "float32"]): + for i in T.serial(16): + if T.shift_left(16, -i, dtype="int32") <= 16: + A[i] = 0.0 + + expected = before + + if __name__ == "__main__": tvm.testing.main() From 58b9b7e98cba950a7206fe70e42fb0707a88971e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 21 Jun 2022 09:16:41 -0500 Subject: [PATCH 6/6] Recognize bounds of int(ceil(log2(arg))) --- src/arith/const_int_bound.cc | 53 ++++++++++++++++++- .../unittest/test_tir_transform_simplify.py | 19 +++++++ 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 716ce74b67ba..cabf299a886b 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -177,7 +177,17 @@ class ConstIntBoundAnalyzer::Impl } Entry VisitExpr_(const CastNode* op) final { - Entry a = VisitExpr(op->value); + Entry a; + + // int(ceil(log2(cast(n,"float64")))) is used as the + // implementation of topi.math.ceil_log2, and appears in iteration + // bounds. + if (auto opt = FindCeilLog2Arg(op)) { + a = CeilLog2Bounds(opt.value()); + } else { + a = VisitExpr(op->value); + } + Entry b = Everything(op->dtype); return Intersect(a, b); } @@ -353,6 +363,7 @@ class ConstIntBoundAnalyzer::Impl // assumptions about the result. return Everything(op->dtype); } + return BinaryOpBoundary(a, b, InfAwareLeftShift); } @@ -650,6 +661,46 @@ class ConstIntBoundAnalyzer::Impl } return {}; } + + /*! + * \brief Extract the argument from int(ceil(log2(arg))) + * + * This expression is used as the implementation of + * topi.math.ceil_log2, and can appear in iteration bounds. + */ + static Optional FindCeilLog2Arg(const CastNode* op) { + if (op->dtype.is_int()) { + if (auto as_call = op->value.as()) { + if (as_call->op.same_as(Op::Get("tir.ceil"))) { + PrimExpr ceil_arg = as_call->args[0]; + if (auto arg_call = ceil_arg.as()) { + if (arg_call->op.same_as(Op::Get("tir.log2"))) { + PrimExpr log_arg = arg_call->args[0]; + return log_arg; + } + } + } + } + } + return NullOpt; + } + + /*! \brief Propagate constraints through ceil(log2(arg)) + * + * Helper function for CastNode visitor + */ + Entry CeilLog2Bounds(PrimExpr arg) { + if (auto as_float = arg.as()) { + // A cast from int to float may have already been simplified + // out. Normally we don't inspect floating-point arguments, but here we can + int64_t val = std::ceil(std::log2(as_float->value)); + return MakeBound(val, val); + } else { + Entry arg_bounds = VisitExpr(arg); + return MakeBound(std::ceil(std::log2(arg_bounds.min_value)), + std::ceil(std::log2(arg_bounds.max_value))); + } + } }; ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) const { diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py index dba0cea6baa1..49e8ee3f786d 100644 --- a/tests/python/unittest/test_tir_transform_simplify.py +++ b/tests/python/unittest/test_tir_transform_simplify.py @@ -405,6 +405,25 @@ def expected(A: T.Buffer[1, "int32"]): A[0] = 4 +class TestLeftCeilLog2LowerBound(BaseBeforeAfter): + """Integer bounds are propagated through topi.math.ceil_log2""" + + @T.prim_func + def before(A: T.Buffer[16, "float32"]): + for i in T.serial(16): + x = T.cast( + T.ceil(T.log2(T.cast(i + 1024 + 1, "float64"), dtype="float64"), dtype="float64"), + dtype="int32", + ) + if x == 11: + A[i] = 0.0 + + @T.prim_func + def expected(A: T.Buffer[16, "float32"]): + for i in T.serial(16): + A[i] = 0.0 + + class TestLeftShiftLowerBound(BaseBeforeAfter): """Integer bounds are propagated through left shift