From 3fbf7cd4002c1ee6672210138b6844b804c123e9 Mon Sep 17 00:00:00 2001 From: Patrik Persson Date: Tue, 11 Feb 2025 19:18:05 +0100 Subject: [PATCH 1/4] replaced topi.split with relax.op.split in onnx frontend fixed related onnx frontend unit tests --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 4 ++-- .../tvm/relax/transform/legalize_ops/manipulate.py | 13 +++++++------ src/relax/op/tensor/manipulate.cc | 12 ++++++------ 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 573cde982bea..e1f09233bcff 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1717,7 +1717,7 @@ def _impl_v1(cls, bb, inputs, attr, params): # When splits isnt specified divide evenly over axis. else: indices = attr["tvm_custom"]["num_outputs"] - return bb.emit_te(topi.split, inputs[0], indices, attr.get("axis", 0)) + return relax.op.split(inputs[0], indices, attr.get("axis", 0)) @classmethod def _impl_v13(cls, bb, inputs, attr, params): @@ -1738,7 +1738,7 @@ def _impl_v13(cls, bb, inputs, attr, params): # When splits isnt specified divide evenly over axis. else: indices = attr["tvm_custom"]["num_outputs"] - return bb.emit_te(topi.split, inputs[0], indices, axis=attr.get("axis", 0)) + return relax.op.split(inputs[0], indices, attr.get("axis", 0)) def get_prim_value_list(values): diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 55bc2772bcce..c71a41dc1c2d 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -112,12 +112,13 @@ def _split(bb: BlockBuilder, call: Call) -> Expr: modulo = tvm.arith.Analyzer().simplify( call.args[0].struct_info.shape.values[call.attrs.axis] % indices_or_sections ) - if modulo != 0: - logging.info( - "Split cannot be legalized by TOPI when the axis being split has " - "length that not divisible by the input number of section." - ) - return call + if isinstance(modulo, tir.IntImm): + if modulo != 0: + logging.info( + "Split cannot be legalized by TOPI when the axis being split has " + "length that not divisible by the input number of section." + ) + return call else: indices_or_sections = call.attrs.indices_or_sections return bb.call_te(topi.split, call.args[0], indices_or_sections, call.attrs.axis) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 452b1f223a80..5dd84e72ecb0 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -864,7 +864,7 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { auto p_indices = opt_indices.value(); // When there is not index, return the input tensor's struct info. if (p_indices.size() == 0) { - return TupleStructInfo({data_sinfo}); + return data_sinfo; } // Fall back to unknown shape when the input tensor doesn't have ShapeExpr as shape. if (data_shape == nullptr) { @@ -911,7 +911,7 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { int n_section = p_n_section->value; // When the number of section is one, return the input tensor's struct info. if (n_section == 1) { - return TupleStructInfo({data_sinfo}); + return data_sinfo; } // Fall back to unknown shape when the input tensor doesn't have ShapeExpr as shape. if (data_shape == nullptr) { @@ -1895,8 +1895,8 @@ StructInfo InferStructInfoOneHot(const Call& call, const BlockBuilder& ctx) { PrimValue off_value = Downcast(call->args[2]); // Check if on_value and off_value have the same dtype ICHECK(on_value->value->dtype == off_value->value->dtype) - << "one_hot: on_value and off_value must have the same dtype, " - << "but got " << on_value->value->dtype << " and " << off_value->value->dtype; + << "one_hot: on_value and off_value must have the same dtype, " << "but got " + << on_value->value->dtype << " and " << off_value->value->dtype; DataType dtype = on_value->value->dtype; // Check if indices has an integer dtype @@ -1924,8 +1924,8 @@ StructInfo InferStructInfoOneHot(const Call& call, const BlockBuilder& ctx) { axis += output_shape.size() + 1; } ICHECK(0 <= axis && axis <= static_cast(output_shape.size())) - << "one_hot: axis must be in the range of [0, " << output_shape.size() << "], " - << "but got " << axis; + << "one_hot: axis must be in the range of [0, " << output_shape.size() << "], " << "but got " + << axis; output_shape.insert(output_shape.begin() + axis, attrs->depth); return TensorStructInfo(ShapeExpr(output_shape), dtype, indices_sinfo->vdevice); From 8dbdf445ba6465f1d8bafd80f92400487f5640ac Mon Sep 17 00:00:00 2001 From: Patrik Persson Date: Wed, 12 Feb 2025 02:20:16 +0100 Subject: [PATCH 2/4] updated formatting --- src/relax/op/tensor/manipulate.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 5dd84e72ecb0..cb738db363ee 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -1895,8 +1895,8 @@ StructInfo InferStructInfoOneHot(const Call& call, const BlockBuilder& ctx) { PrimValue off_value = Downcast(call->args[2]); // Check if on_value and off_value have the same dtype ICHECK(on_value->value->dtype == off_value->value->dtype) - << "one_hot: on_value and off_value must have the same dtype, " << "but got " - << on_value->value->dtype << " and " << off_value->value->dtype; + << "one_hot: on_value and off_value must have the same dtype, " + << "but got " << on_value->value->dtype << " and " << off_value->value->dtype; DataType dtype = on_value->value->dtype; // Check if indices has an integer dtype @@ -1924,8 +1924,8 @@ StructInfo InferStructInfoOneHot(const Call& call, const BlockBuilder& ctx) { axis += output_shape.size() + 1; } ICHECK(0 <= axis && axis <= static_cast(output_shape.size())) - << "one_hot: axis must be in the range of [0, " << output_shape.size() << "], " << "but got " - << axis; + << "one_hot: axis must be in the range of [0, " << output_shape.size() << "], " + << "but got " << axis; output_shape.insert(output_shape.begin() + axis, attrs->depth); return TensorStructInfo(ShapeExpr(output_shape), dtype, indices_sinfo->vdevice); From bfcb8c60079274e1e9fbd914326c79d51ea8eb5a Mon Sep 17 00:00:00 2001 From: Patrik Persson Date: Wed, 12 Feb 2025 02:29:35 +0100 Subject: [PATCH 3/4] updated the split unit tests for the case when only one tensor is returned --- tests/python/relax/test_op_manipulate.py | 28 +++++++++++------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index 23ab6780cf7b..86f0993bfbf6 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -2108,62 +2108,62 @@ def test_split_infer_struct_info_single_output(): _check_inference( bb, relax.op.split(x0, [], axis=1), - relax.TupleStructInfo([relax.TensorStructInfo((a, b), "float32")]), + relax.TensorStructInfo((a, b), "float32"), ) _check_inference( bb, relax.op.split(x1, [], axis=1), - relax.TupleStructInfo([relax.TensorStructInfo(dtype="float32", ndim=2)]), + relax.TensorStructInfo(dtype="float32", ndim=2), ) _check_inference( bb, relax.op.split(x2, [], axis=1), - relax.TupleStructInfo([relax.TensorStructInfo(dtype="float32")]), + relax.TensorStructInfo(dtype="float32"), ) _check_inference( bb, relax.op.split(x3, [], axis=1), - relax.TupleStructInfo([relax.TensorStructInfo(s0, "float32")]), + relax.TensorStructInfo(s0, "float32"), ) _check_inference( bb, relax.op.split(x4, [], axis=1), - relax.TupleStructInfo([relax.TensorStructInfo(s1, "float32")]), + relax.TensorStructInfo(s1, "float32"), ) _check_inference( bb, relax.op.split(x5, [], axis=1), - relax.TupleStructInfo([relax.TensorStructInfo(s2, "float32")]), + relax.TensorStructInfo(s2, "float32"), ) _check_inference( bb, relax.op.split(x0, 1, axis=1), - relax.TupleStructInfo([relax.TensorStructInfo((a, b), "float32")]), + relax.TensorStructInfo((a, b), "float32"), ) _check_inference( bb, relax.op.split(x1, 1, axis=1), - relax.TupleStructInfo([relax.TensorStructInfo(dtype="float32", ndim=2)]), + relax.TensorStructInfo(dtype="float32", ndim=2), ) _check_inference( bb, relax.op.split(x2, 1, axis=1), - relax.TupleStructInfo([relax.TensorStructInfo(dtype="float32")]), + relax.TensorStructInfo(dtype="float32"), ) _check_inference( bb, relax.op.split(x3, 1, axis=1), - relax.TupleStructInfo([relax.TensorStructInfo(s0, "float32")]), + relax.TensorStructInfo(s0, "float32"), ) _check_inference( bb, relax.op.split(x4, 1, axis=1), - relax.TupleStructInfo([relax.TensorStructInfo(s1, "float32")]), + relax.TensorStructInfo(s1, "float32"), ) _check_inference( bb, relax.op.split(x5, 1, axis=1), - relax.TupleStructInfo([relax.TensorStructInfo(s2, "float32")]), + relax.TensorStructInfo(s2, "float32"), ) @@ -2200,9 +2200,7 @@ def test_split_infer_struct_info(): _check_inference( bb, relax.op.split(x, 1), - R.Tuple( - R.Tensor([16, 4]), - ), + R.Tensor([16, 4]), ) _check_inference( bb, From d04f5164d18251d917b6bb70f0797efa6e58b698 Mon Sep 17 00:00:00 2001 From: Patrik Persson Date: Wed, 12 Feb 2025 17:23:14 +0100 Subject: [PATCH 4/4] formatting --- tests/python/relax/test_op_manipulate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index 86f0993bfbf6..28e762d9a4de 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -2163,7 +2163,7 @@ def test_split_infer_struct_info_single_output(): _check_inference( bb, relax.op.split(x5, 1, axis=1), - relax.TensorStructInfo(s2, "float32"), + relax.TensorStructInfo(s2, "float32"), )