From fe3b65f1dc6da4a9be608dde0e24379dfb2691eb Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 22 Feb 2023 09:22:28 -0600 Subject: [PATCH] [Transform] Use callable() instead of isinstance() for type checking Previously, type-checking of a callable arguments, such as to `tvm.ir.transform.module_pass`, was done using `isinstance(arg, (types.FunctionType, types.LambdaType))`. This check can give false negatives for valid python types, such as a bound method or an instance of a class that implements `__call__`. This commit replaces the checks with the builtin function `callable()`, which handles any Python object that can be called using function-like syntax. --- python/tvm/ir/transform.py | 3 +-- python/tvm/relay/transform/transform.py | 2 +- python/tvm/te/hybrid/parser.py | 3 +-- python/tvm/tir/transform/function_pass.py | 3 +-- 4 files changed, 4 insertions(+), 7 deletions(-) diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index 17995bfa7850..f7d40dc68147 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=invalid-name,unused-argument """Common pass infrastructure across IR variants.""" -import types import inspect import functools @@ -340,7 +339,7 @@ def create_module_pass(pass_arg): info = PassInfo(opt_level, fname, required) if inspect.isclass(pass_arg): return _wrap_class_module_pass(pass_arg, info) - if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): + if not callable(pass_arg): raise TypeError("pass_func must be a callable for Module pass") return _ffi_transform_api.MakeModulePass(pass_arg, info) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 1f5b91da4432..4c609620cbb7 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1049,7 +1049,7 @@ def create_function_pass(pass_arg): info = tvm.transform.PassInfo(opt_level, fname, required) if inspect.isclass(pass_arg): return _wrap_class_function_pass(pass_arg, info) - if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): + if not callable(pass_arg): raise TypeError("pass_func must be a callable for Module pass") return _ffi_api.MakeFunctionPass(pass_arg, info) diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py index ec103ac18811..bd47e416305f 100644 --- a/python/tvm/te/hybrid/parser.py +++ b/python/tvm/te/hybrid/parser.py @@ -20,7 +20,6 @@ import operator import logging import sys -import types import numbers from enum import Enum @@ -142,7 +141,7 @@ def __init__(self, args, usage, symbols, closure_vars, func_name=None): self.symbols = {} # Symbol table for k, v in symbols.items(): - if isinstance(v, types.FunctionType): + if callable(v): self.add_symbol(k, Symbol.Callable, v) self.closure_vars = closure_vars diff --git a/python/tvm/tir/transform/function_pass.py b/python/tvm/tir/transform/function_pass.py index 9450ade34e67..9fa0e3bc181f 100644 --- a/python/tvm/tir/transform/function_pass.py +++ b/python/tvm/tir/transform/function_pass.py @@ -16,7 +16,6 @@ # under the License. """TIR specific function pass support.""" import inspect -import types import functools from typing import Callable, List, Optional, Union @@ -151,7 +150,7 @@ def create_function_pass(pass_arg): info = PassInfo(opt_level, fname, required) if inspect.isclass(pass_arg): return _wrap_class_function_pass(pass_arg, info) - if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): + if not callable(pass_arg): raise TypeError("pass_func must be a callable for Module pass") return _ffi_api.CreatePrimFuncPass(pass_arg, info) # type: ignore