Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2741,24 +2741,24 @@ def get_var(name, val, scan=False):
loop_var_names = [v.name_hint for v in loop_vars]

num_scan_outputs = len(body.output) - (1 + num_deps)
# TODO (jwfromm) Test with strided slice once type unifier for this case is fixed.
if num_scan_outputs != 0 and "Slice" in [n.op_type for n in body.node]:
warnings.warn(
"""
Using scan outputs in a loop with strided slice
currently may cause errors during compilation.
"""
)

# Construct variables and intial empty tensors for any scan outputs.
# To do this, we'll figure out the output shapes of the body subgraph by importing
# it and doing type inference.
scan_output_vars = []
scan_output_init = []
if num_scan_outputs > 0:
with subgraph_scope:
loop_outputs = subgraph_scope.from_onnx(
body, graph_scope.opset, get_output_expr=True
)
loop_outputs = _expr.TupleWrapper(loop_outputs, len(body.output))

for i in range(num_scan_outputs):
name, shape, dtype, _ = get_info(body.output[i + 1 + num_deps])
if dtype is None:
dtype = infer_type(loop_deps[i]).checked_type.dtype
if dtype == "float":
dtype = "float32"
name, _, _, _ = get_info(body.output[i + 1 + num_deps])
output_node = infer_type(loop_outputs[i + 1 + num_deps])
shape = get_const_tuple(output_node.checked_type.shape)
dtype = output_node.checked_type.dtype
scan_output_vars.append(
_expr.var(name, shape=([_ty.Any()] * (len(shape) + 1)), dtype=dtype)
)
Expand Down
13 changes: 11 additions & 2 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4043,7 +4043,7 @@ def verify_count_loop():
verify_with_ort_with_inputs(loop_model, input_vals, use_vm=True, freeze_params=True)


def verify_tensor_loop():
def verify_tensor_loop(shapeless_output=False):
y_in = helper.make_tensor_value_info("y_in", TensorProto.FLOAT, [3, 3, 3, 3])
y_out = helper.make_tensor_value_info("y_out", TensorProto.FLOAT, [3, 3, 3, 3])
scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, [3, 3, 3, 3])
Expand Down Expand Up @@ -4076,6 +4076,13 @@ def verify_tensor_loop():

trip_count = np.array(5).astype(np.int64)
cond = np.array(1).astype(bool)

# Allow testing of malformed nodes since pytorch likes to create these.
if shapeless_output:
scan_shape = None
else:
scan_shape = [5, 3, 3, 3, 3]

loop_graph = onnx.helper.make_graph(
[loop_node],
"loop_outer",
Expand All @@ -4086,7 +4093,7 @@ def verify_tensor_loop():
],
outputs=[
onnx.helper.make_tensor_value_info("res_y", onnx.TensorProto.FLOAT, [3, 3, 3, 3]),
onnx.helper.make_tensor_value_info("res_scan", onnx.TensorProto.FLOAT, [5, 3, 3, 3, 3]),
onnx.helper.make_tensor_value_info("res_scan", onnx.TensorProto.FLOAT, scan_shape),
],
)
loop_model = onnx.helper.make_model(loop_graph)
Expand All @@ -4106,6 +4113,8 @@ def test_loop():
verify_count_loop()
# Test a loop that uses an array output.
verify_tensor_loop()
# Test a loop that is malformed and has no output shape defined.
verify_tensor_loop(shapeless_output=True)


def verify_if(cond_array, num_outputs):
Expand Down