Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2914,7 +2914,7 @@ def from_onnx(self, graph, opset, get_output_expr=False):
else:
self._num_input += 1
if i_name in self._shape:
i_shape = self._shape[i_name]
i_shape = self._shape.pop(i_name)
else:
if "?" in str(i_shape):
warning_msg = (
Expand All @@ -2929,6 +2929,11 @@ def from_onnx(self, graph, opset, get_output_expr=False):
dtype = d_type
self._nodes[i_name] = new_var(i_name, shape=i_shape, dtype=dtype)
self._inputs[i_name] = self._nodes[i_name]
assert (
len(self._shape) == 0
), "User specified the shape for inputs that weren't found in the graph: " + str(
self._shape
)
# get list of unsupported ops
convert_map = _get_convert_map(opset)
unsupported_ops = set()
Expand Down
39 changes: 33 additions & 6 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from onnx import helper, TensorProto, mapping, numpy_helper
import torch
import torchvision
import pytest
import tvm.topi.testing
import tvm
from tvm import relay
Expand Down Expand Up @@ -57,7 +58,7 @@ def get_tvm_output_with_vm(
mod = relay.transform.DynamicToStatic()(mod)

ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target)
result = ex.evaluate()(*input_data)
result = ex.evaluate()(*input_data, **params)
if isinstance(result, tvm.runtime.NDArray):
return result.asnumpy()
return [r.asnumpy() for r in result]
Expand Down Expand Up @@ -500,7 +501,7 @@ def test_squeeze():

model = helper.make_model(graph, producer_name="squeeze_test")
x = np.random.uniform(size=in_shape).astype("float32")
verify_with_ort_with_inputs(model, [x], [out_shape])
verify_with_ort_with_inputs(model, [x], [out_shape], opset=11)


@tvm.testing.uses_gpu
Expand Down Expand Up @@ -538,7 +539,7 @@ def test_unsqueeze():
)

model = helper.make_model(graph, producer_name="squeeze_test")
verify_with_ort(model, [in_shape])
verify_with_ort(model, [in_shape], opset=11)


def verify_gather(in_shape, indices, axis, dtype):
Expand Down Expand Up @@ -1584,7 +1585,7 @@ def verify_pad_v11(indata, pads, mode="constant", value=0.0):
pads = np.array(pads)
# onnx graph
if mode in ["edge", "reflect"]:
inputs = [indata, pads]
inputs = [indata]
outdata = np.pad(indata, pad_width=np_pads, mode=mode)
node = helper.make_node("Pad", inputs=["input", "pads"], outputs=["output"], mode=mode)
graph = helper.make_graph(
Expand All @@ -1600,7 +1601,7 @@ def verify_pad_v11(indata, pads, mode="constant", value=0.0):
],
)
else:
inputs = [indata, pads, np.array([value]).astype("float32")]
inputs = [indata]
outdata = np.pad(indata, pad_width=np_pads, mode="constant", constant_values=value)
node = helper.make_node(
"Pad", inputs=["input", "pads", "constant_value"], outputs=["output"], mode="constant"
Expand Down Expand Up @@ -1663,7 +1664,7 @@ def verify_reduce_func(func, data, axis, keepdims):

model = helper.make_model(graph, producer_name="reduce_test")

verify_with_ort_with_inputs(model, [data], [outshape])
verify_with_ort_with_inputs(model, [data], [outshape], opset=11)


@tvm.testing.uses_gpu
Expand Down Expand Up @@ -4089,6 +4090,31 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"):
verify_cumsum(data, 1, 1, 1, type="int32")


def test_wrong_input():
node = helper.make_node(
"Softplus",
inputs=["X"],
outputs=["Y"],
)

graph = helper.make_graph(
[node],
"softplus_test",
inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list([5]))],
outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, list([5]))],
)
model = helper.make_model(graph, producer_name="softplus_test")

# Check that the graph can import correctly with proper shape definitions.
correct_shape_dict = {"X": [5]}
relay.frontend.from_onnx(model, shape=correct_shape_dict)

# Check that an assertion is triggered when an input not in the graph is provided.
wrong_shape_dict = {"Z": [5]}
with pytest.raises(AssertionError):
relay.frontend.from_onnx(model, shape=wrong_shape_dict)


if __name__ == "__main__":
test_flatten()
test_reshape()
Expand Down Expand Up @@ -4167,3 +4193,4 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"):
test_maxunpool()
test_softplus()
test_cumsum()
test_wrong_input()