diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index 493c3d957b28..8d7e81d7d0d8 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -349,14 +349,14 @@ def apply_prim_func_arg_and_result_memory_constraints( ) -def verify_well_formed(func: PrimFunc, assert_mode: bool = True) -> bool: +def verify_well_formed(obj: Union[PrimFunc, IRModule], assert_mode: bool = True) -> bool: """Verify if the given TIR is well-formed. The verification includes: - Check if expressions not contain vars that is defined outside the block. Parameters ---------- - func: tvm.tir.PrimFunc - The function to be verified. + obj: Union[tvm.tir.PrimFunc, tvm.ir.IRModule] + The function or module to be verified. assert_mode: bool The indicator if it raises an error when the function is not well-formed. @@ -366,7 +366,7 @@ def verify_well_formed(func: PrimFunc, assert_mode: bool = True) -> bool: result: bool Whether it is a well-formed TIR function. """ - return _ffi_api.VerifyWellFormed(func, assert_mode) # type: ignore # pylint: disable=no-member + return _ffi_api.VerifyWellFormed(obj, assert_mode) # type: ignore # pylint: disable=no-member def OOBChecker(): diff --git a/src/tir/analysis/verify_well_formed.cc b/src/tir/analysis/verify_well_formed.cc index e0318e14080f..898183533ccd 100644 --- a/src/tir/analysis/verify_well_formed.cc +++ b/src/tir/analysis/verify_well_formed.cc @@ -27,6 +27,7 @@ #include #include "../ir/functor_common.h" +#include "tvm/ir/module.h" namespace tvm { namespace tir { @@ -142,7 +143,29 @@ bool VerifyWellFormed(const PrimFunc& func, bool assert_mode) { return true; } -TVM_REGISTER_GLOBAL("tir.analysis.VerifyWellFormed").set_body_typed(VerifyWellFormed); +bool VerifyWellFormed(const IRModule& mod, bool assert_mode) { + for (const auto& [gvar, base_func] : mod->functions) { + if (auto prim_func = base_func.as()) { + bool res = VerifyWellFormed(prim_func.value(), assert_mode); + if (!res) { + return false; + } + } + } + return true; +} + +TVM_REGISTER_GLOBAL("tir.analysis.VerifyWellFormed") + .set_body_typed([](const ObjectRef& obj, bool assert_mode) { + if (auto opt = obj.as()) { + return VerifyWellFormed(opt.value(), assert_mode); + } else if (auto opt = obj.as()) { + return VerifyWellFormed(opt.value(), assert_mode); + } else { + LOG(FATAL) << "Expected VerifyWellFormed argument to be a PrimFunc or IRModule, but found " + << obj->GetTypeKey(); + } + }); } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_tir_analysis_verify_well_formed.py b/tests/python/unittest/test_tir_analysis_verify_well_formed.py index 023d5f5f315c..4f88cc8be1e1 100644 --- a/tests/python/unittest/test_tir_analysis_verify_well_formed.py +++ b/tests/python/unittest/test_tir_analysis_verify_well_formed.py @@ -36,6 +36,7 @@ def element_wise( C[i, j] = B[i, j] * 2.0 assert tvm.tir.analysis.verify_well_formed(element_wise) + assert tvm.tir.analysis.verify_well_formed(tvm.IRModule.from_expr(element_wise)) def test_fail_use_out_loop_var():