From 8fb938c891ded72045d0caf0a66cd4d33bc1c6e7 Mon Sep 17 00:00:00 2001 From: Patrik Persson Date: Thu, 24 Oct 2024 16:01:39 +0200 Subject: [PATCH 1/9] updated slice and squeeze operators to work with dynamic shape expressions --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 30 +++++++++++++++---- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 6c9225070d3f..05406f524fe9 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1199,14 +1199,25 @@ class Squeeze(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr, params): + data = inputs[0] axis = get_constant(inputs[1], params) if isinstance(axis, relax.Constant): axis = [int(x) for x in axis.data.numpy()] # If data is constant, perform computation directly. if isinstance(inputs[0], relax.Constant): - out_data = _np.squeeze(inputs[0].data.numpy(), axis) - return relax.const(out_data, inputs[0].struct_info.dtype) - return relax.op.squeeze(inputs[0], axis) + out_data = _np.squeeze(data.data.numpy(), axis[0]) + return relax.const(out_data, data.struct_info.dtype) + + if isinstance(data, relax.ShapeExpr): + + if axis == [0]: + return relax.PrimValue(data[0]) + else: + raise NotImplementedError( + "Unsqueeze with symbolic axes and non-zero axes is not supported." + ) + + return relax.op.squeeze(data, axis) class Constant(OnnxOpConverter): @@ -1564,7 +1575,7 @@ def _impl_v13(cls, bb, inputs, attr, params): index = 0 for i in splits[:-1]: index += i - indices.append(index) + indices.append(index.item()) else: raise ValueError("Dynamic Split not yet supported") # When splits isnt specified divide evenly over axis. @@ -1611,11 +1622,17 @@ def _impl_v13(cls, bb, inputs, attr, params): steps = [1] * len(axes) # If input is a shape tensor, we can directly extract it. if isinstance(data, relax.ShapeExpr): - shape_data = [dim.value for dim in data] + + shape_data = [dim for dim in data] # Starts, ends, and steps must be 1-d for shape operation. assert all(len(i) == 1 for i in [starts, ends, steps]) sliced_values = shape_data[starts[0] : ends[0] : steps[0]] - return relax.const(sliced_values, "int64") + + if all([isinstance(val, (tir.IntImm, int)) for val in sliced_values]): + return relax.const([x.value for x in sliced_values], "int64") + else: + return relax.ShapeExpr(sliced_values) + # If all `starts`, `ends`, and `steps` are constant, use strict mode # Otherwise, we assume the slice is inbound. assume_inbound = not all( @@ -3220,6 +3237,7 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto): "Equal", "Where", "Cast", + "Squeeze", ] return_tuple_ops = [ "SequenceConstruct", From c198ba971dcf7b3cc4646ea6fff7f059872e9b91 Mon Sep 17 00:00:00 2001 From: Patrik Persson Date: Thu, 24 Oct 2024 16:19:52 +0200 Subject: [PATCH 2/9] updated formatting --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 05406f524fe9..3f29d215a2e0 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1209,7 +1209,6 @@ def _impl_v13(cls, bb, inputs, attr, params): return relax.const(out_data, data.struct_info.dtype) if isinstance(data, relax.ShapeExpr): - if axis == [0]: return relax.PrimValue(data[0]) else: @@ -1622,7 +1621,6 @@ def _impl_v13(cls, bb, inputs, attr, params): steps = [1] * len(axes) # If input is a shape tensor, we can directly extract it. if isinstance(data, relax.ShapeExpr): - shape_data = [dim for dim in data] # Starts, ends, and steps must be 1-d for shape operation. assert all(len(i) == 1 for i in [starts, ends, steps]) From 8e6de88a8e4d335ae6c8f5ba23b5884fbfde6eb2 Mon Sep 17 00:00:00 2001 From: Patrik Persson Date: Thu, 24 Oct 2024 17:01:09 +0200 Subject: [PATCH 3/9] Add support for dynamic ShapeExpr in Flatten --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 26 +++++++++++++++---- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 3f29d215a2e0..6c1c54c08270 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1204,8 +1204,8 @@ def _impl_v13(cls, bb, inputs, attr, params): if isinstance(axis, relax.Constant): axis = [int(x) for x in axis.data.numpy()] # If data is constant, perform computation directly. - if isinstance(inputs[0], relax.Constant): - out_data = _np.squeeze(data.data.numpy(), axis[0]) + if isinstance(data, relax.Constant): + out_data = _np.squeeze(data.data.numpy(), tuple(axis)) return relax.const(out_data, data.struct_info.dtype) if isinstance(data, relax.ShapeExpr): @@ -1213,7 +1213,7 @@ def _impl_v13(cls, bb, inputs, attr, params): return relax.PrimValue(data[0]) else: raise NotImplementedError( - "Unsqueeze with symbolic axes and non-zero axes is not supported." + "Squeeze with symbolic axes and non-zero axes is not supported." ) return relax.op.squeeze(data, axis) @@ -2252,8 +2252,24 @@ class Flatten(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr, params): axis = attr.get("axis", 1) - data_shape = [i.value for i in inputs[0].struct_info.shape] - new_shape = (1, -1) if axis == 0 else (_np.prod(data_shape[0:axis]).astype("int64"), -1) + data_shape = [i for i in inputs[0].struct_info.shape] + + if axis == 0: + new_shape = (1, -1) + else: + shape_flags = [isinstance(x, tvm.script.tir.IntImm) for x in data_shape[0:axis]] + + if all(shape_flags): + data_shape = [x.value for x in data_shape[0:axis]] + new_shape = (_np.prod(data_shape[0:axis]).astype("int64"), -1) + else: + batch_size = 1 + + for el in data_shape[0:axis]: + batch_size = batch_size * el + + new_shape = (batch_size, -1) + return relax.op.reshape(inputs[0], new_shape) From 85be408c438be090f8076b1c859e1b2a788cb7c3 Mon Sep 17 00:00:00 2001 From: Patrik Persson Date: Thu, 24 Oct 2024 17:59:00 +0200 Subject: [PATCH 4/9] added missing indentation --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 6c1c54c08270..538e02ac5f50 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1203,18 +1203,19 @@ def _impl_v13(cls, bb, inputs, attr, params): axis = get_constant(inputs[1], params) if isinstance(axis, relax.Constant): axis = [int(x) for x in axis.data.numpy()] - # If data is constant, perform computation directly. - if isinstance(data, relax.Constant): - out_data = _np.squeeze(data.data.numpy(), tuple(axis)) - return relax.const(out_data, data.struct_info.dtype) - - if isinstance(data, relax.ShapeExpr): - if axis == [0]: - return relax.PrimValue(data[0]) - else: - raise NotImplementedError( - "Squeeze with symbolic axes and non-zero axes is not supported." - ) + + # If data is constant, perform computation directly. + if isinstance(data, relax.Constant): + out_data = _np.squeeze(data.data.numpy(), tuple(axis)) + return relax.const(out_data, data.struct_info.dtype) + + if isinstance(data, relax.ShapeExpr): + if axis == [0]: + return relax.PrimValue(data[0]) + else: + raise NotImplementedError( + "Squeeze with symbolic axes and non-zero axes is not supported." + ) return relax.op.squeeze(data, axis) From b03d7eb1fbdf747b205ba753fe021459c6383db4 Mon Sep 17 00:00:00 2001 From: Patrik Persson Date: Thu, 24 Oct 2024 18:15:15 +0200 Subject: [PATCH 5/9] changed asnumpy to numpy since asnumpy is deprecated --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 538e02ac5f50..4da109cf637b 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1570,7 +1570,7 @@ def _impl_v13(cls, bb, inputs, attr, params): splits_rank = splits.checked_type.ndim if splits is not None and splits_rank > 0: if isinstance(splits, relax.Constant): - splits = splits.data.asnumpy() + splits = splits.data.numpy() indices = [] index = 0 for i in splits[:-1]: From 9adf2dd124aefbb355b8de8cd44abc73eae5fab6 Mon Sep 17 00:00:00 2001 From: Patrik Persson Date: Fri, 25 Oct 2024 12:54:44 +0200 Subject: [PATCH 6/9] fixed bug introduced in the Flatten converter updated formatting --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 4da109cf637b..3dae0a22c903 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1203,7 +1203,7 @@ def _impl_v13(cls, bb, inputs, attr, params): axis = get_constant(inputs[1], params) if isinstance(axis, relax.Constant): axis = [int(x) for x in axis.data.numpy()] - + # If data is constant, perform computation directly. if isinstance(data, relax.Constant): out_data = _np.squeeze(data.data.numpy(), tuple(axis)) @@ -1622,7 +1622,7 @@ def _impl_v13(cls, bb, inputs, attr, params): steps = [1] * len(axes) # If input is a shape tensor, we can directly extract it. if isinstance(data, relax.ShapeExpr): - shape_data = [dim for dim in data] + shape_data = list(data) # Starts, ends, and steps must be 1-d for shape operation. assert all(len(i) == 1 for i in [starts, ends, steps]) sliced_values = shape_data[starts[0] : ends[0] : steps[0]] @@ -2253,7 +2253,7 @@ class Flatten(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr, params): axis = attr.get("axis", 1) - data_shape = [i for i in inputs[0].struct_info.shape] + data_shape = list(inputs[0].struct_info.shape) if axis == 0: new_shape = (1, -1) @@ -2262,7 +2262,7 @@ def _impl_v13(cls, bb, inputs, attr, params): if all(shape_flags): data_shape = [x.value for x in data_shape[0:axis]] - new_shape = (_np.prod(data_shape[0:axis]).astype("int64"), -1) + new_shape = (_np.prod(data_shape).astype("int64"), -1) else: batch_size = 1 From ef294c58170544ad09db5db7eb35ef9f677b1566 Mon Sep 17 00:00:00 2001 From: Patrik Persson Date: Tue, 29 Oct 2024 15:20:26 +0100 Subject: [PATCH 7/9] removed indention after relax.Constant check and instead perform a check for tuple or None type axis where relevant --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 30 +++++++++++-------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 3dae0a22c903..1a063c2fa806 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1202,20 +1202,26 @@ def _impl_v13(cls, bb, inputs, attr, params): data = inputs[0] axis = get_constant(inputs[1], params) if isinstance(axis, relax.Constant): - axis = [int(x) for x in axis.data.numpy()] + axis = tuple([int(x) for x in axis.data.numpy()]) - # If data is constant, perform computation directly. - if isinstance(data, relax.Constant): - out_data = _np.squeeze(data.data.numpy(), tuple(axis)) - return relax.const(out_data, data.struct_info.dtype) + # If data is constant, perform computation directly. + if isinstance(data, relax.Constant): + if isinstance(axis, (tuple, type(None))): + out_data = _np.squeeze(data.data.numpy(), axis) + else: + raise NotImplementedError( + "Squeeze with symbolic axes not supported" + ) - if isinstance(data, relax.ShapeExpr): - if axis == [0]: - return relax.PrimValue(data[0]) - else: - raise NotImplementedError( - "Squeeze with symbolic axes and non-zero axes is not supported." - ) + return relax.const(out_data, data.struct_info.dtype) + + if isinstance(data, relax.ShapeExpr): + if axis == (0,): + return relax.PrimValue(data[0]) + else: + raise NotImplementedError( + "Squeeze with symbolic axes and non-zero axes is not supported." + ) return relax.op.squeeze(data, axis) From 310026b9ff29054cfdbcd432bf59bcf57a8196e7 Mon Sep 17 00:00:00 2001 From: Patrik Persson Date: Tue, 29 Oct 2024 16:17:09 +0100 Subject: [PATCH 8/9] extracted generate_random_value function from generate_random_inputs added isinstance(tvm_out, (int, float, bool)) in check_correctness since the VM can return primitive types added test case for squeeze with constant input added test case for squeeze with dynamic input shape added test case for squeeze with dynamic shape expression added test case for slice with dynamic shape expression fixed bug in test_split which cased the constant split case never to occur --- tests/python/relax/test_frontend_onnx.py | 220 ++++++++++++++++++++--- 1 file changed, 198 insertions(+), 22 deletions(-) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 46373510b101..050f6ca933aa 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -52,28 +52,35 @@ def generate_random_inputs( shape = [] for dim in i.type.tensor_type.shape.dim: shape.append(dim.dim_value) - - # Extract datatype for the input. - if i.type.tensor_type.elem_type: - dtype = str(onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[i.type.tensor_type.elem_type]) - else: - dtype = "float32" - - # Generate random inputs for each input. - if dtype == "bool": - # random_value = np.random.choice(a=[False, True], size=shape) - random_value = rg.choice(a=[False, True], size=shape) - elif dtype.startswith("int"): - # Keep non-zero values - random_value = rg.integers(low=-63, high=63, size=shape).astype(dtype) - random_value[random_value <= 0] -= 1 - else: - random_value = rg.standard_normal(size=shape).astype(dtype) - input_values[i.name] = random_value + + input_values[i.name] = generate_random_value(shape, i.type.tensor_type.elem_type) return input_values +def generate_random_value( + shape, elem_type +) -> np.ndarray: + + # Extract datatype for the input. + if elem_type: + dtype = str(onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[elem_type]) + else: + dtype = "float32" + + # Generate random inputs for each input. + if dtype == "bool": + # random_value = np.random.choice(a=[False, True], size=shape) + random_value = rg.choice(a=[False, True], size=shape) + elif dtype.startswith("int"): + # Keep non-zero values + random_value = rg.integers(low=-63, high=63, size=shape).astype(dtype) + random_value[random_value <= 0] -= 1 + else: + random_value = rg.standard_normal(size=shape).astype(dtype) + + return random_value + def check_correctness( model: ModelProto, inputs: Optional[Dict[str, np.ndarray]] = None, @@ -156,6 +163,8 @@ def _check_output(tvm_out, ort_out): elif isinstance(tvm_out, tvm.runtime.ShapeTuple) and isinstance(ort_out, np.ndarray): shape_out = tvm.nd.array([int(i) for i in tvm_out]) tvm.testing.assert_allclose(shape_out.numpy(), ort_out, rtol=rtol, atol=atol) + elif isinstance(tvm_out, (int, float, bool)) and isinstance(ort_out, np.ndarray): + tvm.testing.assert_allclose(np.array(tvm_out), ort_out, rtol=rtol, atol=atol) else: raise ValueError(f"Unsupported types: {type(tvm_out)}, {type(ort_out)}") @@ -218,6 +227,30 @@ def verify_unary( model = helper.make_model(graph, producer_name="elemwise_test") check_correctness(model, opset=opset) +def verify_unary_dynamic_shape( + op_name, + shape, + shape_instance, + attrs={}, + domain=None, + input_dtype=TensorProto.FLOAT, + output_dtype=TensorProto.FLOAT, + opset=14, +): + test_node = helper.make_node(op_name, ["x"], ["y"], **attrs, domain=domain) + graph = helper.make_graph( + [test_node], + "elemwise_test", + inputs=[ + helper.make_tensor_value_info("x", input_dtype, shape), + ], + outputs=[helper.make_tensor_value_info("y", output_dtype, shape)], + ) + + model = helper.make_model(graph, producer_name="elemwise_test") + inputs = {"x": generate_random_value(shape_instance, input_dtype)} + check_correctness(model, inputs, opset=opset) + def verify_binary( op_name, shape_a, shape_b, shape_c, attrs={}, domain=None, dtype=TensorProto.FLOAT, opset=14 @@ -1012,6 +1045,81 @@ def test_squeeze(axis): model = helper.make_model(graph, producer_name="squeeze_test") check_correctness(model, opset=13) +@pytest.mark.parametrize("axis", [[0, 2], None]) +def test_squeeze_constant(axis): + shape = [1, 32, 1, 32] + constant= make_constant_node("x", onnx.TensorProto.FLOAT, shape, rg.standard_normal(size=shape).astype("float32")) + if axis: + squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"]) + else: + squeeze_node = helper.make_node("Squeeze", ["x"], ["y"]) + + initializer = ( + [helper.make_tensor("axes", TensorProto.INT64, [len(axis)], axis)] if axis else None + ) + + graph = helper.make_graph( + [constant, squeeze_node], + "squeeze_test", + inputs=[], + initializer=initializer, + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [32, 32])], + ) + + model = helper.make_model(graph, producer_name="squeeze_test") + check_correctness(model, opset=13) + +@pytest.mark.parametrize("axis", [[0]]) +@pytest.mark.parametrize("A", [8, 16, 32]) +@pytest.mark.parametrize("B", [8, 16, 32]) +def test_dynamic_squeeze(axis, A, B): + + squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"]) + shape = [1, "A", "B"] + + initializer = ( + [helper.make_tensor("axes", TensorProto.INT64, [len(axis)], axis)] if axis else None + ) + + graph = helper.make_graph( + [squeeze_node], + "squeeze_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), + ], + initializer=initializer, + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, ["A", "B"])], + ) + + model = helper.make_model(graph, producer_name="squeeze_test") + inputs = {"x": rg.standard_normal(size=[1, A, B]).astype("float32")} + check_correctness(model, inputs, opset=13) + +@pytest.mark.parametrize("axis", [[0]]) +@pytest.mark.parametrize("A", [8, 16, 32]) +def test_dynamic_shape_squeeze(axis, A): + + shape_node = helper.make_node("Shape", ["x"], ["y"]) + squeeze_node = helper.make_node("Squeeze", ["y", "axes"], ["z"]) + shape = ["A"] + + initializer = ( + [helper.make_tensor("axes", TensorProto.INT64, [len(axis)], axis)] if axis else None + ) + + graph = helper.make_graph( + [shape_node, squeeze_node], + "squeeze_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), + ], + initializer=initializer, + outputs=[helper.make_tensor_value_info("z", TensorProto.INT64, [])], + ) + + model = helper.make_model(graph, producer_name="squeeze_test") + inputs = {"x": rg.standard_normal(size=[A]).astype("float32")} + check_correctness(model, inputs, opset=13) def test_const(): shape = [32, 32] @@ -1547,6 +1655,65 @@ def verify_slice(data_shape, output_shape, starts, ends, axes=None, steps=None): # steps=[-1, -3, -2], # ) +def test_slice_dynamic_shape(): + def verify_slice(data_shape, data_instance_shape, output_shape, starts, ends, axes=None, steps=None): + if isinstance(starts, list): + starts = np.array(starts, "int64") + if isinstance(ends, list): + ends = np.array(ends, "int64") + if isinstance(axes, list): + axes = np.array(axes, "int64") + if isinstance(steps, list): + steps = np.array(steps, "int64") + + slice_inputs = ["y", "starts", "ends"] + initializer = [ + helper.make_tensor("starts", TensorProto.INT64, starts.shape, starts), + helper.make_tensor("ends", TensorProto.INT64, ends.shape, ends), + ] + + if axes is not None: + initializer.append(helper.make_tensor("axes", TensorProto.INT64, axes.shape, axes)) + slice_inputs.append("axes") + if steps is not None: + initializer.append(helper.make_tensor("steps", TensorProto.INT64, steps.shape, steps)) + slice_inputs.append("steps") + + shape_node = helper.make_node("Shape", inputs=["x"], outputs=["y"]) + slice_node = helper.make_node("Slice", inputs=slice_inputs, outputs=["z"]) + + graph = helper.make_graph( + [shape_node, slice_node], + "slice_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, data_shape), + ], + outputs=[helper.make_tensor_value_info("z", TensorProto.INT64, output_shape)], + initializer=initializer, + ) + + model = helper.make_model(graph, producer_name="slice_test") + inputs = {"x": rg.standard_normal(size=data_instance_shape).astype("float32")} + check_correctness(model, inputs) + + verify_slice([20, 10, 5], [20, 10, 5], [2], starts=[0], ends=[2], axes=[0]) + verify_slice(["A", 10, 5], [20, 10, 5], [2], starts=[0], ends=[2], axes=[0]) + verify_slice(["A", "B", 5], [20, 10, 5], [2], starts=[0], ends=[2], axes=[0]) + verify_slice([20, 10, "C"], [20, 10, 5], [2], starts=[0], ends=[2], axes=[0]) + verify_slice(["A", "B", "C"], [20, 10, 5], [2], starts=[0], ends=[2], axes=[0]) + + verify_slice([20, 10, 5], [20, 10, 5], [1], starts=[1], ends=[2], axes=[0]) + verify_slice(["A", 10, 5], [20, 10, 5], [1], starts=[1], ends=[2], axes=[0]) + verify_slice(["A", "B", 5], [20, 10, 5], [1], starts=[1], ends=[2], axes=[0]) + verify_slice([20, 10, "C"], [20, 10, 5], [1], starts=[1], ends=[2], axes=[0]) + verify_slice(["A", "B", "C"], [20, 10, 5], [1], starts=[1], ends=[2], axes=[0]) + + verify_slice([20, 10, 5], [20, 10, 5], [2], starts=[1], ends=[3], axes=[0]) + verify_slice(["A", 10, 5], [20, 10, 5], [2], starts=[1], ends=[3], axes=[0]) + verify_slice(["A", "B", 5], [20, 10, 5], [2], starts=[1], ends=[3], axes=[0]) + verify_slice([20, 10, "C"], [20, 10, 5], [2], starts=[1], ends=[3], axes=[0]) + verify_slice(["A", "B", "C"], [20, 10, 5], [2], starts=[1], ends=[3], axes=[0]) + # TODO Enable dynamism @pytest.mark.parametrize("dynamic", [False]) @@ -1795,12 +1962,13 @@ def verify_split(indata_shape, outdata_shapes, split, axis=0, pass_split=True, o ) ] + split_constant = None if pass_split: if opset >= 13: np_split = np.array(split).astype(np.int64) - initializer.append( - helper.make_tensor("split", TensorProto.INT64, list(np_split.shape), np_split) - ) + split_constant= make_constant_node("split", onnx.TensorProto.INT64, list(np_split.shape), np_split) + input_names.append("split") + node = helper.make_node( "Split", inputs=input_names, @@ -1812,8 +1980,10 @@ def verify_split(indata_shape, outdata_shapes, split, axis=0, pass_split=True, o split_attr = helper.make_attribute("split", split) node.attribute.append(split_attr) + nodes = [split_constant, node] if split_constant else [node] + graph = helper.make_graph( - [node], + nodes, "split_test", inputs=inputs, initializer=initializer, @@ -2226,6 +2396,12 @@ def test_flatten(): verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": 2}) +def test_flatten_dynamic(): + verify_unary_dynamic_shape("Flatten", [1, "A", "B", 32], [1, 3, 32, 32], attrs={"axis": 0}) + verify_unary_dynamic_shape("Flatten", [1, "A", "B", 32], [1, 3, 32, 32], attrs={"axis": -1}) + verify_unary_dynamic_shape("Flatten", [1, "A", "B", 32], [1, 3, 32, 32], attrs={"axis": 2}) + + def test_onehot(): one_hot_node = helper.make_node("OneHot", ["indices", "depth", "values"], ["y"], axis=1) graph = helper.make_graph( From 3d5edc916a9a06dfc704d0ef9244daf34662b1b8 Mon Sep 17 00:00:00 2001 From: Patrik Persson Date: Tue, 29 Oct 2024 18:19:55 +0100 Subject: [PATCH 9/9] applied black formatting --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 4 +- tests/python/relax/test_frontend_onnx.py | 49 ++++++++++++------- 2 files changed, 31 insertions(+), 22 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 1a063c2fa806..611f4348d55e 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1209,9 +1209,7 @@ def _impl_v13(cls, bb, inputs, attr, params): if isinstance(axis, (tuple, type(None))): out_data = _np.squeeze(data.data.numpy(), axis) else: - raise NotImplementedError( - "Squeeze with symbolic axes not supported" - ) + raise NotImplementedError("Squeeze with symbolic axes not supported") return relax.const(out_data, data.struct_info.dtype) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 050f6ca933aa..9faa441138fc 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -52,16 +52,14 @@ def generate_random_inputs( shape = [] for dim in i.type.tensor_type.shape.dim: shape.append(dim.dim_value) - + input_values[i.name] = generate_random_value(shape, i.type.tensor_type.elem_type) return input_values -def generate_random_value( - shape, elem_type -) -> np.ndarray: - +def generate_random_value(shape, elem_type) -> np.ndarray: + # Extract datatype for the input. if elem_type: dtype = str(onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[elem_type]) @@ -81,6 +79,7 @@ def generate_random_value( return random_value + def check_correctness( model: ModelProto, inputs: Optional[Dict[str, np.ndarray]] = None, @@ -170,7 +169,7 @@ def _check_output(tvm_out, ort_out): # Check that number of outputs match. assert len(tvm_output) == len(ort_output), "Unequal number of outputs" - for (tvm_out, ort_out) in zip(tvm_output, ort_output): + for tvm_out, ort_out in zip(tvm_output, ort_output): # TODO Allow configurable tolerance. if ort_out is not None: _check_output(tvm_out, ort_out) @@ -227,6 +226,7 @@ def verify_unary( model = helper.make_model(graph, producer_name="elemwise_test") check_correctness(model, opset=opset) + def verify_unary_dynamic_shape( op_name, shape, @@ -246,7 +246,7 @@ def verify_unary_dynamic_shape( ], outputs=[helper.make_tensor_value_info("y", output_dtype, shape)], ) - + model = helper.make_model(graph, producer_name="elemwise_test") inputs = {"x": generate_random_value(shape_instance, input_dtype)} check_correctness(model, inputs, opset=opset) @@ -1045,11 +1045,14 @@ def test_squeeze(axis): model = helper.make_model(graph, producer_name="squeeze_test") check_correctness(model, opset=13) + @pytest.mark.parametrize("axis", [[0, 2], None]) def test_squeeze_constant(axis): shape = [1, 32, 1, 32] - constant= make_constant_node("x", onnx.TensorProto.FLOAT, shape, rg.standard_normal(size=shape).astype("float32")) - if axis: + constant = make_constant_node( + "x", onnx.TensorProto.FLOAT, shape, rg.standard_normal(size=shape).astype("float32") + ) + if axis: squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"]) else: squeeze_node = helper.make_node("Squeeze", ["x"], ["y"]) @@ -1069,11 +1072,12 @@ def test_squeeze_constant(axis): model = helper.make_model(graph, producer_name="squeeze_test") check_correctness(model, opset=13) + @pytest.mark.parametrize("axis", [[0]]) @pytest.mark.parametrize("A", [8, 16, 32]) @pytest.mark.parametrize("B", [8, 16, 32]) def test_dynamic_squeeze(axis, A, B): - + squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"]) shape = [1, "A", "B"] @@ -1092,13 +1096,14 @@ def test_dynamic_squeeze(axis, A, B): ) model = helper.make_model(graph, producer_name="squeeze_test") - inputs = {"x": rg.standard_normal(size=[1, A, B]).astype("float32")} + inputs = {"x": rg.standard_normal(size=[1, A, B]).astype("float32")} check_correctness(model, inputs, opset=13) + @pytest.mark.parametrize("axis", [[0]]) @pytest.mark.parametrize("A", [8, 16, 32]) def test_dynamic_shape_squeeze(axis, A): - + shape_node = helper.make_node("Shape", ["x"], ["y"]) squeeze_node = helper.make_node("Squeeze", ["y", "axes"], ["z"]) shape = ["A"] @@ -1118,9 +1123,10 @@ def test_dynamic_shape_squeeze(axis, A): ) model = helper.make_model(graph, producer_name="squeeze_test") - inputs = {"x": rg.standard_normal(size=[A]).astype("float32")} + inputs = {"x": rg.standard_normal(size=[A]).astype("float32")} check_correctness(model, inputs, opset=13) + def test_const(): shape = [32, 32] const_node = helper.make_node( @@ -1655,8 +1661,11 @@ def verify_slice(data_shape, output_shape, starts, ends, axes=None, steps=None): # steps=[-1, -3, -2], # ) + def test_slice_dynamic_shape(): - def verify_slice(data_shape, data_instance_shape, output_shape, starts, ends, axes=None, steps=None): + def verify_slice( + data_shape, data_instance_shape, output_shape, starts, ends, axes=None, steps=None + ): if isinstance(starts, list): starts = np.array(starts, "int64") if isinstance(ends, list): @@ -1678,10 +1687,10 @@ def verify_slice(data_shape, data_instance_shape, output_shape, starts, ends, ax if steps is not None: initializer.append(helper.make_tensor("steps", TensorProto.INT64, steps.shape, steps)) slice_inputs.append("steps") - + shape_node = helper.make_node("Shape", inputs=["x"], outputs=["y"]) slice_node = helper.make_node("Slice", inputs=slice_inputs, outputs=["z"]) - + graph = helper.make_graph( [shape_node, slice_node], "slice_test", @@ -1966,7 +1975,9 @@ def verify_split(indata_shape, outdata_shapes, split, axis=0, pass_split=True, o if pass_split: if opset >= 13: np_split = np.array(split).astype(np.int64) - split_constant= make_constant_node("split", onnx.TensorProto.INT64, list(np_split.shape), np_split) + split_constant = make_constant_node( + "split", onnx.TensorProto.INT64, list(np_split.shape), np_split + ) input_names.append("split") node = helper.make_node( @@ -2398,8 +2409,8 @@ def test_flatten(): def test_flatten_dynamic(): verify_unary_dynamic_shape("Flatten", [1, "A", "B", 32], [1, 3, 32, 32], attrs={"axis": 0}) - verify_unary_dynamic_shape("Flatten", [1, "A", "B", 32], [1, 3, 32, 32], attrs={"axis": -1}) - verify_unary_dynamic_shape("Flatten", [1, "A", "B", 32], [1, 3, 32, 32], attrs={"axis": 2}) + verify_unary_dynamic_shape("Flatten", [1, "A", "B", 32], [1, 3, 32, 32], attrs={"axis": -1}) + verify_unary_dynamic_shape("Flatten", [1, "A", "B", 32], [1, 3, 32, 32], attrs={"axis": 2}) def test_onehot():