From 0b3ddd2b040ec200ecc1e60cd233249e6a1d7f6c Mon Sep 17 00:00:00 2001 From: Yufeng Shi Date: Wed, 5 Nov 2025 12:59:13 +0000 Subject: [PATCH] Arm backend: Upcast index argument to int64 for aten.index_copy ops Change-Id: I01dbb2b69c5689dbc8b5b56534996eb7d74828ff Signed-off-by: Yufeng Shi --- ...rt_int32_casts_after_int64_placeholders.py | 2 + ...t32_casts_after_int64_placeholders_pass.py | 70 ++++++++++++++++++- 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py b/backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py index ef5aa9625c7..de80d61bfbe 100644 --- a/backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py +++ b/backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py @@ -36,6 +36,8 @@ class InsertInt32CastsAfterInt64PlaceholdersPass(ArmPass): # Key: op overload; Value: zero-based indices of positional args that must be i64. I64_INPUT_ARG_POSITIONS = { torch.ops.aten.one_hot.default: (0,), + torch.ops.aten.index_copy_.default: (2,), + torch.ops.aten.index_copy.default: (2,), } def _insert_callsite_i32_to_i64_casts(self, graph_module: torch.fx.GraphModule): diff --git a/backends/arm/test/passes/test_insert_int32_casts_after_int64_placeholders_pass.py b/backends/arm/test/passes/test_insert_int32_casts_after_int64_placeholders_pass.py index 7c32cee8534..cf087a4b4ac 100644 --- a/backends/arm/test/passes/test_insert_int32_casts_after_int64_placeholders_pass.py +++ b/backends/arm/test/passes/test_insert_int32_casts_after_int64_placeholders_pass.py @@ -8,9 +8,13 @@ import torch from executorch.backends.arm._passes import InsertInt32CastsAfterInt64PlaceholdersPass -from executorch.backends.arm.test.tester.test_pipeline import PassPipeline +from executorch.backends.arm.test.tester.test_pipeline import ( + PassPipeline, + TosaPipelineINT, +) input_t = Tuple[torch.Tensor, torch.Tensor] # weights, indices +input_t3 = Tuple[torch.Tensor, torch.LongTensor, torch.Tensor] class Int64InputModel(torch.nn.Module): @@ -44,3 +48,67 @@ def test_int64_model_tosa_FP(): ) pipeline.pop_stage(-1) # Do not compare output pipeline.run() + + +class UpcastToInt64ForIndexCopyInplaceModel(torch.nn.Module): + aten_op = "torch.ops.aten.index_copy_.default" + + def forward(self, x: torch.Tensor, index: torch.LongTensor, y: torch.tensor): + return x.index_copy_(0, index, y) + + def get_inputs(self) -> input_t3: + return ( + torch.zeros(5, 3), + torch.tensor([0, 4, 2]), + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float), + ) + + +def test_upcast_to_int64_for_index_copy_inplace_tosa_INT(): + module = UpcastToInt64ForIndexCopyInplaceModel() + pipeline = TosaPipelineINT[input_t3]( + module, + module.get_inputs(), + aten_op=module.aten_op, + ) + pipeline.pop_stage("check.quant_nodes") + pipeline.change_args( + "check_count.exir", + { + "torch.ops.higher_order.executorch_call_delegate": 0, + }, + ) + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() + + +class UpcastToInt64ForIndexCopyModel(torch.nn.Module): + aten_op = "torch.ops.aten.index_copy.default" + + def forward(self, x: torch.Tensor, index: torch.LongTensor, y: torch.tensor): + return x.index_copy(0, index, y) + + def get_inputs(self) -> input_t3: + return ( + torch.zeros(5, 3), + torch.tensor([0, 4, 2]), + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float), + ) + + +def test_upcast_to_int64_for_index_copy_tosa_INT(): + module = UpcastToInt64ForIndexCopyModel() + pipeline = TosaPipelineINT[input_t3]( + module, + module.get_inputs(), + aten_op=module.aten_op, + ) + pipeline.pop_stage("check.quant_nodes") + pipeline.change_args( + "check_count.exir", + { + "torch.ops.higher_order.executorch_call_delegate": 0, + }, + ) + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run()