diff --git a/backends/cadence/aot/remove_ops.py b/backends/cadence/aot/remove_ops.py index 289c5ffeeec..a47422d3dc6 100644 --- a/backends/cadence/aot/remove_ops.py +++ b/backends/cadence/aot/remove_ops.py @@ -47,19 +47,16 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class RemoveDetachCopyPass(ExportPass): - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.aten.detach_copy.default: - return super().call_operator(op, args, kwargs, meta) +class RemoveDetachCopyPass(RemoveOrReplacePassInterface): + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.detach_copy.default] - assert len(args) == 1 - return cast(ProxyValue, args[0]) + def maybe_remove_or_replace(self, node: Node) -> bool: + input_node = node.args[0] + assert isinstance(input_node, Node) + node.replace_all_uses_with(input_node) + return True # The following class consolidates passes to remove ops that are redundant: @@ -114,53 +111,43 @@ def call_operator( @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class RemoveNopExpandOpPass(ExportPass): +class RemoveNopExpandOpPass(RemoveOrReplacePassInterface): """ For an expand op, if the operator shape matches the expand shape, then the expand is a nop. """ - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if get_edge_overload_packet(op) not in { - exir_ops.edge.aten.expand_copy, - exir_ops.edge.aten.expand, - }: - return super().call_operator(op, args, kwargs, meta) - - # Parse the args, and check for nop condition - arg0 = cast(ProxyValue, args[0]) - arg1 = cast(Sequence[int], args[1]) - in_tensor = arg0.to_tensor() - if list(in_tensor.shape) == list(arg1): - return arg0 + @property + def targets(self) -> list[EdgeOpOverload]: + return [ + exir_ops.edge.aten.expand_copy.default, + exir_ops.edge.aten.expand.default, + ] - return super().call_operator(op, args, kwargs, meta) + def maybe_remove_or_replace(self, node: Node) -> bool: + input_node = node.args[0] + assert isinstance(input_node, Node) + if input_node.meta["val"].shape == node.meta["val"].shape: + node.replace_all_uses_with(input_node) + return True + return False @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class RemoveToOpsPass(ExportPass): +class RemoveToOpsPass(RemoveOrReplacePassInterface): # aten.to.* as of now are all nops - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op not in ( + @property + def targets(self) -> list[EdgeOpOverload]: + return [ exir_ops.edge.aten.to.dtype, exir_ops.edge.aten.to.dtype_layout, - ): - return super().call_operator(op, args, kwargs, meta) + ] - logging.debug(f"Erasing to.dtype node (target = {op})") - return cast(ProxyValue, args[0]) + def maybe_remove_or_replace(self, node: Node) -> bool: + input_node = node.args[0] + assert isinstance(input_node, Node) + node.replace_all_uses_with(input_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) @@ -210,40 +197,37 @@ def maybe_remove_or_replace(self, node: Node) -> bool: @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class RemoveNopLinalgVectorNormOpPass(ExportPass): +class RemoveNopLinalgVectorNormOpPass(RemoveOrReplacePassInterface): """ If the norm is applied over a dimension that is size 1, it can be eliminated. """ - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op is not exir_ops.edge.aten.linalg_vector_norm.default: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.linalg_vector_norm.default] + def maybe_remove_or_replace(self, node: Node) -> bool: # If the op has three args or less, it can't be a nop - if len(args) <= 3: - return super().call_operator(op, args, kwargs, meta) + if len(node.args) <= 3: + return False # If dim is None, or keepdim is False, it is not a nop - dim = cast(Optional[tuple[int, ...]], args[2]) - keepdim = cast(bool, args[3]) + dim = cast(Optional[tuple[int, ...]], node.args[2]) + keepdim = cast(bool, node.args[3]) if dim is None or not keepdim: - return super().call_operator(op, args, kwargs, meta) + return False # If the norm has 4 args and keepdim is True, check if dim is not None # and if the dimensions in dim are size 1. If not, the norm is not a nop. - t = cast(ProxyValue, args[0]) - shape = t.to_tensor().shape - if len(args) < 4: + input_node = node.args[0] + assert isinstance(input_node, Node) + shape = input_node.meta["val"].shape + if len(node.args) < 4: for d in dim: if shape[d] != 1: - return super().call_operator(op, args, kwargs, meta) + return False - return t + node.replace_all_uses_with(input_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) @@ -358,23 +342,21 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class RemoveCloneOpPass(ExportPass): +class RemoveCloneOpPass(RemoveOrReplacePassInterface): # If the op is a clone op, return the input and eliminate the op - def call_operator( - self, - op, # pyre-ignore - args: tuple[ProxyValue], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.aten.clone.default: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.clone.default] - return args[0] + def maybe_remove_or_replace(self, node: Node) -> bool: + input_node = node.args[0] + assert isinstance(input_node, Node) + node.replace_all_uses_with(input_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class RemoveContiguousOpPass(ExportPass): +class RemoveContiguousOpPass(RemoveOrReplacePassInterface): """ This is based on the assumption that all tensors are contiguous in ExecuTorch and after cadence passes, and we should revisit this if that assumption is no longer true. @@ -382,43 +364,37 @@ class RemoveContiguousOpPass(ExportPass): original graph module. """ - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.aten.contiguous.default: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.contiguous.default] - assert len(args) == 1 - return cast(ProxyValue, args[0]) + def maybe_remove_or_replace(self, node: Node) -> bool: + input_node = node.args[0] + assert isinstance(input_node, Node) + node.replace_all_uses_with(input_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class RemoveAliasCopyOpPass(ExportPass): +class RemoveAliasCopyOpPass(RemoveOrReplacePassInterface): """ alias_copy is a no-op and can be removed. """ - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.aten.alias_copy.default: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.alias_copy.default] - assert len(args) == 1 - return cast(ProxyValue, args[0]) + def maybe_remove_or_replace(self, node: Node) -> bool: + input_node = node.args[0] + assert isinstance(input_node, Node) + node.replace_all_uses_with(input_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class RemoveNopRequantizeOpPass(ExportPass): +class RemoveNopRequantizeOpPass(RemoveOrReplacePassInterface): """ For a requantize op, if the following three conditions are satisfied: 1. the in_scale matches the out_scale @@ -427,100 +403,96 @@ class RemoveNopRequantizeOpPass(ExportPass): then the requantize op is redundant, and can be eliminated """ - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.cadence.requantize.per_tensor: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.cadence.requantize.per_tensor] - # Parse the args - (X, in_scale, in_zero_point, out_scale, out_zero_point, out_dtype) = cast( - tuple[ProxyValue, int, float, int, float, torch.dtype], args - ) - in_dtype = X.to_tensor().dtype + def maybe_remove_or_replace(self, node: Node) -> bool: + input_node = node.args[0] + assert isinstance(input_node, Node) + in_scale = node.args[1] + in_zero_point = node.args[2] + out_scale = node.args[3] + out_zero_point = node.args[4] + out_dtype = node.args[5] + in_dtype = input_node.meta["val"].dtype # Check the three conditions if ( in_scale == out_scale and in_zero_point == out_zero_point and in_dtype == out_dtype ): - return cast(ProxyValue, args[0]) - - return super().call_operator(op, args, kwargs, meta) + node.replace_all_uses_with(input_node) + return True + return False @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class RemoveNopMulOpPass(ExportPass): +class RemoveNopMulOpPass(RemoveOrReplacePassInterface): """ If a mul op is multiplying two tensors with the same shape and one of those tensors is all zeros, return the zero tensor instead. """ - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.aten.mul.Tensor: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.mul.Tensor] - # Parse the args - (input1, input2) = cast(tuple[ProxyValue, ProxyValue], args) + def maybe_remove_or_replace(self, node: Node) -> bool: + input1 = node.args[0] + input2 = node.args[1] + assert isinstance(input1, Node) + assert isinstance(input2, Node) # Check if both inputs have the same shape - if input1.to_tensor().shape != input2.to_tensor().shape: - return super().call_operator(op, args, kwargs, meta) + if input1.meta["val"].shape != input2.meta["val"].shape: + return False # Check if one of the inputs is a zero tensor - if input1.node.target == exir_ops.edge.aten.full.default: - if input1.node.args[1] == 0: - return input1 - elif input2.node.target == exir_ops.edge.aten.full.default: - if input2.node.args[1] == 0: - return input2 + if input1.target == exir_ops.edge.aten.full.default: + if input1.args[1] == 0: + node.replace_all_uses_with(input1) + return True + elif input2.target == exir_ops.edge.aten.full.default: + if input2.args[1] == 0: + node.replace_all_uses_with(input2) + return True - return super().call_operator(op, args, kwargs, meta) + return False @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class RemoveNopAddOpPass(ExportPass): +class RemoveNopAddOpPass(RemoveOrReplacePassInterface): """ If an add op is adding two tensors with the same shape and one of those tensors is all zeros, return the other tensor instead. """ - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.aten.add.Tensor: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.add.Tensor] - # Parse the args - (input1, input2) = cast(tuple[ProxyValue, ProxyValue], args) + def maybe_remove_or_replace(self, node: Node) -> bool: + input1 = node.args[0] + input2 = node.args[1] + assert isinstance(input1, Node) + assert isinstance(input2, Node) # Check if both inputs have the same shape - if input1.to_tensor().shape != input2.to_tensor().shape: - return super().call_operator(op, args, kwargs, meta) + if input1.meta["val"].shape != input2.meta["val"].shape: + return False # Check if one of the inputs is a zero tensor - if input1.node.target == exir_ops.edge.aten.full.default: - if input1.node.args[1] == 0: - return input2 - elif input2.node.target == exir_ops.edge.aten.full.default: - if input2.node.args[1] == 0: - return input1 - - return super().call_operator(op, args, kwargs, meta) + if input1.target == exir_ops.edge.aten.full.default: + if input1.args[1] == 0: + node.replace_all_uses_with(input2) + return True + elif input2.target == exir_ops.edge.aten.full.default: + if input2.args[1] == 0: + node.replace_all_uses_with(input1) + return True + + return False @register_cadence_pass(CadencePassAttribute(opt_level=2))