diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 6e0c7cc2dd3f..93429a863889 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -6148,13 +6148,35 @@ def _impl_v11(cls, inputs, attr, params): return _expr.Tuple(inputs) -class SequenceLength(OnnxOpConverter): - """Operator converter for sequence length op.""" +class SequenceErase(OnnxOpConverter): + """Operator converter for sequence erase op.""" @classmethod def _impl_v11(cls, inputs, attr, params): - # Get length of input sequence - return _expr.const(len(inputs[0]), dtype="int64") + # Erase tensor from sequence on specified position + input_sequence = inputs[0] + + if len(inputs) == 2: + position = inputs[1] + # Non constant position is not supported. + if isinstance(position, _expr.Constant): + position = position.data.numpy() + elif position.name_hint in params: + position = params[position.name_hint].numpy() + else: + raise NotImplementedError("Position must be a constant.") + else: + position = -1 + + seq_len = len(input_sequence) + assert -seq_len <= position < seq_len, "Position is out of bounds" + + if position < 0: + position = seq_len + position + # Convert sequence to a list, insert tensors before erased, and repackage as Tuple. + tensor_list = [input_sequence[i] for i in range(seq_len) if i != position] + # Create new tuple and return. + return _expr.Tuple(tensor_list) class SequenceInsert(OnnxOpConverter): @@ -6188,6 +6210,15 @@ def _impl_v11(cls, inputs, attr, params): return _expr.Tuple(tensor_list) +class SequenceLength(OnnxOpConverter): + """Operator converter for sequence length op.""" + + @classmethod + def _impl_v11(cls, inputs, attr, params): + # Get length of input sequence + return _expr.const(len(inputs[0]), dtype="int64") + + class ConcatFromSequence(OnnxOpConverter): """Operator converter for sequence concatenation op.""" @@ -6492,8 +6523,9 @@ def _get_convert_map(opset): "LinearRegressor": LinearRegressor.get_converter(opset), # Sequence operators "SequenceConstruct": SequenceConstruct.get_converter(opset), - "SequenceLength": SequenceLength.get_converter(opset), + "SequenceErase": SequenceErase.get_converter(opset), "SequenceInsert": SequenceInsert.get_converter(opset), + "SequenceLength": SequenceLength.get_converter(opset), "ConcatFromSequence": ConcatFromSequence.get_converter(opset), "SplitToSequence": SplitToSequence.get_converter(opset), "SequenceAt": SequenceAt.get_converter(opset), diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 6a780a632fb7..3e1af4086784 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -7747,10 +7747,17 @@ def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=0, new_axis= outputs=["inserted_sequence"], ) + # Test sequence erase. + erase_node = helper.make_node( + "SequenceErase", + inputs=["inserted_sequence", "position"], + outputs=["erased_sequence"], + ) + # Test sequence concatenation. concat_node = helper.make_node( "ConcatFromSequence", - inputs=["inserted_sequence"], + inputs=["erased_sequence"], outputs=["concat_sequence"], axis=axis, ) @@ -7796,6 +7803,7 @@ def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=0, new_axis= position_node, construct_node, insert_node, + erase_node, concat_node, split_node, at_node,