diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 58b2ae3211..f980465bc4 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4261,7 +4261,7 @@ def aten_index_copy( raise NotImplementedError() -@torch_op(("aten::index_put", "aten::_unsafe_index_put")) +@torch_op(("aten::index_put", "aten::_unsafe_index_put"), trace_only=True) def aten_index_put( self: TReal, indices: Sequence[INT64], @@ -4275,10 +4275,10 @@ def aten_index_put( """ # TODO(justinchuby): Handle when indicies has more than one element - index = op.SequenceAt(indices, 0) + index = indices[0] new_index = op.Unsqueeze(index, [-1]) - if op.Cast(accumulate, to=BOOL.dtype): + if accumulate: result = op.ScatterND(self, new_index, values, reduction="add") else: result = op.ScatterND(self, new_index, values) @@ -4286,7 +4286,7 @@ def aten_index_put( return result -@torch_op("aten::index_put") +@torch_op("aten::index_put", trace_only=True) def aten_index_put_bool( self: TReal, indices: Sequence[BOOL], @@ -4295,37 +4295,18 @@ def aten_index_put_bool( ) -> TReal: """index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor""" - index = op.SequenceAt(indices, 0) # assume indices only have 1 element - # FIXME: ORT ArgMax fails on INT64 input even though ONNX allows it - index_int = op.Cast(index, to=INT32.dtype) - # if all False, return op.Identity(self) - if op.ReduceSum(index_int) == 0: - result = self - else: - # change array([F,F,T,F,F]) to array([2]) - index = op.ArgMax(index_int) # assume index only have 1 True - # change array([2]) to array([2,2,2,2,2]) - self_dim_1 = op.Shape(self, start=1, end=2) - index_dim_0 = op.Shape(index, start=0, end=1) - shape = op.Concat(self_dim_1, index_dim_0, axis=0) - new_ind = op.Expand(index, shape) - new_ind_t = op.Transpose(new_ind) - - # values must have same rank with input(self) - if op.Size(op.Shape(values)) < op.Size(op.Shape(self)): # type: ignore[operator] - values = op.Unsqueeze(values, op.Constant(value_ints=[0])) - - if op.Cast(accumulate, to=BOOL.dtype): - zeros = op.Expand(op.Constant(value_float=0.0), op.Shape(self)) - zeros = op.CastLike(zeros, values) - result = op.ScatterElements(zeros, new_ind_t, values) - # FIXME: type promotion - result = op.CastLike(result, self) - result = op.Add(result, self) - else: - result = op.ScatterElements(self, new_ind_t, values) - - return result + # TODO: Support indices with more than 1 elements + index = indices[0] + # accumulate should be always False, True does not make sense but an assert would be great + # Reshape indices so it can be properly broadcasted + self_rank = len(self.shape) + index_rank = len(index.shape) + if self_rank > index_rank: + index_shape = op.Shape(index) + padding = op.Constant(value_ints=[1 for _ in range(self_rank - index_rank)]) + padded_shape = op.Concat(index_shape, padding, axis=0) + index = op.Reshape(index, padded_shape) + return op.Where(index, values, self) def aten_index_reduce( diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 8f40a50061..35e1778ca2 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -852,12 +852,10 @@ def _where_input_wrangler( TorchLibOpInfo( "index_put_bool", core_ops.aten_index_put_bool, - ) - .skip( + ).skip( matcher=lambda sample: sample.args[0][0].dtype != torch.bool, reason="this Aten overload only supports tensor(bool) as indices", - ) - .skip(reason="FIXME: https://github.com/microsoft/onnxscript/issues/1749"), + ), TorchLibOpInfo( "index_put", core_ops.aten_index_put,