diff --git a/backends/xnnpack/partition/graphs/bilinear_2d.py b/backends/xnnpack/partition/graphs/bilinear_2d.py index a971cb92441..0040439f845 100644 --- a/backends/xnnpack/partition/graphs/bilinear_2d.py +++ b/backends/xnnpack/partition/graphs/bilinear_2d.py @@ -37,12 +37,15 @@ def forward(self, x): ] for align_corners in [True, False]: for config in capture_configs: - edge = exir.capture( - bilinear2d(align_corners), sample_inputs, config - ).to_edge( - config=get_xnnpack_edge_compile_config(), - ) - _bilinear2d_graphs[edge.exported_program.graph_module] = align_corners + for skip_dim_order_flag in [True, False]: + edge = exir.capture( + bilinear2d(align_corners), sample_inputs, config + ).to_edge( + config=get_xnnpack_edge_compile_config( + skip_dim_order=skip_dim_order_flag + ) + ) + _bilinear2d_graphs[edge.exported_program.graph_module] = align_corners return _bilinear2d_graphs diff --git a/backends/xnnpack/passes/__init__.py b/backends/xnnpack/passes/__init__.py index 1ca4fe307fe..c3a85e4aa86 100644 --- a/backends/xnnpack/passes/__init__.py +++ b/backends/xnnpack/passes/__init__.py @@ -27,6 +27,7 @@ from executorch.exir.pass_base import ExportPass from executorch.exir.passes.const_prop_pass import ConstPropPass +from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass from executorch.exir.program._program import _transform from torch._export.pass_base import PassType @@ -50,6 +51,8 @@ def __init__( if not passes: # All the XNNPACK passes self.passes = [ + # TODO - remove this pass once we have a better support for dim_order ops lowering + DimOrderOpsRevertPass, ConvertToUpsampleBilinear2d, ConvertToLinearPass, ConvertToSDPAPass, diff --git a/backends/xnnpack/test/ops/bilinear2d.py b/backends/xnnpack/test/ops/bilinear2d.py index ab9d3d3c11d..d3c85350692 100644 --- a/backends/xnnpack/test/ops/bilinear2d.py +++ b/backends/xnnpack/test/ops/bilinear2d.py @@ -65,12 +65,15 @@ def forward(self, x): ) return a + # Since we may or may not enable dim order, use these ops only for + # check_not since we have `to_copy` and `to_dim_order_copy` in the list. ops = { "executorch_exir_dialects_edge__ops_aten_sub_Tensor", "executorch_exir_dialects_edge__ops_aten_mul_Tensor", "executorch_exir_dialects_edge__ops_aten_index_Tensor", "executorch_exir_dialects_edge__ops_aten_arange_start_step", "executorch_exir_dialects_edge__ops_aten__to_copy_default", + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default", "executorch_exir_dialects_edge__ops_aten_add_Tensor", "executorch_exir_dialects_edge__ops_aten_clamp_default", } @@ -81,7 +84,6 @@ def test_fp32_static_resize_bilinear2d(self): Tester(self.StaticResizeBilinear2dModule(), example_inputs) .export() .to_edge() - .check(self.ops) .partition() .check_not(self.ops) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) @@ -90,13 +92,12 @@ def test_fp32_static_resize_bilinear2d(self): .run_method_and_compare_outputs() ) - def test_fp32_static_resize_bilinear2d_with_align_cornesr(self): + def test_fp32_static_resize_bilinear2d_with_align_corners(self): example_inputs = (torch.randn(2, 3, 4, 5),) ( Tester(self.StaticResizeBilinear2dModuleWithAlignCorners(), example_inputs) .export() .to_edge() - .check(self.ops) .partition() .check_not(self.ops) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) diff --git a/backends/xnnpack/utils/configs.py b/backends/xnnpack/utils/configs.py index 3fe290606c2..9dda84c5e55 100644 --- a/backends/xnnpack/utils/configs.py +++ b/backends/xnnpack/utils/configs.py @@ -12,8 +12,12 @@ ### XNNPACK Configs ### -def get_xnnpack_edge_compile_config() -> exir.EdgeCompileConfig: - return exir.EdgeCompileConfig(_check_ir_validity=False, _skip_dim_order=True) +def get_xnnpack_edge_compile_config( + skip_dim_order: bool = True, +) -> exir.EdgeCompileConfig: + return exir.EdgeCompileConfig( + _check_ir_validity=False, _skip_dim_order=skip_dim_order + ) def get_transform_passes(additional_passes=None) -> List[PassType]: diff --git a/exir/passes/dim_order_ops_registry.py b/exir/passes/dim_order_ops_registry.py index 7fed005b3c6..27fc03f9413 100644 --- a/exir/passes/dim_order_ops_registry.py +++ b/exir/passes/dim_order_ops_registry.py @@ -45,3 +45,15 @@ def _to_dim_order_copy_out_impl(*args, **kwargs): DimOrderOpsMap = { "aten._to_copy.default": exir_ops.edge.dim_order_ops._to_dim_order_copy.default, } + +""" +Defines a map of aten or edge ops to the corresponding memory format ops for quick lookup +""" +MemoryFormatOpsMap = { + "dim_order_ops._to_dim_order_copy.default": exir_ops.edge.aten._to_copy.default, +} + +# If we are replacing an aten op with a dim_order op, we must have a 1:1 mapping through these dicts. +assert len(DimOrderOpsMap) == len(MemoryFormatOpsMap) + +# TODO stricter check for 1:1 mapping diff --git a/exir/passes/memory_format_ops_pass.py b/exir/passes/memory_format_ops_pass.py index 5a3c0f3a912..32678bf4082 100644 --- a/exir/passes/memory_format_ops_pass.py +++ b/exir/passes/memory_format_ops_pass.py @@ -9,13 +9,19 @@ import torch from executorch.exir.dialects.edge._ops import EdgeOpOverload -from executorch.exir.dim_order_utils import get_dim_order +from executorch.exir.dim_order_utils import get_dim_order, get_memory_format from executorch.exir.pass_base import ExportPass, ProxyValue -from executorch.exir.passes.dim_order_ops_registry import DimOrderOpsMap +from executorch.exir.passes.dim_order_ops_registry import ( + DimOrderOpsMap, + MemoryFormatOpsMap, +) logger = logging.getLogger(__file__) logger.setLevel(logging.INFO) +# TODO - these passes are too specialized on a single to_copy op. +# We should be able to replace (or revert) any of the dim_order ops in the future. + class MemoryFormatOpsPass(ExportPass): """ @@ -53,7 +59,55 @@ def call_operator(self, op, args, kwargs, meta): f" _to_dim_order_copy = dim_order: {nkwargs['dim_order']}" ) - t = DimOrderOpsMap[op.__name__] + t = DimOrderOpsMap.get(op.__name__, None) + assert t is not None, f"{op.__name__} not found in DimOrderOpsMap" + + return super().call_operator( + t, + args, + nkwargs, + meta, + ) + + +class DimOrderOpsRevertPass(ExportPass): + """ + This pass is to revert the dim_order ops back to the memory format ops. + """ + + def call_operator(self, op, args, kwargs, meta): + if not (isinstance(op, EdgeOpOverload) and op.__name__ in MemoryFormatOpsMap): + return super().call_operator( + op, + args, + kwargs, + meta, + ) + + # new kwargs with dim_order, and no memory_format for the new op + nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable + + # can always get the shape, assuming rank is specialized + if isinstance(args[0], ProxyValue) and args[0].is_tensor(): + ndim = args[0].to_tensor().dim() + elif isinstance(args[0], torch.Tensor): + ndim = args[0].dim() + else: + assert 0, f"Expecting a Tensor or a ProxyValue buy got {type(args[0])}" + + # get the "to" memory format for the EdgeOp + default_dim_order = list(range(ndim)) + dim_order = nkwargs.pop("dim_order", default_dim_order) + + nkwargs["memory_format"] = get_memory_format(dim_order) + + logger.debug( + f" _to_dim_order_copy = dim_order: {dim_order}." + f"_to_copy = rank: {ndim}, memory_format: {nkwargs['memory_format']}." + ) + + t = MemoryFormatOpsMap.get(op.__name__, None) + assert t is not None, f"{op.__name__} not found in MemoryFormatOpsMap" return super().call_operator( t, diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 61d3af8afb6..99ec6481458 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -37,10 +37,11 @@ from executorch.exir.passes.insert_write_back_for_buffers_pass import ( insert_write_back_for_buffers_pass, ) + +from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass from executorch.exir.passes.normalize_view_copy_base_pass import ( NormalizeViewCopyBasePass, ) - from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass from executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators from executorch.exir.passes.replace_edge_with_backend_pass import EdgeToBackendOpsPass @@ -1676,3 +1677,100 @@ def forward(self, text_tokens): ) new_ep = constant_prop_pass(edge_manager._edge_programs["forward"]) _ = copy.deepcopy(new_ep.module_call_graph) + + def test_dim_order_revert_pass(self) -> None: + aten_op_str = "torch.ops.aten._to_copy.default" + edge_aten_op_str = "executorch_exir_dialects_edge__ops_aten__to_copy_default" + edge_dim_order_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default" + + class Module(torch.nn.Module): + """ + A simple module that has a single to op that converts to channels last and then back to contiguous. + Assuming contiguous input. + """ + + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.to(memory_format=torch.channels_last).to( + memory_format=torch.contiguous_format + ) + x.to(memory_format=torch.channels_last).to( + memory_format=torch.contiguous_format + ) + + @staticmethod + def to_copy_count(): + return 4 + + def _do_checks( + test_str: str, allowed: str, allowed_count: int, not_allowed_list: List[str] + ) -> None: + for not_allowed in not_allowed_list: + FileCheck().check_count(allowed, allowed_count, exactly=True).check_not( + not_allowed + ).run(test_str) + + m = Module() + n = m.to_copy_count() + input = torch.randn([2, 3, 4, 5]).to(memory_format=torch.contiguous_format) + + # 1. vanilla export, no edge ops + ep = export( + m, + (input,), + ) + _do_checks( + ep.graph_module.code, + aten_op_str, + n, + [edge_aten_op_str, edge_dim_order_op_str], + ) + + # 2a. to edge without dim orders, we should see edge aten ops but not dim order ops + edge_prog = to_edge( + ep, compile_config=exir.EdgeCompileConfig(_skip_dim_order=True) + )._edge_programs["forward"] + _do_checks( + edge_prog.graph_module.code, + edge_aten_op_str, + n, + [aten_op_str, edge_dim_order_op_str], + ) + + # 3a. expect no change after the pass, we should see edge aten ops but not dim order ops + new_res = DimOrderOpsRevertPass()(edge_prog.graph_module) + self.assertIsNotNone(new_res) + _do_checks( + new_res.graph_module.code, + edge_aten_op_str, + n, + [aten_op_str, edge_dim_order_op_str], + ) + + # 2b. let's try with dim order enabled, we should see edge dim order ops but not edge aten ops + edge_prog_dim_order = to_edge( + ep, compile_config=exir.EdgeCompileConfig(_skip_dim_order=False) + )._edge_programs["forward"] + _do_checks( + edge_prog_dim_order.graph_module.code, + edge_dim_order_op_str, + n, + [aten_op_str, edge_aten_op_str], + ) + + # 3b. expect edge aten ops after the pass, we should see not see the edge dim order ops + new_res_dim_order = DimOrderOpsRevertPass()(edge_prog_dim_order.graph_module) + self.assertIsNotNone(new_res_dim_order) + _do_checks( + new_res_dim_order.graph_module.code, + edge_aten_op_str, + n, + [aten_op_str, edge_dim_order_op_str], + ) + + output_no_dim_order = new_res.graph_module(input) + output_no_dim_order_revert = new_res_dim_order.graph_module(input) + self.assertTrue( + torch.allclose(output_no_dim_order[0], output_no_dim_order_revert[0]) + )