From 3d7a65285f90b592b2aff206f8fb898c9e7cfe1a Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 14 Jul 2021 22:52:00 +0000 Subject: [PATCH 1/4] Snapshot --- python/tvm/relay/frontend/onnx.py | 26 ++-- tests/python/frontend/onnx/test_forward.py | 167 +++++++++++---------- 2 files changed, 97 insertions(+), 96 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index aafc301be555..ec6b786ca717 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2741,24 +2741,24 @@ def get_var(name, val, scan=False): loop_var_names = [v.name_hint for v in loop_vars] num_scan_outputs = len(body.output) - (1 + num_deps) - # TODO (jwfromm) Test with strided slice once type unifier for this case is fixed. - if num_scan_outputs != 0 and "Slice" in [n.op_type for n in body.node]: - warnings.warn( - """ - Using scan outputs in a loop with strided slice - currently may cause errors during compilation. - """ - ) # Construct variables and intial empty tensors for any scan outputs. + # To do this, we'll figure out the output shapes of the body subgraph by importing + # it and doing type inference. scan_output_vars = [] scan_output_init = [] + if num_scan_outputs > 0: + with subgraph_scope: + loop_outputs = subgraph_scope.from_onnx( + body, graph_scope.opset, get_output_expr=True + ) + loop_outputs = _expr.TupleWrapper(loop_outputs, len(body.output)) + for i in range(num_scan_outputs): - name, shape, dtype, _ = get_info(body.output[i + 1 + num_deps]) - if dtype is None: - dtype = infer_type(loop_deps[i]).checked_type.dtype - if dtype == "float": - dtype = "float32" + name, _, _, _ = get_info(body.output[i + 1 + num_deps]) + output_node = infer_type(loop_outputs[i + 1 + num_deps]) + shape = get_const_tuple(output_node.checked_type.shape) + dtype = output_node.checked_type.dtype scan_output_vars.append( _expr.var(name, shape=([_ty.Any()] * (len(shape) + 1)), dtype=dtype) ) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 049ca1e0cfe0..cfdc212affe6 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -61,6 +61,7 @@ def get_tvm_output_with_vm( if convert_to_static: mod = relay.transform.DynamicToStatic()(mod) + print(relay.transform.InferType()(mod)) ex = relay.create_executor("vm", mod=mod, device=device, target=target) result = ex.evaluate()(*input_data, **params) if isinstance(result, tvm.runtime.NDArray): @@ -5026,87 +5027,87 @@ def repeat(N, D): if __name__ == "__main__": - test_flatten() - test_reshape() - test_shape() - test_expand() - test_power() - test_squeeze() - test_unsqueeze() - test_slice() - test_floor() - test_ceil() - test_round() - test_isinf() - test_isnan() - test_clip() - test_clip_min_max_as_inputs() - test_onehot() - test_gemm() - test_matmul() - test_gather() - test_gatherelements() - test_gather_nd() - test_scatter() - test_lrn() - test_instance_norm() - test_upsample() - test_forward_min() - test_forward_max() - test_forward_mean() - test_forward_hardsigmoid() - test_forward_arg_min_max() - test_softmax() - test_constantofshape() - test_all_reduce_funcs() - test_pad() - test_split() - test_binary_ops() - test_unary_ops() - test_leaky_relu() - test_elu() - test_selu() - test_prelu() - test_ThresholdedRelu() - test_LogSoftmax() - test_resnet() - test_inception() - test_densenet() - test_sign() - test_not() - test_and() - test_tile() - test_erf() - test_where() - test_or() - test_depth_to_space() - test_space_to_depth() - test_batch_norm() - test_batch_norm_dynamic_subgraph() - test_conv() - test_convtranspose() - test_unsqueeze_constant() - test_pooling() - test_lppool() - test_lstm() - test_gru() - test_resize() - test_nonzero() - test_topk() - test_mod() - test_xor() - test_max_roi_pool() - test_roi_align() - test_range() + #test_flatten() + #test_reshape() + #test_shape() + #test_expand() + #test_power() + #test_squeeze() + #test_unsqueeze() + #test_slice() + #test_floor() + #test_ceil() + #test_round() + #test_isinf() + #test_isnan() + #test_clip() + #test_clip_min_max_as_inputs() + #test_onehot() + #test_gemm() + #test_matmul() + #test_gather() + #test_gatherelements() + #test_gather_nd() + #test_scatter() + #test_lrn() + #test_instance_norm() + #test_upsample() + #test_forward_min() + #test_forward_max() + #test_forward_mean() + #test_forward_hardsigmoid() + #test_forward_arg_min_max() + #test_softmax() + #test_constantofshape() + #test_all_reduce_funcs() + #test_pad() + #test_split() + #test_binary_ops() + #test_unary_ops() + #test_leaky_relu() + #test_elu() + #test_selu() + #test_prelu() + #test_ThresholdedRelu() + #test_LogSoftmax() + #test_resnet() + #test_inception() + #test_densenet() + #test_sign() + #test_not() + #test_and() + #test_tile() + #test_erf() + #test_where() + #test_or() + #test_depth_to_space() + #test_space_to_depth() + #test_batch_norm() + #test_batch_norm_dynamic_subgraph() + #test_conv() + #test_convtranspose() + #test_unsqueeze_constant() + #test_pooling() + #test_lppool() + #test_lstm() + #test_gru() + #test_resize() + #test_nonzero() + #test_topk() + #test_mod() + #test_xor() + #test_max_roi_pool() + #test_roi_align() + #test_range() test_loop() - test_size() - test_maxunpool() - test_softplus() - test_cumsum() - test_wrong_input() - test_aten() - test_reverse_sequence() - test_eyelike() - test_qlinearconv() - test_convinteger() - test_batch_matmul() + #test_size() + #test_maxunpool() + #test_softplus() + #test_cumsum() + #test_wrong_input() + #test_aten() + #test_reverse_sequence() + #test_eyelike() + #test_qlinearconv() + #test_convinteger() + #test_batch_matmul() From 6b00c5ceda153ea6f3be560a0d6209325ce65a5b Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 14 Jul 2021 22:57:31 +0000 Subject: [PATCH 2/4] Undo comments. --- tests/python/frontend/onnx/test_forward.py | 167 ++++++++++----------- 1 file changed, 83 insertions(+), 84 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index cfdc212affe6..049ca1e0cfe0 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -61,7 +61,6 @@ def get_tvm_output_with_vm( if convert_to_static: mod = relay.transform.DynamicToStatic()(mod) - print(relay.transform.InferType()(mod)) ex = relay.create_executor("vm", mod=mod, device=device, target=target) result = ex.evaluate()(*input_data, **params) if isinstance(result, tvm.runtime.NDArray): @@ -5027,87 +5026,87 @@ def repeat(N, D): if __name__ == "__main__": - #test_flatten() - #test_reshape() - #test_shape() - #test_expand() - #test_power() - #test_squeeze() - #test_unsqueeze() - #test_slice() - #test_floor() - #test_ceil() - #test_round() - #test_isinf() - #test_isnan() - #test_clip() - #test_clip_min_max_as_inputs() - #test_onehot() - #test_gemm() - #test_matmul() - #test_gather() - #test_gatherelements() - #test_gather_nd() - #test_scatter() - #test_lrn() - #test_instance_norm() - #test_upsample() - #test_forward_min() - #test_forward_max() - #test_forward_mean() - #test_forward_hardsigmoid() - #test_forward_arg_min_max() - #test_softmax() - #test_constantofshape() - #test_all_reduce_funcs() - #test_pad() - #test_split() - #test_binary_ops() - #test_unary_ops() - #test_leaky_relu() - #test_elu() - #test_selu() - #test_prelu() - #test_ThresholdedRelu() - #test_LogSoftmax() - #test_resnet() - #test_inception() - #test_densenet() - #test_sign() - #test_not() - #test_and() - #test_tile() - #test_erf() - #test_where() - #test_or() - #test_depth_to_space() - #test_space_to_depth() - #test_batch_norm() - #test_batch_norm_dynamic_subgraph() - #test_conv() - #test_convtranspose() - #test_unsqueeze_constant() - #test_pooling() - #test_lppool() - #test_lstm() - #test_gru() - #test_resize() - #test_nonzero() - #test_topk() - #test_mod() - #test_xor() - #test_max_roi_pool() - #test_roi_align() - #test_range() + test_flatten() + test_reshape() + test_shape() + test_expand() + test_power() + test_squeeze() + test_unsqueeze() + test_slice() + test_floor() + test_ceil() + test_round() + test_isinf() + test_isnan() + test_clip() + test_clip_min_max_as_inputs() + test_onehot() + test_gemm() + test_matmul() + test_gather() + test_gatherelements() + test_gather_nd() + test_scatter() + test_lrn() + test_instance_norm() + test_upsample() + test_forward_min() + test_forward_max() + test_forward_mean() + test_forward_hardsigmoid() + test_forward_arg_min_max() + test_softmax() + test_constantofshape() + test_all_reduce_funcs() + test_pad() + test_split() + test_binary_ops() + test_unary_ops() + test_leaky_relu() + test_elu() + test_selu() + test_prelu() + test_ThresholdedRelu() + test_LogSoftmax() + test_resnet() + test_inception() + test_densenet() + test_sign() + test_not() + test_and() + test_tile() + test_erf() + test_where() + test_or() + test_depth_to_space() + test_space_to_depth() + test_batch_norm() + test_batch_norm_dynamic_subgraph() + test_conv() + test_convtranspose() + test_unsqueeze_constant() + test_pooling() + test_lppool() + test_lstm() + test_gru() + test_resize() + test_nonzero() + test_topk() + test_mod() + test_xor() + test_max_roi_pool() + test_roi_align() + test_range() test_loop() - #test_size() - #test_maxunpool() - #test_softplus() - #test_cumsum() - #test_wrong_input() - #test_aten() - #test_reverse_sequence() - #test_eyelike() - #test_qlinearconv() - #test_convinteger() - #test_batch_matmul() + test_size() + test_maxunpool() + test_softplus() + test_cumsum() + test_wrong_input() + test_aten() + test_reverse_sequence() + test_eyelike() + test_qlinearconv() + test_convinteger() + test_batch_matmul() From 66a6543f03e781f1c7a4730ca37f87d4c24e9427 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 14 Jul 2021 23:09:11 +0000 Subject: [PATCH 3/4] Add testing for malformed loop nodes. --- tests/python/frontend/onnx/test_forward.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 049ca1e0cfe0..8043c028d68e 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4043,7 +4043,7 @@ def verify_count_loop(): verify_with_ort_with_inputs(loop_model, input_vals, use_vm=True, freeze_params=True) -def verify_tensor_loop(): +def verify_tensor_loop(shapeless_output=False): y_in = helper.make_tensor_value_info("y_in", TensorProto.FLOAT, [3, 3, 3, 3]) y_out = helper.make_tensor_value_info("y_out", TensorProto.FLOAT, [3, 3, 3, 3]) scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, [3, 3, 3, 3]) @@ -4076,6 +4076,13 @@ def verify_tensor_loop(): trip_count = np.array(5).astype(np.int64) cond = np.array(1).astype(bool) + + # Allow testing of malformed nodes since pytorch likes to create these. + if shapeless_output: + scan_shape = None + else: + scan_shape = [5, 3, 3, 3, 3] + loop_graph = onnx.helper.make_graph( [loop_node], "loop_outer", @@ -4086,7 +4093,7 @@ def verify_tensor_loop(): ], outputs=[ onnx.helper.make_tensor_value_info("res_y", onnx.TensorProto.FLOAT, [3, 3, 3, 3]), - onnx.helper.make_tensor_value_info("res_scan", onnx.TensorProto.FLOAT, [5, 3, 3, 3, 3]), + onnx.helper.make_tensor_value_info("res_scan", onnx.TensorProto.FLOAT, scan_shape), ], ) loop_model = onnx.helper.make_model(loop_graph) @@ -4106,6 +4113,8 @@ def test_loop(): verify_count_loop() # Test a loop that uses an array output. verify_tensor_loop() + # Test a loop that is malformed and has no output shape defined. + verify_tensor_loop(shapeless_output=True) def verify_if(cond_array, num_outputs): From e616561a9c053be8504a902493c47fa12746f725 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 14 Jul 2021 23:34:12 +0000 Subject: [PATCH 4/4] Format oops. --- python/tvm/relay/frontend/onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index ec6b786ca717..a3a9f494f4e2 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2751,7 +2751,7 @@ def get_var(name, val, scan=False): with subgraph_scope: loop_outputs = subgraph_scope.from_onnx( body, graph_scope.opset, get_output_expr=True - ) + ) loop_outputs = _expr.TupleWrapper(loop_outputs, len(body.output)) for i in range(num_scan_outputs):