diff --git a/backends/arm/test/passes/test_insert_int32_casts_after_int64_placeholders_pass.py b/backends/arm/test/passes/test_insert_int32_casts_after_int64_placeholders_pass.py index cf087a4b4ac..2461a0e833a 100644 --- a/backends/arm/test/passes/test_insert_int32_casts_after_int64_placeholders_pass.py +++ b/backends/arm/test/passes/test_insert_int32_casts_after_int64_placeholders_pass.py @@ -53,13 +53,13 @@ def test_int64_model_tosa_FP(): class UpcastToInt64ForIndexCopyInplaceModel(torch.nn.Module): aten_op = "torch.ops.aten.index_copy_.default" - def forward(self, x: torch.Tensor, index: torch.LongTensor, y: torch.tensor): + def forward(self, x: torch.Tensor, index: torch.LongTensor, y: torch.Tensor): return x.index_copy_(0, index, y) def get_inputs(self) -> input_t3: return ( torch.zeros(5, 3), - torch.tensor([0, 4, 2]), + torch.LongTensor([0, 4, 2]), torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float), ) @@ -85,13 +85,13 @@ def test_upcast_to_int64_for_index_copy_inplace_tosa_INT(): class UpcastToInt64ForIndexCopyModel(torch.nn.Module): aten_op = "torch.ops.aten.index_copy.default" - def forward(self, x: torch.Tensor, index: torch.LongTensor, y: torch.tensor): + def forward(self, x: torch.Tensor, index: torch.LongTensor, y: torch.Tensor): return x.index_copy(0, index, y) def get_inputs(self) -> input_t3: return ( torch.zeros(5, 3), - torch.tensor([0, 4, 2]), + torch.LongTensor([0, 4, 2]), torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float), )