diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 391eaaab5f64..fab4ae889dd7 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2914,7 +2914,7 @@ def from_onnx(self, graph, opset, get_output_expr=False): else: self._num_input += 1 if i_name in self._shape: - i_shape = self._shape[i_name] + i_shape = self._shape.pop(i_name) else: if "?" in str(i_shape): warning_msg = ( @@ -2929,6 +2929,11 @@ def from_onnx(self, graph, opset, get_output_expr=False): dtype = d_type self._nodes[i_name] = new_var(i_name, shape=i_shape, dtype=dtype) self._inputs[i_name] = self._nodes[i_name] + assert ( + len(self._shape) == 0 + ), "User specified the shape for inputs that weren't found in the graph: " + str( + self._shape + ) # get list of unsupported ops convert_map = _get_convert_map(opset) unsupported_ops = set() diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 177bed66f466..5a6216ac705d 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -19,6 +19,7 @@ from onnx import helper, TensorProto, mapping, numpy_helper import torch import torchvision +import pytest import tvm.topi.testing import tvm from tvm import relay @@ -57,7 +58,7 @@ def get_tvm_output_with_vm( mod = relay.transform.DynamicToStatic()(mod) ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target) - result = ex.evaluate()(*input_data) + result = ex.evaluate()(*input_data, **params) if isinstance(result, tvm.runtime.NDArray): return result.asnumpy() return [r.asnumpy() for r in result] @@ -500,7 +501,7 @@ def test_squeeze(): model = helper.make_model(graph, producer_name="squeeze_test") x = np.random.uniform(size=in_shape).astype("float32") - verify_with_ort_with_inputs(model, [x], [out_shape]) + verify_with_ort_with_inputs(model, [x], [out_shape], opset=11) @tvm.testing.uses_gpu @@ -538,7 +539,7 @@ def test_unsqueeze(): ) model = helper.make_model(graph, producer_name="squeeze_test") - verify_with_ort(model, [in_shape]) + verify_with_ort(model, [in_shape], opset=11) def verify_gather(in_shape, indices, axis, dtype): @@ -1584,7 +1585,7 @@ def verify_pad_v11(indata, pads, mode="constant", value=0.0): pads = np.array(pads) # onnx graph if mode in ["edge", "reflect"]: - inputs = [indata, pads] + inputs = [indata] outdata = np.pad(indata, pad_width=np_pads, mode=mode) node = helper.make_node("Pad", inputs=["input", "pads"], outputs=["output"], mode=mode) graph = helper.make_graph( @@ -1600,7 +1601,7 @@ def verify_pad_v11(indata, pads, mode="constant", value=0.0): ], ) else: - inputs = [indata, pads, np.array([value]).astype("float32")] + inputs = [indata] outdata = np.pad(indata, pad_width=np_pads, mode="constant", constant_values=value) node = helper.make_node( "Pad", inputs=["input", "pads", "constant_value"], outputs=["output"], mode="constant" @@ -1663,7 +1664,7 @@ def verify_reduce_func(func, data, axis, keepdims): model = helper.make_model(graph, producer_name="reduce_test") - verify_with_ort_with_inputs(model, [data], [outshape]) + verify_with_ort_with_inputs(model, [data], [outshape], opset=11) @tvm.testing.uses_gpu @@ -4089,6 +4090,31 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): verify_cumsum(data, 1, 1, 1, type="int32") +def test_wrong_input(): + node = helper.make_node( + "Softplus", + inputs=["X"], + outputs=["Y"], + ) + + graph = helper.make_graph( + [node], + "softplus_test", + inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list([5]))], + outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, list([5]))], + ) + model = helper.make_model(graph, producer_name="softplus_test") + + # Check that the graph can import correctly with proper shape definitions. + correct_shape_dict = {"X": [5]} + relay.frontend.from_onnx(model, shape=correct_shape_dict) + + # Check that an assertion is triggered when an input not in the graph is provided. + wrong_shape_dict = {"Z": [5]} + with pytest.raises(AssertionError): + relay.frontend.from_onnx(model, shape=wrong_shape_dict) + + if __name__ == "__main__": test_flatten() test_reshape() @@ -4167,3 +4193,4 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): test_maxunpool() test_softplus() test_cumsum() + test_wrong_input()