From fe3de55c11f3ce4369151987b53dbc555ef6d689 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Tue, 3 Oct 2023 06:01:37 -0700 Subject: [PATCH] internal fixes for FunctionalTensorMode usage in AOTAutograd (#538) Summary: Fixes needed to properly land https://github.com/pytorch/pytorch/pull/110079 internally (1) executorch has a higher order op that requires a functionalization rule (2) s-curve export still has an internal flow that calls some AOTAutograd API's, but also manually makes some calls to C++ funcitonalization. I changed them to use python functionalization. Reviewed By: zou3519, tugsbayasgalan Differential Revision: D49657241 --- exir/delegate.py | 41 +++++------------------------------------ 1 file changed, 5 insertions(+), 36 deletions(-) diff --git a/exir/delegate.py b/exir/delegate.py index de2be14a8e3..77049a8848b 100644 --- a/exir/delegate.py +++ b/exir/delegate.py @@ -123,44 +123,13 @@ def call_delegate_fake_tensor_mode(mode, lowered_module, *args): return lowered_module.original_module(*args) -@executorch_call_delegate.py_impl(torch._C.DispatchKey.Functionalize) +@executorch_call_delegate.py_functionalize_impl # pyre-ignore -def call_delegate_func(lowered_module, *args): - reapply_views = torch._C._functionalization_reapply_views_tls() - # At this point, we will see functionalized tensors, so need to unwrap them first - unwrapped_args = tuple( - _unwrap_all_tensors_from_functional(arg, reapply_views=reapply_views) - for arg in args - ) - guard = torch._C.ExcludeDispatchKeyGuard( - torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) - ) - try: - delegate_return = executorch_call_delegate(lowered_module, *unwrapped_args) - return _wrap_all_tensors_to_functional(delegate_return, level=0) - finally: - del guard - - -# pyre-ignore -@executorch_call_delegate.py_impl(torch._C._functorch.TransformType.Functionalize) -# pyre-ignore -def call_delegate_functionalize(interpreter, lowered_module, *args): - """ - Functionalization implementation for torch.ops.executorch_call_delegate. We - don't need to do anything since the delegated program is controlled by - users. - """ - reapply_views = interpreter.functionalize_add_back_views() - # At this point, we will see functionalized tensors, so need to unwrap them first - unwrapped_args = tuple( - _unwrap_all_tensors_from_functional(arg, reapply_views=reapply_views) - for arg in args - ) - - with interpreter.lower(): +def call_delegate_functionalize(ctx, lowered_module, *args): + unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args) + with ctx.redispatch_to_next(): res = executorch_call_delegate(lowered_module, *unwrapped_args) - return _wrap_all_tensors_to_functional(res, level=interpreter.level()) + return ctx.wrap_tensors(res) # pyre-ignore: Missing parameter annotation [2]: Parameter `obj` must have a type other than `Any`.Pyre