diff --git a/exir/passes/__init__.py b/exir/passes/__init__.py index d46e5826081..777b2a1c866 100644 --- a/exir/passes/__init__.py +++ b/exir/passes/__init__.py @@ -24,6 +24,7 @@ from executorch.exir.dynamic_shape import DynamicMemoryPlanningMode from executorch.exir.error import InternalError from executorch.exir.operator.convert import ( + _get_overload_schema, get_out_args_from_opoverload, is_out_variant, to_out_variant, @@ -63,6 +64,7 @@ from torch._subclasses import FakeTensor from torch.fx.passes.infra.pass_base import PassBase, PassResult from torch.fx.passes.shape_prop import TensorMetadata +from torchgen.model import SchemaKind __all__ = [ "ExportPass", @@ -257,7 +259,6 @@ def callWithLoggerEnabled(self, graph_module: torch.fx.GraphModule) -> None: memory.alloc, memory.view, executorch_call_delegate, - torch.ops.aten.copy_.default, } to_out_var_skiplist.update(_EXECUTORCH_SYM_OPS) @@ -347,6 +348,8 @@ def get_submodule(node: torch.fx.Node) -> torch.fx.GraphModule: continue elif target in to_out_var_skiplist: continue + elif _get_overload_schema(target).kind() == SchemaKind.inplace: + continue if not isinstance( target, (torch._ops.OpOverload, EdgeOpOverload, BackendOpOverload) ):