diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py index b9ca7d0e11f2..ebea344b413a 100644 --- a/python/tvm/relay/expr_functor.py +++ b/python/tvm/relay/expr_functor.py @@ -18,7 +18,7 @@ """The expression functor of Relay.""" from tvm.ir import Op -from .function import Function +from .function import Function, FunctionWithFields from .expr import Call, Let, Var, GlobalVar from .expr import If, Tuple, TupleGetItem, Constant from .expr import RefCreate, RefRead, RefWrite @@ -204,7 +204,11 @@ class ExprMutator(ExprFunctor): def visit_function(self, fn): new_params = [self.visit(x) for x in fn.params] new_body = self.visit(fn.body) - return Function(list(new_params), new_body, fn.ret_type, fn.type_params, fn.attrs) + return FunctionWithFields( + fn, + list(new_params), + new_body, + ) def visit_let(self, let): new_var = self.visit(let.var) diff --git a/python/tvm/relay/function.py b/python/tvm/relay/function.py index f889f1e596ef..6b3513cb5e1a 100644 --- a/python/tvm/relay/function.py +++ b/python/tvm/relay/function.py @@ -63,3 +63,24 @@ def __call__(self, *args): Arguments. """ return Call(self, args, None, None) + + +@tvm._ffi.register_func("relay.FunctionWithFields") +def FunctionWithFields( + function, + params=None, + body=None, + ret_type=None, + ty_params=None, + attrs=None, + virtual_device=None, + span=None, +): + """ + Returns function with the given properties. A None property denotes 'no change'. + Returns function if all properties are unchanged. Otherwise, returns a copy with the new + fields. + """ + return _ffi_api.FunctionWithFields( + function, params, body, ret_type, ty_params, attrs, virtual_device, span + ) diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index 63e74144e061..1a3db9974f05 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -127,6 +127,14 @@ TVM_REGISTER_GLOBAL("relay.ir.Function") tvm::Array ty_params, tvm::DictAttrs attrs) { return Function(params, body, ret_type, ty_params, attrs); }); +TVM_REGISTER_GLOBAL("relay.ir.FunctionWithFields") + .set_body_typed([](Function function, Optional> opt_params, Optional opt_body, + Optional opt_ret_type, Optional> opt_ty_params, + Optional opt_attrs, Optional opt_virtual_device, + Optional opt_span) { + return WithFields(function, opt_params, opt_body, opt_ret_type, opt_ty_params, opt_attrs, + opt_virtual_device, opt_span); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) {