Skip to content
2 changes: 1 addition & 1 deletion backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from . import arm_pass_utils # noqa
from .arm_pass import ArmPass # noqa # usort: skip
from .add_bias_pass import AddBiasPass # noqa
from .annotate_channels_last_dim_order_pass import AnnotateChannelsLastDimOrder # noqa
from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa
from .broadcast_args_pass import BroadcastArgsPass # noqa
from .cast_bool_to_int8_pass import CastBoolToInt8Pass # noqa
Expand Down Expand Up @@ -85,6 +84,7 @@
)
from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa
from .size_adjust_input_pass import SizeAdjustInputPass # noqa
from .to_tosa_memory_format_pass import ToTosaMemoryFormatPass # noqa
from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa
from .unsqueeze_scalar_placeholders_pass import UnsqueezeScalarPlaceholdersPass # noqa
from .replace_inf_values_pass import ReplaceInfValues # noqa # usort: skip
Expand Down
6 changes: 3 additions & 3 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import executorch.backends.arm.tosa.dialect # noqa: unused
from executorch.backends.arm._passes import (
AddBiasPass,
AnnotateChannelsLastDimOrder,
AnnotateDecomposedMatmulPass,
BroadcastArgsPass,
CastBoolToInt8Pass,
Expand Down Expand Up @@ -84,6 +83,7 @@
RetraceFoldedDtypesPass,
ScalarsToAttributePass,
SizeAdjustInputPass,
ToTosaMemoryFormatPass,
UnsqueezeBeforeRepeatPass,
UnsqueezeScalarPlaceholdersPass,
)
Expand Down Expand Up @@ -162,7 +162,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:

self.add_pass(InsertTableOpsPass(exported_program))
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
self.add_pass(AnnotateChannelsLastDimOrder())
self.add_pass(ToTosaMemoryFormatPass(exported_program))
self.add_pass(InsertRescalePass())

return self._transform(exported_program.graph_module)
Expand Down Expand Up @@ -241,7 +241,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(AddBiasPass(exported_program))
self.add_pass(InsertTableOpsPass(exported_program))
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
self.add_pass(AnnotateChannelsLastDimOrder())
self.add_pass(ToTosaMemoryFormatPass(exported_program))
self.add_pass(InsertRescalePass())

return self._transform(exported_program.graph_module)
Expand Down
12 changes: 8 additions & 4 deletions backends/arm/_passes/decompose_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
# pyre-unsafe

import torch
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
get_first_fake_tensor,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

Expand All @@ -34,8 +37,9 @@ def call(self, graph_module: torch.fx.GraphModule):

input_node, dim, index = node.args

rank = len(input_node.meta["val"].size())
shape = input_node.meta["val"].shape
input_tensor = get_first_fake_tensor(input_node)
rank = len(input_tensor.size())
shape = input_tensor.shape
dim = dim % rank if dim < 0 else dim
index = index % shape[dim] if index < 0 else index

Expand All @@ -44,7 +48,7 @@ def call(self, graph_module: torch.fx.GraphModule):
graph_module.graph, slice_op, (input_node, dim, index, index + 1)
)
squeeze_node = create_node(
graph_module.graph, squeeze_op, (slice_node, [dim])
graph_module.graph, squeeze_op, (slice_node, [dim]), from_node=node
)

node.replace_all_uses_with(squeeze_node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,22 @@
from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
get_first_fake_tensor,
is_param_node,
)
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class AnnotateChannelsLastDimOrder(ExportPass):
def _is_input(node: torch.fx.Node, exported_program: ExportedProgram) -> bool:
"""
Returns True if the node is an input node, i.e. a placeholder or a parameter.
"""
return node.op == "placeholder" and not is_param_node(exported_program, node)


class ToTosaMemoryFormatPass(ExportPass):
"""
Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order
that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts backend.tosa.TRANSPOSE
Expand All @@ -30,6 +39,10 @@ class AnnotateChannelsLastDimOrder(ExportPass):
NNHWC_order = (0, 1, 3, 4, 2)
NNHWC_inverse_order = (0, 1, 4, 2, 3)

def __init__(self, exported_program: ExportedProgram) -> None:
self.exported_program = exported_program
super().__init__()

def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
"""
returns True for w in the following sequence;
Expand Down Expand Up @@ -92,25 +105,30 @@ def is_channel_reshape(input_shape, output_shape):

@staticmethod
def insert_input_transpose(node, input_node, graph_module):
if input_node.target == exir_ops.backend.tosa.TRANSPOSE.default:
pre_permute_node = input_node.all_input_nodes[0]
node.replace_input_with(input_node, pre_permute_node)
return

with graph_module.graph.inserting_before(node):
permute_node = create_node(
graph_module.graph,
exir_ops.backend.tosa.TRANSPOSE.default,
args=(
input_node,
list(
AnnotateChannelsLastDimOrder.NNHWC_inverse_order
ToTosaMemoryFormatPass.NNHWC_inverse_order
if len(get_first_fake_tensor(input_node).size()) == 5
else AnnotateChannelsLastDimOrder.NHWC_inverse_order
else ToTosaMemoryFormatPass.NHWC_inverse_order
),
),
from_node=node,
)
node.replace_input_with(input_node, permute_node)

