From 166a482e73e9fce17ed23d93b4a8488908cc3d2d Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Mon, 11 Nov 2024 21:33:50 +0800 Subject: [PATCH] [Relax] Update ONNX frontend for unique, nonzero and compress This PR updates the ONNX frontend: - Add match cast for unique and nonzero operators, enabling further import of ONNX models. - Add support for compress operator. - Fix the shape of the output tensor for nonzero operator. --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 52 +++++++++++++++++-- python/tvm/relax/op/set.py | 2 +- src/relax/op/tensor/set.cc | 4 +- tests/python/relax/test_frontend_onnx.py | 30 ++++++++++- tests/python/relax/test_op_set.py | 2 +- 5 files changed, 81 insertions(+), 9 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index eb7a3eaf3628..94ccfdb23e01 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -833,6 +833,32 @@ def _impl_v18(cls, bb, inputs, attr, params): return relax.op.scatter_nd(inputs[0], inputs[1], inputs[2], reduction) +class Compress(OnnxOpConverter): + """Convert an onnx Compress node into an equivalent Relax expression.""" + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + tensor, condition = inputs + axis = attr.get("axis", None) + + # Change one hot tensor to indices e.g. [0, 1, 1, 0, 1] -> [1, 2, 4] + if condition.struct_info.dtype != "bool": + raise ValueError("Condition tensor is expected to be a boolean tensor") + if condition.struct_info.ndim != 1: + raise ValueError("Condition tensor is expected to be a 1D boolean tensor") + indices = relax.op.nonzero(condition) + num_nonzero = tir.Var("num_nonzero", "int64") + indices = bb.match_cast(indices, relax.TensorStructInfo([1, num_nonzero], "int64")) + indices = relax.op.reshape(indices, [-1]) + + if axis is not None: + return relax.op.take(tensor, indices, axis=axis) + + # if axis is None, flatten input tensor before selection + tensor = relax.op.reshape(tensor, (-1,)) + return relax.op.take(tensor, indices, axis=0) + + class Size(OnnxOpConverter): """Convert an onnx Size node into an equivalent Relax expression.""" @@ -2726,7 +2752,22 @@ def _impl_v11(cls, bb, inputs, attr, params): axis = attr.get("axis", None) sorted = bool(attr.get("sorted", 1)) # TODO(tvm-team): Add support for return_index, return_inverse, return_counts - return relax.op.unique(data, sorted=sorted, axis=axis) + unique = relax.op.unique(data, sorted=sorted, axis=axis) + unique_numbers = tir.Var("unique_numbers", "int64") + input_shape = data.struct_info.shape + dtype = data.struct_info.dtype + + if axis is None: + # flatten the input tensor + return bb.match_cast(unique, relax.TensorStructInfo((unique_numbers,), dtype)) + + axis = axis if axis >= 0 else len(input_shape) + axis + if axis < 0 or axis >= len(input_shape): + raise ValueError(f"Axis {axis} is out of bounds") + output_shape = [ + input_shape[i] if i != axis else unique_numbers for i in range(len(input_shape)) + ] + return bb.match_cast(unique, relax.TensorStructInfo(output_shape, dtype)) class NonZero(OnnxOpConverter): @@ -2734,7 +2775,12 @@ class NonZero(OnnxOpConverter): @classmethod def _impl_v9(cls, bb, inputs, attr, params): - return relax.op.nonzero(inputs[0]) + ndim = inputs[0].struct_info.ndim + ndim = 1 if ndim == 0 else ndim + nonzero_numbers = tir.Var("nonzero_numbers", "int64") + return bb.match_cast( + relax.op.nonzero(inputs[0]), relax.TensorStructInfo((ndim, nonzero_numbers), "int64") + ) class HardSigmoid(OnnxOpConverter): @@ -3075,7 +3121,7 @@ def _get_convert_map(): "Scatter": Scatter, "ScatterElements": ScatterElements, "ScatterND": ScatterND, - # "Compress": Compress, + "Compress": Compress, "Size": Size, "EyeLike": EyeLike, # Normalization diff --git a/python/tvm/relax/op/set.py b/python/tvm/relax/op/set.py index c5db852ddd5d..ed4b2e2ff928 100644 --- a/python/tvm/relax/op/set.py +++ b/python/tvm/relax/op/set.py @@ -123,7 +123,7 @@ def nonzero(x: Expr) -> Expr: Returns ------- result : relax.Expr - A (n+1)-D tensor containing indices of non-zero elements. + A 2-D tensor containing indices of non-zero elements. Note ---- diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc index c659a49afd12..e2aef8005e78 100644 --- a/src/relax/op/tensor/set.cc +++ b/src/relax/op/tensor/set.cc @@ -148,9 +148,7 @@ TVM_REGISTER_GLOBAL("relax.op.nonzero").set_body_typed(nonzero); StructInfo InferStructInfoNonzero(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); - // Cheat zero dim scalar as 1-dim. - int dim = data_sinfo->IsUnknownNdim() ? kUnknownNDim : std::max(1, data_sinfo->ndim) + 1; - return TensorStructInfo(DataType::Int(64), dim, data_sinfo->vdevice); + return TensorStructInfo(DataType::Int(64), 2, data_sinfo->vdevice); } TVM_REGISTER_OP("relax.nonzero") diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 6f74957a0781..a4a4f78bd3ef 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -601,6 +601,34 @@ def verify_scatter_nd(data_shape, indices_shape, updates_shape): verify_scatter_nd([10], [5, 1], [5]) +@pytest.mark.parametrize("tensor_shape", [[32, 32]]) +@pytest.mark.parametrize("condition_shape", [None, [8], [16]]) +@pytest.mark.parametrize("axis", [None, 0, 1]) +def test_compress( + tensor_shape: List[int], + condition_shape: Optional[List[int]], + axis: Optional[int], +): + if condition_shape is None and axis is None: + pytest.skip("Either condition_shape or axis must be specified") + if condition_shape is None: + condition_shape = [tensor_shape[axis]] + compress_node = helper.make_node("Compress", ["tensor", "condition"], ["output"], axis=axis) + graph = helper.make_graph( + [compress_node], + "compress_test", + inputs=[ + helper.make_tensor_value_info("tensor", TensorProto.FLOAT, tensor_shape), + helper.make_tensor_value_info("condition", TensorProto.BOOL, condition_shape), + ], + outputs=[ + helper.make_tensor_value_info("output", TensorProto.FLOAT, []) + ], # shape is unknown + ) + model = helper.make_model(graph, producer_name="compress_test") + check_correctness(model, opset=11) + + def test_size(): test_node = helper.make_node("Size", ["x"], ["y"]) graph = helper.make_graph( @@ -2478,7 +2506,7 @@ def test_unique(axis: Optional[int], sorted: int): check_correctness(model) -@pytest.mark.parametrize("shape", [(), (1,), (2, 3), (4, 5, 6)]) +@pytest.mark.parametrize("shape", [(), (1,), (2, 3), (4, 5, 6), (7, 8, 9, 10)]) def test_nonzero(shape): verify_unary("NonZero", shape, input_dtype=TensorProto.BOOL, output_dtype=TensorProto.INT64) diff --git a/tests/python/relax/test_op_set.py b/tests/python/relax/test_op_set.py index e9070f99fc3f..05b6d8887b7a 100644 --- a/tests/python/relax/test_op_set.py +++ b/tests/python/relax/test_op_set.py @@ -875,7 +875,7 @@ def test_nonzero_infer_struct_info(shape): _check_inference( bb, relax.op.nonzero(x0), - relax.TensorStructInfo(ndim=len(shape) + 1, dtype="int64"), + relax.TensorStructInfo(ndim=2, dtype="int64"), )