From 65ac46a66cd3f350db4dae04d0e8c93431266ed4 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 25 May 2023 09:36:51 -0500 Subject: [PATCH 1/3] [Bugfix] Avoid symbol conflicts in MakePackedAPI/MakeUnpackedAPI PRs https://github.com/apache/tvm/pull/14913 and https://github.com/apache/tvm/pull/14914 made analogous changes to `MakePackedAPI` and `MakeUnpackedAPI` to handle subroutine calls. Both PRs introduced the same symbol, `tvm::tir::SubroutineCallRewriter`, a local utility to update internal calls to a modified function. While each PR passed CI individually, and was therefore able to merge, having both changes caused a duplicate symbol. This commit updates `MakePackedAPI` and `MakeUnpackedAPI` to place their local utilities into anonymous namespaces, avoiding the conflict. --- src/tir/transforms/make_packed_api.cc | 3 +++ src/tir/transforms/make_unpacked_api.cc | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index dd9d471c5066..825a8da45b27 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -42,6 +42,7 @@ namespace tir { static constexpr const char* kDeviceContextVar = "device_api_context"; +namespace { class ReturnRewriter : public StmtMutator { public: explicit ReturnRewriter(Var ret_var, Var ret_tcode) : ret_var_(ret_var), ret_tcode_(ret_tcode) {} @@ -176,6 +177,8 @@ class SubroutineCallRewriter : public StmtExprMutator { bool made_change_{false}; }; +} // namespace + inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0)); } diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index 82685411f592..bdb3a953e99c 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -40,6 +40,8 @@ namespace tvm { namespace tir { +namespace { + class SubroutineCallRewriter : public StmtExprMutator { public: static Optional Apply(const std::unordered_set& external_methods, @@ -84,6 +86,8 @@ class SubroutineCallRewriter : public StmtExprMutator { bool made_change_{false}; }; +} // namespace + PrimFunc MakeUnpackedAPI(PrimFunc func) { // A function with an explicit calling convention has already been // lowered, and should not be modified. From b3feac6c22010c1235f77bc1f28ab669472fefeb Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 24 May 2023 14:58:24 -0500 Subject: [PATCH 2/3] [TVMScript] Prevent bool to int conversion in T.Assert condition Previously, while literal `True` and `False` values were converted to `tvm::Bool` instances, constant-foldable expressions (e.g. `0 == 1`) would be evaluated to `True`, but were then passed directly to the FFI. Because the FFI uses the same representation for integer and boolean values, the conversion to `PrimExpr` resulted in a `tvm::Integer` instead of `tvm::Bool`. This commit converts the argument of `T.Assert` to a `tvm::Bool` before calling the FFI, avoiding the ambiguity. In addition, the `AssertStmt` constructor now validates the datatype of the condition, to prevent it from re-occurring. This was first caught in the unit test `test_debug_info.py::test_llvm_ir_debug_info`, which failed on some versions of LLVM due to the use of `i32` as the condition of an assert. --- python/tvm/script/ir_builder/tir/ir.py | 2 ++ src/tir/ir/stmt.cc | 3 ++ .../unittest/test_tvmscript_syntax_sugar.py | 35 +++++++++++++++++++ 3 files changed, 40 insertions(+) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index a6a21ea9402a..22f815c3d812 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -862,6 +862,8 @@ def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame: # pylint: d res : frame.AssertFrame The result AssertFrame. """ + if isinstance(condition, bool): + condition = IntImm("bool", condition) return _ffi_api.Assert(condition, message) # type: ignore[attr-defined] # pylint: disable=no-member diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index b32b9b6c4584..1d1e674a9dd1 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -80,6 +80,9 @@ TVM_REGISTER_NODE_TYPE(AttrStmtNode); // AssertStmt AssertStmt::AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span) { ICHECK(condition.defined()); + CHECK(condition.dtype().is_bool()) + << "AssertStmt should have boolean condition, " + << "but received " << condition << " with dtype " << condition.dtype(); ICHECK(message.dtype() == DataType::Int(32) || message.as()) << "TypeError: AssertStmt message must be an int or string:" << message << "\n"; diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py index 1ff5be80cabc..41262a6669a3 100644 --- a/tests/python/unittest/test_tvmscript_syntax_sugar.py +++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py @@ -452,5 +452,40 @@ def func(): assert var_name == "j" +def test_boolean_constant(): + """Python booleans should become T.Bool objects""" + + @T.prim_func + def explicit(): + T.evaluate(T.bool(True)) + + @T.prim_func + def implicit(): + T.evaluate(True) + + assert_structural_equal(implicit, explicit) + + +def test_foldable_boolean_in_assert(): + """Foldable booleans T.Bool objects + + The condition of an assert statement should be a boolean + expression. Previously, this test failed because the FFI does not + distinguish between integer primitives and boolean primitives. + """ + + @T.prim_func + def explicit(): + assert T.bool(False), "Message" + T.evaluate(0) + + @T.prim_func + def implicit(): + assert 0 == 1, "Message" + T.evaluate(0) + + assert_structural_equal(implicit, explicit) + + if __name__ == "__main__": tvm.testing.main() From edc95a773251eab5d5d0da8122ef980440dff820 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 25 May 2023 07:53:24 -0500 Subject: [PATCH 3/3] Updated TVMScript printer unit test with boolean condition --- tests/python/unittest/test_tvmscript_printer_tir.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index d78ba70f0919..e6334553d64f 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -277,13 +277,13 @@ def test_attr_stmt(): def test_assert_stmt(): with IRBuilder() as ib: - with T.Assert(1, "assertion"): + with T.Assert(True, "assertion"): T.evaluate(0) obj = ib.get() _assert_print( obj, """ -with T.Assert(1, "assertion"): +with T.Assert(T.bool(True), "assertion"): T.evaluate(0) """, )