diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index a6a21ea9402a..22f815c3d812 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -862,6 +862,8 @@ def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame: # pylint: d res : frame.AssertFrame The result AssertFrame. """ + if isinstance(condition, bool): + condition = IntImm("bool", condition) return _ffi_api.Assert(condition, message) # type: ignore[attr-defined] # pylint: disable=no-member diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index b32b9b6c4584..1d1e674a9dd1 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -80,6 +80,9 @@ TVM_REGISTER_NODE_TYPE(AttrStmtNode); // AssertStmt AssertStmt::AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span) { ICHECK(condition.defined()); + CHECK(condition.dtype().is_bool()) + << "AssertStmt should have boolean condition, " + << "but received " << condition << " with dtype " << condition.dtype(); ICHECK(message.dtype() == DataType::Int(32) || message.as()) << "TypeError: AssertStmt message must be an int or string:" << message << "\n"; diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index dd9d471c5066..825a8da45b27 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -42,6 +42,7 @@ namespace tir { static constexpr const char* kDeviceContextVar = "device_api_context"; +namespace { class ReturnRewriter : public StmtMutator { public: explicit ReturnRewriter(Var ret_var, Var ret_tcode) : ret_var_(ret_var), ret_tcode_(ret_tcode) {} @@ -176,6 +177,8 @@ class SubroutineCallRewriter : public StmtExprMutator { bool made_change_{false}; }; +} // namespace + inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0)); } diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index 82685411f592..bdb3a953e99c 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -40,6 +40,8 @@ namespace tvm { namespace tir { +namespace { + class SubroutineCallRewriter : public StmtExprMutator { public: static Optional Apply(const std::unordered_set& external_methods, @@ -84,6 +86,8 @@ class SubroutineCallRewriter : public StmtExprMutator { bool made_change_{false}; }; +} // namespace + PrimFunc MakeUnpackedAPI(PrimFunc func) { // A function with an explicit calling convention has already been // lowered, and should not be modified. diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index d78ba70f0919..e6334553d64f 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -277,13 +277,13 @@ def test_attr_stmt(): def test_assert_stmt(): with IRBuilder() as ib: - with T.Assert(1, "assertion"): + with T.Assert(True, "assertion"): T.evaluate(0) obj = ib.get() _assert_print( obj, """ -with T.Assert(1, "assertion"): +with T.Assert(T.bool(True), "assertion"): T.evaluate(0) """, ) diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py index 1ff5be80cabc..41262a6669a3 100644 --- a/tests/python/unittest/test_tvmscript_syntax_sugar.py +++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py @@ -452,5 +452,40 @@ def func(): assert var_name == "j" +def test_boolean_constant(): + """Python booleans should become T.Bool objects""" + + @T.prim_func + def explicit(): + T.evaluate(T.bool(True)) + + @T.prim_func + def implicit(): + T.evaluate(True) + + assert_structural_equal(implicit, explicit) + + +def test_foldable_boolean_in_assert(): + """Foldable booleans T.Bool objects + + The condition of an assert statement should be a boolean + expression. Previously, this test failed because the FFI does not + distinguish between integer primitives and boolean primitives. + """ + + @T.prim_func + def explicit(): + assert T.bool(False), "Message" + T.evaluate(0) + + @T.prim_func + def implicit(): + assert 0 == 1, "Message" + T.evaluate(0) + + assert_structural_equal(implicit, explicit) + + if __name__ == "__main__": tvm.testing.main()