diff --git a/exir/delegate.py b/exir/delegate.py index de2be14a8e3..77049a8848b 100644 --- a/exir/delegate.py +++ b/exir/delegate.py @@ -123,44 +123,13 @@ def call_delegate_fake_tensor_mode(mode, lowered_module, *args): return lowered_module.original_module(*args) -@executorch_call_delegate.py_impl(torch._C.DispatchKey.Functionalize) +@executorch_call_delegate.py_functionalize_impl # pyre-ignore -def call_delegate_func(lowered_module, *args): - reapply_views = torch._C._functionalization_reapply_views_tls() - # At this point, we will see functionalized tensors, so need to unwrap them first - unwrapped_args = tuple( - _unwrap_all_tensors_from_functional(arg, reapply_views=reapply_views) - for arg in args - ) - guard = torch._C.ExcludeDispatchKeyGuard( - torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) - ) - try: - delegate_return = executorch_call_delegate(lowered_module, *unwrapped_args) - return _wrap_all_tensors_to_functional(delegate_return, level=0) - finally: - del guard - - -# pyre-ignore -@executorch_call_delegate.py_impl(torch._C._functorch.TransformType.Functionalize) -# pyre-ignore -def call_delegate_functionalize(interpreter, lowered_module, *args): - """ - Functionalization implementation for torch.ops.executorch_call_delegate. We - don't need to do anything since the delegated program is controlled by - users. - """ - reapply_views = interpreter.functionalize_add_back_views() - # At this point, we will see functionalized tensors, so need to unwrap them first - unwrapped_args = tuple( - _unwrap_all_tensors_from_functional(arg, reapply_views=reapply_views) - for arg in args - ) - - with interpreter.lower(): +def call_delegate_functionalize(ctx, lowered_module, *args): + unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args) + with ctx.redispatch_to_next(): res = executorch_call_delegate(lowered_module, *unwrapped_args) - return _wrap_all_tensors_to_functional(res, level=interpreter.level()) + return ctx.wrap_tensors(res) # pyre-ignore: Missing parameter annotation [2]: Parameter `obj` must have a type other than `Any`.Pyre