From 210398c59b946404412bbd993f247d192c0afeed Mon Sep 17 00:00:00 2001 From: Rafael Stahl Date: Fri, 24 Jun 2022 15:59:22 +0200 Subject: [PATCH 1/5] [Relay][VirtualDevice] Expose WithFields to Python to do proper copy in ExprMutator --- python/tvm/relay/expr_functor.py | 13 +++++++++++-- python/tvm/relay/function.py | 7 +++++++ src/relay/ir/function.cc | 6 ++++++ 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py index b9ca7d0e11f2..2093eb7b49a8 100644 --- a/python/tvm/relay/expr_functor.py +++ b/python/tvm/relay/expr_functor.py @@ -18,7 +18,7 @@ """The expression functor of Relay.""" from tvm.ir import Op -from .function import Function +from .function import Function, FunctionWithFields from .expr import Call, Let, Var, GlobalVar from .expr import If, Tuple, TupleGetItem, Constant from .expr import RefCreate, RefRead, RefWrite @@ -204,7 +204,16 @@ class ExprMutator(ExprFunctor): def visit_function(self, fn): new_params = [self.visit(x) for x in fn.params] new_body = self.visit(fn.body) - return Function(list(new_params), new_body, fn.ret_type, fn.type_params, fn.attrs) + return FunctionWithFields( + fn, + list(new_params), + new_body, + fn.ret_type, + fn.type_params, + fn.attrs, + fn.virtual_device_, + fn.span, + ) def visit_let(self, let): new_var = self.visit(let.var) diff --git a/python/tvm/relay/function.py b/python/tvm/relay/function.py index f889f1e596ef..e4c13e75dd61 100644 --- a/python/tvm/relay/function.py +++ b/python/tvm/relay/function.py @@ -63,3 +63,10 @@ def __call__(self, *args): Arguments. """ return Call(self, args, None, None) + + +@tvm._ffi.register_func("relay.FunctionWithFields") +def FunctionWithFields(function, params, body, ret_type, ty_params, attrs, virtual_device, span): + return _ffi_api.FunctionWithFields( + function, params, body, ret_type, ty_params, attrs, virtual_device, span + ) diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index 63e74144e061..858b377f06e4 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -127,6 +127,12 @@ TVM_REGISTER_GLOBAL("relay.ir.Function") tvm::Array ty_params, tvm::DictAttrs attrs) { return Function(params, body, ret_type, ty_params, attrs); }); +TVM_REGISTER_GLOBAL("relay.ir.FunctionWithFields") + .set_body_typed([](Function function, tvm::Array params, Expr body, Type ret_type, + tvm::Array ty_params, tvm::DictAttrs attrs, + VirtualDevice virtual_device, Span span) { + return WithFields(function, params, body, ret_type, ty_params, attrs, virtual_device, span); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { From ad5dcd7e816d628ee0b0dcad4fc6db1565d9e596 Mon Sep 17 00:00:00 2001 From: Rafael Stahl Date: Mon, 27 Jun 2022 22:28:38 +0200 Subject: [PATCH 2/5] [Relay] give FunctionWithFields optional arguments --- python/tvm/relay/expr_functor.py | 5 ----- python/tvm/relay/function.py | 8 ++++---- src/relay/ir/function.cc | 10 ++++++---- 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py index 2093eb7b49a8..ebea344b413a 100644 --- a/python/tvm/relay/expr_functor.py +++ b/python/tvm/relay/expr_functor.py @@ -208,11 +208,6 @@ def visit_function(self, fn): fn, list(new_params), new_body, - fn.ret_type, - fn.type_params, - fn.attrs, - fn.virtual_device_, - fn.span, ) def visit_let(self, let): diff --git a/python/tvm/relay/function.py b/python/tvm/relay/function.py index e4c13e75dd61..ac5b3db29de1 100644 --- a/python/tvm/relay/function.py +++ b/python/tvm/relay/function.py @@ -66,7 +66,7 @@ def __call__(self, *args): @tvm._ffi.register_func("relay.FunctionWithFields") -def FunctionWithFields(function, params, body, ret_type, ty_params, attrs, virtual_device, span): - return _ffi_api.FunctionWithFields( - function, params, body, ret_type, ty_params, attrs, virtual_device, span - ) +def FunctionWithFields( + function, params=None, body=None, ret_type=None, ty_params=None, attrs=None, virtual_device=None, span=None +): + return _ffi_api.FunctionWithFields(function, params, body, ret_type, ty_params, attrs, virtual_device, span) diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index 858b377f06e4..1a3db9974f05 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -128,10 +128,12 @@ TVM_REGISTER_GLOBAL("relay.ir.Function") return Function(params, body, ret_type, ty_params, attrs); }); TVM_REGISTER_GLOBAL("relay.ir.FunctionWithFields") - .set_body_typed([](Function function, tvm::Array params, Expr body, Type ret_type, - tvm::Array ty_params, tvm::DictAttrs attrs, - VirtualDevice virtual_device, Span span) { - return WithFields(function, params, body, ret_type, ty_params, attrs, virtual_device, span); + .set_body_typed([](Function function, Optional> opt_params, Optional opt_body, + Optional opt_ret_type, Optional> opt_ty_params, + Optional opt_attrs, Optional opt_virtual_device, + Optional opt_span) { + return WithFields(function, opt_params, opt_body, opt_ret_type, opt_ty_params, opt_attrs, + opt_virtual_device, opt_span); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) From 38ad5944ed879f2856ed99e72d703314ae4546d9 Mon Sep 17 00:00:00 2001 From: Rafael Stahl Date: Tue, 28 Jun 2022 11:24:35 +0200 Subject: [PATCH 3/5] [lint] fix wrong line length --- python/tvm/relay/function.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/function.py b/python/tvm/relay/function.py index ac5b3db29de1..6a047b5a91b2 100644 --- a/python/tvm/relay/function.py +++ b/python/tvm/relay/function.py @@ -67,6 +67,15 @@ def __call__(self, *args): @tvm._ffi.register_func("relay.FunctionWithFields") def FunctionWithFields( - function, params=None, body=None, ret_type=None, ty_params=None, attrs=None, virtual_device=None, span=None + function, + params=None, + body=None, + ret_type=None, + ty_params=None, + attrs=None, + virtual_device=None, + span=None, ): - return _ffi_api.FunctionWithFields(function, params, body, ret_type, ty_params, attrs, virtual_device, span) + return _ffi_api.FunctionWithFields( + function, params, body, ret_type, ty_params, attrs, virtual_device, span + ) \ No newline at end of file From be66322ef77508355c812374c8142d0f8a853ea0 Mon Sep 17 00:00:00 2001 From: Rafael Stahl Date: Tue, 28 Jun 2022 12:01:03 +0200 Subject: [PATCH 4/5] [lint] missing newline --- python/tvm/relay/function.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/function.py b/python/tvm/relay/function.py index 6a047b5a91b2..18d1fd240ba3 100644 --- a/python/tvm/relay/function.py +++ b/python/tvm/relay/function.py @@ -78,4 +78,4 @@ def FunctionWithFields( ): return _ffi_api.FunctionWithFields( function, params, body, ret_type, ty_params, attrs, virtual_device, span - ) \ No newline at end of file + ) From f3c8cc9249e258eaf5c2af742414ab520b50b2df Mon Sep 17 00:00:00 2001 From: Rafael Stahl Date: Tue, 28 Jun 2022 12:31:32 +0200 Subject: [PATCH 5/5] [doc] add doc string to FunctionWithFields --- python/tvm/relay/function.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/tvm/relay/function.py b/python/tvm/relay/function.py index 18d1fd240ba3..6b3513cb5e1a 100644 --- a/python/tvm/relay/function.py +++ b/python/tvm/relay/function.py @@ -76,6 +76,11 @@ def FunctionWithFields( virtual_device=None, span=None, ): + """ + Returns function with the given properties. A None property denotes 'no change'. + Returns function if all properties are unchanged. Otherwise, returns a copy with the new + fields. + """ return _ffi_api.FunctionWithFields( function, params, body, ret_type, ty_params, attrs, virtual_device, span )