From 0981612b3a370ddcc0666ca39280702f88487b7a Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 17 Jan 2025 12:12:15 +0100 Subject: [PATCH 01/18] Fix index_put with boolean index --- onnxscript/function_libs/torch_lib/ops/core.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index a1793858e9..26ddbaf51a 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4297,7 +4297,12 @@ def aten_index_put_bool( """index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor""" index = op.SequenceAt(indices, 0) # assume indices only have 1 element + # accumulate should be always False, True does not make sense. + return op.Where(index, values, self) + + """ # FIXME: ORT ArgMax fails on INT64 input even though ONNX allows it + index = op.SequenceAt(indices, 0) # assume indices only have 1 element index_int = op.Cast(index, to=INT32.dtype) # if all False, return op.Identity(self) if op.ReduceSum(index_int) == 0: @@ -4327,6 +4332,7 @@ def aten_index_put_bool( result = op.ScatterElements(self, new_ind_t, values) return result + """ def aten_index_reduce( From e5ff42f6ed39f43e9345ed0ac7c3c88fb039f7a6 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 17 Jan 2025 12:18:02 +0100 Subject: [PATCH 02/18] add unit test --- tests/function_libs/torch_lib/e2e_test.py | 57 +++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 tests/function_libs/torch_lib/e2e_test.py diff --git a/tests/function_libs/torch_lib/e2e_test.py b/tests/function_libs/torch_lib/e2e_test.py new file mode 100644 index 0000000000..aa1e75f22e --- /dev/null +++ b/tests/function_libs/torch_lib/e2e_test.py @@ -0,0 +1,57 @@ +import unittest +import onnx +import torch + + +class TestEnd2End(unittest.TestCase): + def test_adaptive_enc_mask(self): + + def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): + # first idx of each chunk, such as [0,18,36,48]. + chunk_start_idx = torch.Tensor(chunk_start_idx).long() + # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48] + start_pad = torch.nn.functional.pad(chunk_start_idx, (1, 0)) + # append x_len to the end, so it becomes [0,18,36,48, x_len] + end_pad = torch.nn.functional.pad(chunk_start_idx, (0, 1), value=x_len) + # seq_range size: [x_len, 1] + seq_range = torch.arange(0, x_len).unsqueeze(-1) + # idx size: [x_len] + idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1] + # boundary size: [x_len] + # boundary = end_pad[idx] + # seq_range_expand size [x_len, x_len] + seq_range_expand = torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) + idx_left = idx - left_window + idx_left[idx_left < 0] = 0 + boundary_left = start_pad[idx_left] + mask_left = seq_range_expand >= boundary_left.unsqueeze(-1) + idx_right = idx + right_window + idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx) + boundary_right = end_pad[idx_right] + mask_right = seq_range_expand < boundary_right.unsqueeze(-1) + return mask_left & mask_right + + class MyModule(torch.nn.Module): + def forward(self, X): + x_len = 10 # 368 + chunk_start_idx = [4] + left_window = 18 + result = adaptive_enc_mask(x_len, chunk_start_idx, left_window, right_window=0) + return X + torch.unsqueeze(result, -1) + + torch_model = MyModule() + torch_model.eval() + inputs = (torch.randn(1, 1, 368),) + expected = torch_model(*inputs) + + program = torch.onnx.export(torch_model, inputs, dynamo=True) + # program.save(r"test_adaptive_enc_mask_not_optimized.onnx") + program.optimize() + program.save(r"test_adaptive_enc_mask.onnx") + ref = onnx.reference.ReferenceEvaluator(program.model_proto) + got = ref.run(None, {"x": inputs[0].numpy()}) + torch.testing.assert_close(expected, torch.tensor(got[0])) + + +if __name__ == "__main__": + unittest.main(verbosity=2) From d52b331e43548319f1f19b7fd7e1b28885ce8115 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 17 Jan 2025 12:47:03 +0100 Subject: [PATCH 03/18] disable test for torch 2.6 --- .../function_libs/torch_lib/ops/core.py | 36 +------------------ tests/function_libs/torch_lib/e2e_test.py | 3 +- 2 files changed, 3 insertions(+), 36 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 26ddbaf51a..b85cefad97 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4297,43 +4297,9 @@ def aten_index_put_bool( """index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor""" index = op.SequenceAt(indices, 0) # assume indices only have 1 element - # accumulate should be always False, True does not make sense. + # accumulate should be always False, True does not make sense but an assert would be great return op.Where(index, values, self) - """ - # FIXME: ORT ArgMax fails on INT64 input even though ONNX allows it - index = op.SequenceAt(indices, 0) # assume indices only have 1 element - index_int = op.Cast(index, to=INT32.dtype) - # if all False, return op.Identity(self) - if op.ReduceSum(index_int) == 0: - result = self - else: - # change array([F,F,T,F,F]) to array([2]) - index = op.ArgMax(index_int) # assume index only have 1 True - # change array([2]) to array([2,2,2,2,2]) - self_dim_1 = op.Shape(self, start=1, end=2) - index_dim_0 = op.Shape(index, start=0, end=1) - shape = op.Concat(self_dim_1, index_dim_0, axis=0) - new_ind = op.Expand(index, shape) - new_ind_t = op.Transpose(new_ind) - - # values must have same rank with input(self) - if op.Size(op.Shape(values)) < op.Size(op.Shape(self)): # type: ignore[operator] - values = op.Unsqueeze(values, op.Constant(value_ints=[0])) - - if op.Cast(accumulate, to=BOOL.dtype): - zeros = op.Expand(op.Constant(value_float=0.0), op.Shape(self)) - zeros = op.CastLike(zeros, values) - result = op.ScatterElements(zeros, new_ind_t, values) - # FIXME: type promotion - result = op.CastLike(result, self) - result = op.Add(result, self) - else: - result = op.ScatterElements(self, new_ind_t, values) - - return result - """ - def aten_index_reduce( self: TensorType, diff --git a/tests/function_libs/torch_lib/e2e_test.py b/tests/function_libs/torch_lib/e2e_test.py index aa1e75f22e..a8d3e78c77 100644 --- a/tests/function_libs/torch_lib/e2e_test.py +++ b/tests/function_libs/torch_lib/e2e_test.py @@ -1,11 +1,12 @@ import unittest import onnx import torch +from onnxscript._internal.version_utils import torch_older_than class TestEnd2End(unittest.TestCase): + @unittest.skipIf(torch_older_than("2.6"), reason="fails to export") def test_adaptive_enc_mask(self): - def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): # first idx of each chunk, such as [0,18,36,48]. chunk_start_idx = torch.Tensor(chunk_start_idx).long() From 1e20ebe465e659a22e24447375521d2647eccabb Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 17 Jan 2025 12:59:18 +0100 Subject: [PATCH 04/18] lint --- tests/function_libs/torch_lib/e2e_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/function_libs/torch_lib/e2e_test.py b/tests/function_libs/torch_lib/e2e_test.py index a8d3e78c77..3270d4d4ef 100644 --- a/tests/function_libs/torch_lib/e2e_test.py +++ b/tests/function_libs/torch_lib/e2e_test.py @@ -1,6 +1,8 @@ import unittest + import onnx import torch + from onnxscript._internal.version_utils import torch_older_than From a84cb89feb0340a3c997261baaab0bbddd4e7477 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 17 Jan 2025 14:22:36 +0100 Subject: [PATCH 05/18] fix issues --- tests/function_libs/torch_lib/e2e_test.py | 24 +++++++++-------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/tests/function_libs/torch_lib/e2e_test.py b/tests/function_libs/torch_lib/e2e_test.py index 3270d4d4ef..4946450844 100644 --- a/tests/function_libs/torch_lib/e2e_test.py +++ b/tests/function_libs/torch_lib/e2e_test.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. import unittest import onnx @@ -9,27 +11,19 @@ class TestEnd2End(unittest.TestCase): @unittest.skipIf(torch_older_than("2.6"), reason="fails to export") def test_adaptive_enc_mask(self): - def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): - # first idx of each chunk, such as [0,18,36,48]. - chunk_start_idx = torch.Tensor(chunk_start_idx).long() - # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48] - start_pad = torch.nn.functional.pad(chunk_start_idx, (1, 0)) - # append x_len to the end, so it becomes [0,18,36,48, x_len] - end_pad = torch.nn.functional.pad(chunk_start_idx, (0, 1), value=x_len) - # seq_range size: [x_len, 1] + def adaptive_enc_mask(x_len, start_idx, left_window=0, right_window=0): + start_idx = torch.Tensor(start_idx).long() + start_pad = torch.nn.functional.pad(start_idx, (1, 0)) + end_pad = torch.nn.functional.pad(start_idx, (0, 1), value=x_len) seq_range = torch.arange(0, x_len).unsqueeze(-1) - # idx size: [x_len] idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1] - # boundary size: [x_len] - # boundary = end_pad[idx] - # seq_range_expand size [x_len, x_len] seq_range_expand = torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) idx_left = idx - left_window idx_left[idx_left < 0] = 0 boundary_left = start_pad[idx_left] mask_left = seq_range_expand >= boundary_left.unsqueeze(-1) idx_right = idx + right_window - idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx) + idx_right[idx_right > len(start_idx)] = len(start_idx) boundary_right = end_pad[idx_right] mask_right = seq_range_expand < boundary_right.unsqueeze(-1) return mask_left & mask_right @@ -37,9 +31,9 @@ def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): class MyModule(torch.nn.Module): def forward(self, X): x_len = 10 # 368 - chunk_start_idx = [4] + start_idx = [4] left_window = 18 - result = adaptive_enc_mask(x_len, chunk_start_idx, left_window, right_window=0) + result = adaptive_enc_mask(x_len, start_idx, left_window, right_window=0) return X + torch.unsqueeze(result, -1) torch_model = MyModule() From 62ab7b8800a8668a081c04e7c9ab8e5c541493a1 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 17 Jan 2025 14:25:37 +0100 Subject: [PATCH 06/18] renaming --- tests/function_libs/torch_lib/e2e_test.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/function_libs/torch_lib/e2e_test.py b/tests/function_libs/torch_lib/e2e_test.py index 4946450844..edf86d3bb1 100644 --- a/tests/function_libs/torch_lib/e2e_test.py +++ b/tests/function_libs/torch_lib/e2e_test.py @@ -10,8 +10,8 @@ class TestEnd2End(unittest.TestCase): @unittest.skipIf(torch_older_than("2.6"), reason="fails to export") - def test_adaptive_enc_mask(self): - def adaptive_enc_mask(x_len, start_idx, left_window=0, right_window=0): + def test_index_put_failing_function(self): + def index_put_failing_function(x_len, start_idx, left_window=0, right_window=0): start_idx = torch.Tensor(start_idx).long() start_pad = torch.nn.functional.pad(start_idx, (1, 0)) end_pad = torch.nn.functional.pad(start_idx, (0, 1), value=x_len) @@ -33,7 +33,7 @@ def forward(self, X): x_len = 10 # 368 start_idx = [4] left_window = 18 - result = adaptive_enc_mask(x_len, start_idx, left_window, right_window=0) + result = index_put_failing_function(x_len, start_idx, left_window, right_window=0) return X + torch.unsqueeze(result, -1) torch_model = MyModule() @@ -42,9 +42,9 @@ def forward(self, X): expected = torch_model(*inputs) program = torch.onnx.export(torch_model, inputs, dynamo=True) - # program.save(r"test_adaptive_enc_mask_not_optimized.onnx") + # program.save(r"test_index_put_failing_function_not_optimized.onnx") program.optimize() - program.save(r"test_adaptive_enc_mask.onnx") + program.save(r"test_index_put_failing_function.onnx") ref = onnx.reference.ReferenceEvaluator(program.model_proto) got = ref.run(None, {"x": inputs[0].numpy()}) torch.testing.assert_close(expected, torch.tensor(got[0])) From 156b48a8de2aadad178f74c90081b94a769b9f46 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 17 Jan 2025 14:52:41 +0100 Subject: [PATCH 07/18] cusop --- tests/function_libs/torch_lib/e2e_test.py | 82 ++++++++++++++++++----- 1 file changed, 64 insertions(+), 18 deletions(-) diff --git a/tests/function_libs/torch_lib/e2e_test.py b/tests/function_libs/torch_lib/e2e_test.py index edf86d3bb1..ae8fcb6091 100644 --- a/tests/function_libs/torch_lib/e2e_test.py +++ b/tests/function_libs/torch_lib/e2e_test.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import unittest +from typing import List import onnx import torch @@ -8,32 +9,35 @@ from onnxscript._internal.version_utils import torch_older_than +def _index_put_failing_function(x_len, start_idx, left_window=0, right_window=0): + start_idx = torch.Tensor(start_idx).long() + start_pad = torch.nn.functional.pad(start_idx, (1, 0)) + end_pad = torch.nn.functional.pad(start_idx, (0, 1), value=x_len) + seq_range = torch.arange(0, x_len).unsqueeze(-1) + idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1] + seq_range_expand = torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) + idx_left = idx - left_window + idx_left[idx_left < 0] = 0 + boundary_left = start_pad[idx_left] + mask_left = seq_range_expand >= boundary_left.unsqueeze(-1) + idx_right = idx + right_window + idx_right[idx_right > len(start_idx)] = len(start_idx) + boundary_right = end_pad[idx_right] + mask_right = seq_range_expand < boundary_right.unsqueeze(-1) + return mask_left & mask_right + + class TestEnd2End(unittest.TestCase): @unittest.skipIf(torch_older_than("2.6"), reason="fails to export") def test_index_put_failing_function(self): - def index_put_failing_function(x_len, start_idx, left_window=0, right_window=0): - start_idx = torch.Tensor(start_idx).long() - start_pad = torch.nn.functional.pad(start_idx, (1, 0)) - end_pad = torch.nn.functional.pad(start_idx, (0, 1), value=x_len) - seq_range = torch.arange(0, x_len).unsqueeze(-1) - idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1] - seq_range_expand = torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) - idx_left = idx - left_window - idx_left[idx_left < 0] = 0 - boundary_left = start_pad[idx_left] - mask_left = seq_range_expand >= boundary_left.unsqueeze(-1) - idx_right = idx + right_window - idx_right[idx_right > len(start_idx)] = len(start_idx) - boundary_right = end_pad[idx_right] - mask_right = seq_range_expand < boundary_right.unsqueeze(-1) - return mask_left & mask_right - class MyModule(torch.nn.Module): def forward(self, X): x_len = 10 # 368 start_idx = [4] left_window = 18 - result = index_put_failing_function(x_len, start_idx, left_window, right_window=0) + result = _index_put_failing_function( + x_len, start_idx, left_window, right_window=0 + ) return X + torch.unsqueeze(result, -1) torch_model = MyModule() @@ -49,6 +53,48 @@ def forward(self, X): got = ref.run(None, {"x": inputs[0].numpy()}) torch.testing.assert_close(expected, torch.tensor(got[0])) + def test_register_custom_op(self): + def index_put_failing_function( + device: torch.device, + x_len: int, + start_idx: List[int], + left_window: int, + right_window: int, + ) -> torch.Tensor: + return _index_put_failing_function( + x_len, start_idx, left_window, right_window + ) + + def index_put_failing_function_shape(device, x_len, start_idx, left_window, right_window): + return torch.empty((x_len, x_len), dtype=torch.bool).to(device) + + def register_custom_op(fct, fct_shape, namespace, fname): + schema_str = torch.library.infer_schema(fct, mutates_args=()) + custom_def = torch.library.CustomOpDef(namespace, fname, schema_str, fct) + custom_def.register_kernel("cpu")(fct) + custom_def._abstract_fn = fct_shape + + register_custom_op( + index_put_failing_function, + index_put_failing_function_shape, + "test_delayed", + "index_put_failing_function", + ) + + class MyModule(torch.nn.Module): + def forward(self, X): + x_len = 10 # 368 + start_idx = [4] + left_window = 18 + result = torch.ops.test_delayed.index_put_failing_function( + "cpu", x_len, start_idx, left_window, 0 + ) + return X + torch.unsqueeze(result, -1) + + inputs = (torch.randn(1, 1, 368),) + ep = torch.export.export(MyModule(), args=inputs, strict=False) + self.assertIn("torch.ops.test_delayed.index_put_failing_function.default", str(ep)) + if __name__ == "__main__": unittest.main(verbosity=2) From 86f51ee3a2419c213012aa5334e0dfa6cdd178b9 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 17 Jan 2025 15:36:29 +0100 Subject: [PATCH 08/18] fix one test --- tests/function_libs/torch_lib/e2e_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/function_libs/torch_lib/e2e_test.py b/tests/function_libs/torch_lib/e2e_test.py index ae8fcb6091..e5ec66ab6c 100644 --- a/tests/function_libs/torch_lib/e2e_test.py +++ b/tests/function_libs/torch_lib/e2e_test.py @@ -53,6 +53,7 @@ def forward(self, X): got = ref.run(None, {"x": inputs[0].numpy()}) torch.testing.assert_close(expected, torch.tensor(got[0])) + @unittest.skipIf(torch_older_than("2.6"), reason="no infer_schema") def test_register_custom_op(self): def index_put_failing_function( device: torch.device, From b265c12a55edb782d917b97404290acd75e5f3e8 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 17 Jan 2025 16:13:59 +0100 Subject: [PATCH 09/18] lint --- tests/function_libs/torch_lib/e2e_test.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/function_libs/torch_lib/e2e_test.py b/tests/function_libs/torch_lib/e2e_test.py index e5ec66ab6c..d75c146c2c 100644 --- a/tests/function_libs/torch_lib/e2e_test.py +++ b/tests/function_libs/torch_lib/e2e_test.py @@ -62,18 +62,23 @@ def index_put_failing_function( left_window: int, right_window: int, ) -> torch.Tensor: + del device return _index_put_failing_function( x_len, start_idx, left_window, right_window ) def index_put_failing_function_shape(device, x_len, start_idx, left_window, right_window): + del device + del start_idx + del left_window + del right_window return torch.empty((x_len, x_len), dtype=torch.bool).to(device) def register_custom_op(fct, fct_shape, namespace, fname): schema_str = torch.library.infer_schema(fct, mutates_args=()) custom_def = torch.library.CustomOpDef(namespace, fname, schema_str, fct) custom_def.register_kernel("cpu")(fct) - custom_def._abstract_fn = fct_shape + custom_def._abstract_fn = fct_shape # pylint: disable=protected-access register_custom_op( index_put_failing_function, From dbdac0c638e80f19ede7e25e231d30cce24ccb10 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 17 Jan 2025 17:48:57 +0100 Subject: [PATCH 10/18] lint --- tests/function_libs/torch_lib/e2e_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/function_libs/torch_lib/e2e_test.py b/tests/function_libs/torch_lib/e2e_test.py index d75c146c2c..f2a0dad8e5 100644 --- a/tests/function_libs/torch_lib/e2e_test.py +++ b/tests/function_libs/torch_lib/e2e_test.py @@ -63,11 +63,11 @@ def index_put_failing_function( right_window: int, ) -> torch.Tensor: del device - return _index_put_failing_function( - x_len, start_idx, left_window, right_window - ) + return _index_put_failing_function(x_len, start_idx, left_window, right_window) - def index_put_failing_function_shape(device, x_len, start_idx, left_window, right_window): + def index_put_failing_function_shape( + device, x_len, start_idx, left_window, right_window + ): del device del start_idx del left_window From 035b8aade0707a3c3f5ea26e98295e181517993b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 22 Jan 2025 10:12:51 -0800 Subject: [PATCH 11/18] Delete tests/function_libs/torch_lib/e2e_test.py --- tests/function_libs/torch_lib/e2e_test.py | 106 ---------------------- 1 file changed, 106 deletions(-) delete mode 100644 tests/function_libs/torch_lib/e2e_test.py diff --git a/tests/function_libs/torch_lib/e2e_test.py b/tests/function_libs/torch_lib/e2e_test.py deleted file mode 100644 index f2a0dad8e5..0000000000 --- a/tests/function_libs/torch_lib/e2e_test.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import unittest -from typing import List - -import onnx -import torch - -from onnxscript._internal.version_utils import torch_older_than - - -def _index_put_failing_function(x_len, start_idx, left_window=0, right_window=0): - start_idx = torch.Tensor(start_idx).long() - start_pad = torch.nn.functional.pad(start_idx, (1, 0)) - end_pad = torch.nn.functional.pad(start_idx, (0, 1), value=x_len) - seq_range = torch.arange(0, x_len).unsqueeze(-1) - idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1] - seq_range_expand = torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) - idx_left = idx - left_window - idx_left[idx_left < 0] = 0 - boundary_left = start_pad[idx_left] - mask_left = seq_range_expand >= boundary_left.unsqueeze(-1) - idx_right = idx + right_window - idx_right[idx_right > len(start_idx)] = len(start_idx) - boundary_right = end_pad[idx_right] - mask_right = seq_range_expand < boundary_right.unsqueeze(-1) - return mask_left & mask_right - - -class TestEnd2End(unittest.TestCase): - @unittest.skipIf(torch_older_than("2.6"), reason="fails to export") - def test_index_put_failing_function(self): - class MyModule(torch.nn.Module): - def forward(self, X): - x_len = 10 # 368 - start_idx = [4] - left_window = 18 - result = _index_put_failing_function( - x_len, start_idx, left_window, right_window=0 - ) - return X + torch.unsqueeze(result, -1) - - torch_model = MyModule() - torch_model.eval() - inputs = (torch.randn(1, 1, 368),) - expected = torch_model(*inputs) - - program = torch.onnx.export(torch_model, inputs, dynamo=True) - # program.save(r"test_index_put_failing_function_not_optimized.onnx") - program.optimize() - program.save(r"test_index_put_failing_function.onnx") - ref = onnx.reference.ReferenceEvaluator(program.model_proto) - got = ref.run(None, {"x": inputs[0].numpy()}) - torch.testing.assert_close(expected, torch.tensor(got[0])) - - @unittest.skipIf(torch_older_than("2.6"), reason="no infer_schema") - def test_register_custom_op(self): - def index_put_failing_function( - device: torch.device, - x_len: int, - start_idx: List[int], - left_window: int, - right_window: int, - ) -> torch.Tensor: - del device - return _index_put_failing_function(x_len, start_idx, left_window, right_window) - - def index_put_failing_function_shape( - device, x_len, start_idx, left_window, right_window - ): - del device - del start_idx - del left_window - del right_window - return torch.empty((x_len, x_len), dtype=torch.bool).to(device) - - def register_custom_op(fct, fct_shape, namespace, fname): - schema_str = torch.library.infer_schema(fct, mutates_args=()) - custom_def = torch.library.CustomOpDef(namespace, fname, schema_str, fct) - custom_def.register_kernel("cpu")(fct) - custom_def._abstract_fn = fct_shape # pylint: disable=protected-access - - register_custom_op( - index_put_failing_function, - index_put_failing_function_shape, - "test_delayed", - "index_put_failing_function", - ) - - class MyModule(torch.nn.Module): - def forward(self, X): - x_len = 10 # 368 - start_idx = [4] - left_window = 18 - result = torch.ops.test_delayed.index_put_failing_function( - "cpu", x_len, start_idx, left_window, 0 - ) - return X + torch.unsqueeze(result, -1) - - inputs = (torch.randn(1, 1, 368),) - ep = torch.export.export(MyModule(), args=inputs, strict=False) - self.assertIn("torch.ops.test_delayed.index_put_failing_function.default", str(ep)) - - -if __name__ == "__main__": - unittest.main(verbosity=2) From e28e753884813b3ca463d387b1936744d3a0c8ff Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 22 Jan 2025 10:15:59 -0800 Subject: [PATCH 12/18] Add trace_only to torch_op decorators --- onnxscript/function_libs/torch_lib/ops/core.py | 11 ++++++----- tests/function_libs/torch_lib/ops_test_data.py | 1 - 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index b85cefad97..9db5d3d9e9 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4262,7 +4262,7 @@ def aten_index_copy( raise NotImplementedError() -@torch_op(("aten::index_put", "aten::_unsafe_index_put")) +@torch_op(("aten::index_put", "aten::_unsafe_index_put"), trace_only=True) def aten_index_put( self: TReal, indices: Sequence[INT64], @@ -4276,10 +4276,10 @@ def aten_index_put( """ # TODO(justinchuby): Handle when indicies has more than one element - index = op.SequenceAt(indices, 0) + index = indices[0] new_index = op.Unsqueeze(index, [-1]) - if op.Cast(accumulate, to=BOOL.dtype): + if accumulate: result = op.ScatterND(self, new_index, values, reduction="add") else: result = op.ScatterND(self, new_index, values) @@ -4287,7 +4287,7 @@ def aten_index_put( return result -@torch_op("aten::index_put") +@torch_op("aten::index_put", trace_only=True) def aten_index_put_bool( self: TReal, indices: Sequence[BOOL], @@ -4296,7 +4296,8 @@ def aten_index_put_bool( ) -> TReal: """index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor""" - index = op.SequenceAt(indices, 0) # assume indices only have 1 element + # TODO: Support indices with more than 1 elements + index = indices[0] # accumulate should be always False, True does not make sense but an assert would be great return op.Where(index, values, self) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 8422ab7306..a9d98bda64 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -857,7 +857,6 @@ def _where_input_wrangler( matcher=lambda sample: sample.args[0][0].dtype != torch.bool, reason="this Aten overload only supports tensor(bool) as indices", ) - .skip(reason="FIXME: https://github.com/microsoft/onnxscript/issues/1749"), TorchLibOpInfo( "index_put", core_ops.aten_index_put, From e11d1af3367366355f1c015720acf9a94f21720b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 22 Jan 2025 10:39:18 -0800 Subject: [PATCH 13/18] Fix misplaced closing parenthesis in ops_test_data.py --- tests/function_libs/torch_lib/ops_test_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index eac2614ee9..f90bd5f36a 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -855,7 +855,7 @@ def _where_input_wrangler( .skip( matcher=lambda sample: sample.args[0][0].dtype != torch.bool, reason="this Aten overload only supports tensor(bool) as indices", - ) + ), TorchLibOpInfo( "index_put", core_ops.aten_index_put, From ced2d5363fc0eeaf73c4b391db9cdcf897453afd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 23 Jan 2025 13:38:12 +0100 Subject: [PATCH 14/18] Update onnxscript/function_libs/torch_lib/ops/core.py Co-authored-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 8f75bfac35..4f93971c26 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4298,6 +4298,11 @@ def aten_index_put_bool( # TODO: Support indices with more than 1 elements index = indices[0] # accumulate should be always False, True does not make sense but an assert would be great + # Reshape indices so it can be properly broadcasted + shape = [1] * len(self.shape) + shape[0] = -1 + shape = op.Constant(value_ints=shape) + index = op.Reshape(index, shape) return op.Where(index, values, self) From 5a6e11f8ab1e4193f5ca079130a6657cde591099 Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 23 Jan 2025 13:43:58 +0100 Subject: [PATCH 15/18] fix index_put bool for other dimension --- .../function_libs/torch_lib/ops/core.py | 10 ++- .../function_libs/torch_lib/aten_ops_test.py | 83 +++++++++++++++++++ 2 files changed, 89 insertions(+), 4 deletions(-) create mode 100644 tests/function_libs/torch_lib/aten_ops_test.py diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 4f93971c26..b220105caa 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4299,10 +4299,12 @@ def aten_index_put_bool( index = indices[0] # accumulate should be always False, True does not make sense but an assert would be great # Reshape indices so it can be properly broadcasted - shape = [1] * len(self.shape) - shape[0] = -1 - shape = op.Constant(value_ints=shape) - index = op.Reshape(index, shape) + lself, lindex = len(self.shape), len(index.shape) + if lself > lindex: + shape = op.Shape(index) + append = op.Constant(value_ints=[1 for _ in range(lself - lindex)]) + new_shape = op.Concat(shape, append, axis=0) + index = op.Reshape(index, new_shape) return op.Where(index, values, self) diff --git a/tests/function_libs/torch_lib/aten_ops_test.py b/tests/function_libs/torch_lib/aten_ops_test.py new file mode 100644 index 0000000000..d413de0779 --- /dev/null +++ b/tests/function_libs/torch_lib/aten_ops_test.py @@ -0,0 +1,83 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest + +import onnxruntime +import torch + + +class TestOnnxExportAten(unittest.TestCase): + def test_aten_index_put_mask_bool_fixed_broadcast_2d(self): + class Model(torch.nn.Module): + def forward(self, x, values): + x = x.clone() + mask = torch.tensor([True, False, True, True, False]).to(torch.bool) + x[mask] = values + return x + + model = Model() + xs = ( + torch.arange(25).reshape((5, 5)).to(torch.float32), + torch.tensor([700, 800, 900, 1000, 1100], dtype=torch.float32), + ) + expected = model(*xs) + ep = torch.onnx.export(model, xs, dynamo=True) + sess = onnxruntime.InferenceSession( + ep.model_proto.SerializeToString(), + providers=["CPUExecutionProvider"], + ) + feeds = dict(zip([i.name for i in sess.get_inputs()], [x.numpy() for x in xs])) + got = sess.run(None, feeds)[0] + torch.testing.assert_close(expected, torch.from_numpy(got)) + + def test_aten_index_put_mask_bool_fixed_broadcast_3d(self): + class Model(torch.nn.Module): + def forward(self, x, values): + x = x.clone() + mask = torch.tensor([True, False]).to(torch.bool) + x[mask] = values + return x + # return torch.ops.aten.index_put(x, (mask,), values) + + model = Model() + xs = ( + torch.arange(2 * 3 * 5).reshape((2, 3, 5)).to(torch.float32), + torch.tensor([700, 800, 900, 1000, 1100], dtype=torch.float32), + ) + expected = model(*xs) + ep = torch.onnx.export(model, xs, dynamo=True) + sess = onnxruntime.InferenceSession( + ep.model_proto.SerializeToString(), + providers=["CPUExecutionProvider"], + ) + feeds = dict(zip([i.name for i in sess.get_inputs()], [x.numpy() for x in xs])) + got = sess.run(None, feeds)[0] + torch.testing.assert_close(expected, torch.from_numpy(got)) + + def test_aten_index_put_mask_bool_fixed_broadcast_3d_2(self): + class Model(torch.nn.Module): + def forward(self, x, values): + x = x.clone() + mask = torch.tensor([[True, False, False], [True, True, False]]).to(torch.bool) + x[mask] = values + return x + # return torch.ops.aten.index_put(x, (mask,), values) + + model = Model() + xs = ( + torch.arange(2 * 3 * 5).reshape((2, 3, 5)).to(torch.float32), + torch.tensor([700, 800, 900, 1000, 1100], dtype=torch.float32), + ) + expected = model(*xs) + ep = torch.onnx.export(model, xs, dynamo=True) + sess = onnxruntime.InferenceSession( + ep.model_proto.SerializeToString(), + providers=["CPUExecutionProvider"], + ) + feeds = dict(zip([i.name for i in sess.get_inputs()], [x.numpy() for x in xs])) + got = sess.run(None, feeds)[0] + torch.testing.assert_close(expected, torch.from_numpy(got)) + + +if __name__ == "__main__": + unittest.main(verbosity=2) From a53a2f98fddc751ed057f59a0cd1018e132b3d4c Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 23 Jan 2025 14:32:51 +0100 Subject: [PATCH 16/18] lint --- tests/function_libs/torch_lib/ops_test_data.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index e56930eba5..35e1778ca2 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -852,8 +852,7 @@ def _where_input_wrangler( TorchLibOpInfo( "index_put_bool", core_ops.aten_index_put_bool, - ) - .skip( + ).skip( matcher=lambda sample: sample.args[0][0].dtype != torch.bool, reason="this Aten overload only supports tensor(bool) as indices", ), From 7d8a47a344deac243fbcc64d0745ee43a42fdbbc Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 23 Jan 2025 07:53:37 -0800 Subject: [PATCH 17/18] Update onnxscript/function_libs/torch_lib/ops/core.py --- onnxscript/function_libs/torch_lib/ops/core.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index b220105caa..f980465bc4 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4299,12 +4299,13 @@ def aten_index_put_bool( index = indices[0] # accumulate should be always False, True does not make sense but an assert would be great # Reshape indices so it can be properly broadcasted - lself, lindex = len(self.shape), len(index.shape) - if lself > lindex: - shape = op.Shape(index) - append = op.Constant(value_ints=[1 for _ in range(lself - lindex)]) - new_shape = op.Concat(shape, append, axis=0) - index = op.Reshape(index, new_shape) + self_rank = len(self.shape) + index_rank = len(index.shape) + if self_rank > index_rank: + index_shape = op.Shape(index) + padding = op.Constant(value_ints=[1 for _ in range(self_rank - index_rank)]) + padded_shape = op.Concat(index_shape, padding, axis=0) + index = op.Reshape(index, padded_shape) return op.Where(index, values, self) From 87870f93f5e4da22bc04411aefa075e541520b15 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 23 Jan 2025 17:47:17 -0800 Subject: [PATCH 18/18] Delete tests/function_libs/torch_lib/aten_ops_test.py --- .../function_libs/torch_lib/aten_ops_test.py | 83 ------------------- 1 file changed, 83 deletions(-) delete mode 100644 tests/function_libs/torch_lib/aten_ops_test.py diff --git a/tests/function_libs/torch_lib/aten_ops_test.py b/tests/function_libs/torch_lib/aten_ops_test.py deleted file mode 100644 index d413de0779..0000000000 --- a/tests/function_libs/torch_lib/aten_ops_test.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import unittest - -import onnxruntime -import torch - - -class TestOnnxExportAten(unittest.TestCase): - def test_aten_index_put_mask_bool_fixed_broadcast_2d(self): - class Model(torch.nn.Module): - def forward(self, x, values): - x = x.clone() - mask = torch.tensor([True, False, True, True, False]).to(torch.bool) - x[mask] = values - return x - - model = Model() - xs = ( - torch.arange(25).reshape((5, 5)).to(torch.float32), - torch.tensor([700, 800, 900, 1000, 1100], dtype=torch.float32), - ) - expected = model(*xs) - ep = torch.onnx.export(model, xs, dynamo=True) - sess = onnxruntime.InferenceSession( - ep.model_proto.SerializeToString(), - providers=["CPUExecutionProvider"], - ) - feeds = dict(zip([i.name for i in sess.get_inputs()], [x.numpy() for x in xs])) - got = sess.run(None, feeds)[0] - torch.testing.assert_close(expected, torch.from_numpy(got)) - - def test_aten_index_put_mask_bool_fixed_broadcast_3d(self): - class Model(torch.nn.Module): - def forward(self, x, values): - x = x.clone() - mask = torch.tensor([True, False]).to(torch.bool) - x[mask] = values - return x - # return torch.ops.aten.index_put(x, (mask,), values) - - model = Model() - xs = ( - torch.arange(2 * 3 * 5).reshape((2, 3, 5)).to(torch.float32), - torch.tensor([700, 800, 900, 1000, 1100], dtype=torch.float32), - ) - expected = model(*xs) - ep = torch.onnx.export(model, xs, dynamo=True) - sess = onnxruntime.InferenceSession( - ep.model_proto.SerializeToString(), - providers=["CPUExecutionProvider"], - ) - feeds = dict(zip([i.name for i in sess.get_inputs()], [x.numpy() for x in xs])) - got = sess.run(None, feeds)[0] - torch.testing.assert_close(expected, torch.from_numpy(got)) - - def test_aten_index_put_mask_bool_fixed_broadcast_3d_2(self): - class Model(torch.nn.Module): - def forward(self, x, values): - x = x.clone() - mask = torch.tensor([[True, False, False], [True, True, False]]).to(torch.bool) - x[mask] = values - return x - # return torch.ops.aten.index_put(x, (mask,), values) - - model = Model() - xs = ( - torch.arange(2 * 3 * 5).reshape((2, 3, 5)).to(torch.float32), - torch.tensor([700, 800, 900, 1000, 1100], dtype=torch.float32), - ) - expected = model(*xs) - ep = torch.onnx.export(model, xs, dynamo=True) - sess = onnxruntime.InferenceSession( - ep.model_proto.SerializeToString(), - providers=["CPUExecutionProvider"], - ) - feeds = dict(zip([i.name for i in sess.get_inputs()], [x.numpy() for x in xs])) - got = sess.run(None, feeds)[0] - torch.testing.assert_close(expected, torch.from_numpy(got)) - - -if __name__ == "__main__": - unittest.main(verbosity=2)