From 2c2d653f805a7b324e2e8b38cf59a2a16c11000c Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 5 Jul 2023 15:04:20 -0500 Subject: [PATCH] [TIR] Allow VerifyWellFormed to accept IRModule Previously, the calling code needed to iterate over all functions in a module. This commit adds an overload that accepts `const IRModule&`, allowing it to be called more easily. This also provides an API that can be extended to validate behavior across an entire IRModule (e.g. requiring that internal function calls have the correct argument types). --- python/tvm/tir/analysis/analysis.py | 8 +++--- src/tir/analysis/verify_well_formed.cc | 25 ++++++++++++++++++- .../test_tir_analysis_verify_well_formed.py | 1 + 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index 493c3d957b28..8d7e81d7d0d8 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -349,14 +349,14 @@ def apply_prim_func_arg_and_result_memory_constraints( ) -def verify_well_formed(func: PrimFunc, assert_mode: bool = True) -> bool: +def verify_well_formed(obj: Union[PrimFunc, IRModule], assert_mode: bool = True) -> bool: """Verify if the given TIR is well-formed. The verification includes: - Check if expressions not contain vars that is defined outside the block. Parameters ---------- - func: tvm.tir.PrimFunc - The function to be verified. + obj: Union[tvm.tir.PrimFunc, tvm.ir.IRModule] + The function or module to be verified. assert_mode: bool The indicator if it raises an error when the function is not well-formed. @@ -366,7 +366,7 @@ def verify_well_formed(func: PrimFunc, assert_mode: bool = True) -> bool: result: bool Whether it is a well-formed TIR function. """ - return _ffi_api.VerifyWellFormed(func, assert_mode) # type: ignore # pylint: disable=no-member + return _ffi_api.VerifyWellFormed(obj, assert_mode) # type: ignore # pylint: disable=no-member def OOBChecker(): diff --git a/src/tir/analysis/verify_well_formed.cc b/src/tir/analysis/verify_well_formed.cc index e0318e14080f..898183533ccd 100644 --- a/src/tir/analysis/verify_well_formed.cc +++ b/src/tir/analysis/verify_well_formed.cc @@ -27,6 +27,7 @@ #include #include "../ir/functor_common.h" +#include "tvm/ir/module.h" namespace tvm { namespace tir { @@ -142,7 +143,29 @@ bool VerifyWellFormed(const PrimFunc& func, bool assert_mode) { return true; } -TVM_REGISTER_GLOBAL("tir.analysis.VerifyWellFormed").set_body_typed(VerifyWellFormed); +bool VerifyWellFormed(const IRModule& mod, bool assert_mode) { + for (const auto& [gvar, base_func] : mod->functions) { + if (auto prim_func = base_func.as()) { + bool res = VerifyWellFormed(prim_func.value(), assert_mode); + if (!res) { + return false; + } + } + } + return true; +} + +TVM_REGISTER_GLOBAL("tir.analysis.VerifyWellFormed") + .set_body_typed([](const ObjectRef& obj, bool assert_mode) { + if (auto opt = obj.as()) { + return VerifyWellFormed(opt.value(), assert_mode); + } else if (auto opt = obj.as()) { + return VerifyWellFormed(opt.value(), assert_mode); + } else { + LOG(FATAL) << "Expected VerifyWellFormed argument to be a PrimFunc or IRModule, but found " + << obj->GetTypeKey(); + } + }); } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_tir_analysis_verify_well_formed.py b/tests/python/unittest/test_tir_analysis_verify_well_formed.py index 023d5f5f315c..4f88cc8be1e1 100644 --- a/tests/python/unittest/test_tir_analysis_verify_well_formed.py +++ b/tests/python/unittest/test_tir_analysis_verify_well_formed.py @@ -36,6 +36,7 @@ def element_wise( C[i, j] = B[i, j] * 2.0 assert tvm.tir.analysis.verify_well_formed(element_wise) + assert tvm.tir.analysis.verify_well_formed(tvm.IRModule.from_expr(element_wise)) def test_fail_use_out_loop_var():