permute_node.meta["tosa_dim_order"] = tuple(
range(len(input_node.meta["val"].size()))
)
permute_node.meta["val"] = input_node.meta["val"]

@staticmethod
def insert_output_transpose(node, graph_module):
Expand All @@ -121,25 +139,23 @@ def insert_output_transpose(node, graph_module):
args=(
node,
list(
AnnotateChannelsLastDimOrder.NNHWC_order
ToTosaMemoryFormatPass.NNHWC_order
if len(get_first_fake_tensor(node).size()) == 5
else AnnotateChannelsLastDimOrder.NHWC_order
else ToTosaMemoryFormatPass.NHWC_order
),
),
from_node=node,
)

permute_node.meta["tosa_dim_order"] = (
AnnotateChannelsLastDimOrder.NNHWC_order
ToTosaMemoryFormatPass.NNHWC_order
if len(get_first_fake_tensor(node).size()) == 5
else AnnotateChannelsLastDimOrder.NHWC_order
)
permute_node.meta["val"] = get_first_fake_tensor(node).permute(
AnnotateChannelsLastDimOrder.NNHWC_order
if len(get_first_fake_tensor(node).size()) == 5
else AnnotateChannelsLastDimOrder.NHWC_order
else ToTosaMemoryFormatPass.NHWC_order
)
node.meta["tosa_dim_order"] = tuple(
range(len(get_first_fake_tensor(node).size()))
)

users = [user for user in node.users if user != permute_node]
for user in users:
user.replace_input_with(node, permute_node)
Expand All @@ -150,20 +166,23 @@ def _insert_view_transpose(
):
nchw_to_nhwc = len(input_shape) < 4 and len(output_shape) >= 4
nhwc_to_nchw = len(input_shape) >= 4 and len(output_shape) < 4
channel_reshape = AnnotateChannelsLastDimOrder.is_channel_reshape(
channel_reshape = ToTosaMemoryFormatPass.is_channel_reshape(
output_shape, input_shape
)

if (
channel_reshape or nhwc_to_nchw
) and AnnotateChannelsLastDimOrder.memory_format_differs(input_shape):
AnnotateChannelsLastDimOrder.insert_input_transpose(
) and ToTosaMemoryFormatPass.memory_format_differs(input_shape):

ToTosaMemoryFormatPass.insert_input_transpose(
node, input_node, graph_module
)

if (
channel_reshape or nchw_to_nhwc
) and AnnotateChannelsLastDimOrder.memory_format_differs(output_shape):
AnnotateChannelsLastDimOrder.insert_output_transpose(node, graph_module)
) and ToTosaMemoryFormatPass.memory_format_differs(output_shape):

ToTosaMemoryFormatPass.insert_output_transpose(node, graph_module)

def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
"""
Expand All @@ -181,9 +200,10 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
# call_function and placeholder allowed due to
# index.Tensor being able to come in as both
if node.op not in ["call_function", "placeholder"]:
if node.op not in ["call_function", "placeholder", "output"]:
continue

# Transpose views
elif node.target in (
exir_ops.edge.aten.view_copy.default,
exir_ops.edge.aten.index.Tensor,
Expand All @@ -194,25 +214,48 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
input_node = node.args[0]
input_shape = input_node.meta["val"].shape
output_shape = node.meta["val"].shape

self._insert_view_transpose(
input_shape, output_shape, node, input_node, graph_module
input_shape,
output_shape,
node,
input_node,
graph_module,
)

# Transpose inputs
elif _is_input(node, self.exported_program):
input_shape = get_first_fake_tensor(node).size()
if len(input_shape) in (4, 5):
ToTosaMemoryFormatPass.insert_output_transpose(node, graph_module)

# Transpose outputs
elif node.op == "output":
output_shape = get_first_fake_tensor(node).size()

if len(output_shape) in (4, 5):
for input_node in node.all_input_nodes:
ToTosaMemoryFormatPass.insert_input_transpose(
node, input_node, graph_module
)

def call(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
node_data = get_first_fake_tensor(node).data

if node_data.dim() == 4:
# Inputs and outputs are always in (N)NCHW format
if _is_input(node, self.exported_program) or node.op == "output":
dim_order = tuple(range(node_data.dim()))
elif node_data.dim() == 4:
dim_order = self.NHWC_order
if self.is_weight_node_for_depthwise_conv2d(node):
# The weights of TOSA DEPTHWISE_CONV2D have shape (H, W, C, M) which corresponds to
# dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d).
dim_order = self.HWCM_order
elif node_data.dim() == 5:
dim_order = self.NNHWC_order # type: ignore[assignment]
dim_order = self.NNHWC_order
else:
dim_order = tuple(range(node_data.dim())) # type: ignore[assignment]

node.meta["tosa_dim_order"] = dim_order
# Insert TOSA transposes to convert between (N)NCHW and (N)NHWC format.
# See insert_tosa_transposes for insertion conditions.
Expand Down
9 changes: 8 additions & 1 deletion backends/arm/operators/op_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,14 @@ def define_node(
validate_valid_dtype(
self.target,
[inputs[0], output],
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32],
[
ts.DType.INT8,
ts.DType.INT16,
ts.DType.INT32,
ts.DType.FP32,
ts.DType.BOOL,
ts.DType.FP16,
],
output.tosa_spec,
)

Expand Down
Loading
Loading