From db05af95ca7c7f7a87ca04709709e05bc81d763b Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 21 Oct 2020 20:46:20 +0000 Subject: [PATCH 1/6] If operator support in ONNX. --- python/tvm/relay/frontend/onnx.py | 45 ++++++++++++++++++- tests/python/frontend/onnx/test_forward.py | 50 +++++++++++++++++++++- 2 files changed, 93 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index fa404efc39cf..1a82fe1cf917 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2146,7 +2146,9 @@ def body_fn(*loop_inputs): # Get the output of the current loop using the updated inputs. with subgraph_scope: - loop_outputs = subgraph_scope.from_onnx(body, 11, get_output_expr=True) + loop_outputs = subgraph_scope.from_onnx( + body, graph_scope.opset, get_output_expr=True + ) # Unpack the body outputs and prepare variables for next iteration. new_cond = loop_outputs[0] new_loop_vars = [loop_outputs[i] for i in range(1, 1 + num_deps)] @@ -2197,6 +2199,44 @@ def body_fn(*loop_inputs): return outputs +class If(OnnxOpConverter): + """Operator converter for If""" + + @classmethod + def _impl_v1(cls, inputs, attr, params): + cond = inputs[0] + then_branch = attr.get("then_branch", None) + else_branch = attr.get("else_branch", None) + assert then_branch is not None and else_branch is not None + + # Create graph converters for both branches. + graph_scope = GraphProto.current + then_graph = GraphProto(graph_scope._shape, graph_scope._dtype) + then_graph._nodes = graph_scope._nodes.copy() + else_graph = GraphProto(graph_scope._shape, graph_scope._dtype) + else_graph._nodes = graph_scope._nodes.copy() + + # Convert each branch to a relay expression. + with then_graph: + then_expr = then_graph.from_onnx(then_branch, graph_scope.opset, get_output_expr=True) + with else_graph: + else_expr = else_graph.from_onnx(else_branch, graph_scope.opset, get_output_expr=True) + + # Add constants from both branches to parent graph. + graph_scope._params.update(then_graph._params) + then_free_vars = analysis.free_vars(then_expr) + for var in then_free_vars: + graph_scope._nodes.update({var.name_hint: var}) + graph_scope._params.update(else_graph._params) + else_free_vars = analysis.free_vars(else_expr) + for var in else_free_vars: + graph_scope._nodes.update({var.name_hint: var}) + + # Now we can construct the relay if statement and return. + output = _expr.If(cond, then_expr, else_expr) + return output + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -2354,6 +2394,7 @@ def _get_convert_map(opset): "Range": Range.get_converter(opset), # defs/control_flow "Loop": Loop.get_converter(opset), + "If": If.get_converter(opset), } @@ -2381,6 +2422,7 @@ def __init__(self, shape, dtype): self._num_param = 0 self._shape = shape if shape else {} self._dtype = dtype + self.opset = None def __enter__(self): self._old_manager = GraphProto.current @@ -2436,6 +2478,7 @@ def from_onnx(self, graph, opset, freeze_params=False, get_output_expr=False): params : dict A dict of name: tvm.nd.array pairs, used as pretrained weights """ + self.opset = opset # parse network inputs to relay, aka parameters for init_tensor in graph.initializer: if not init_tensor.name.strip(): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index bf27ba5ddcd9..e3a8c379642a 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -17,7 +17,7 @@ import numpy as np import math import onnx -from onnx import helper, TensorProto, mapping +from onnx import helper, TensorProto, mapping, numpy_helper import torch import torchvision import tvm.topi.testing @@ -3841,6 +3841,53 @@ def test_loop(): verify_count_loop() +def test_if(): + # Given a bool scalar input cond. + # return constant tensor x if cond is True, otherwise return constant tensor y. + + then_out = onnx.helper.make_tensor_value_info("then_out", onnx.TensorProto.FLOAT, [5]) + else_out = onnx.helper.make_tensor_value_info("else_out", onnx.TensorProto.FLOAT, [5]) + + x = np.array([1, 2, 3, 4, 5]).astype(np.float32) + y = np.array([5, 4, 3, 2, 1]).astype(np.float32) + + then_const_node = onnx.helper.make_node( + "Constant", inputs=[], outputs=["then_out"], value=onnx.numpy_helper.from_array(x) + ) + + else_const_node = onnx.helper.make_node( + "Constant", inputs=[], outputs=["else_out"], value=onnx.numpy_helper.from_array(y) + ) + + then_body = onnx.helper.make_graph([then_const_node], "then_body", [], [then_out]) + + else_body = onnx.helper.make_graph([else_const_node], "else_body", [], [else_out]) + + if_node = onnx.helper.make_node( + "If", inputs=["cond"], outputs=["res"], then_branch=then_body, else_branch=else_body + ) + + if_graph = onnx.helper.make_graph( + [if_node], + "if_outer", + inputs=[ + onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []), + ], + outputs=[ + onnx.helper.make_tensor_value_info("res", onnx.TensorProto.FLOAT, [5]), + ], + ) + + if_model = onnx.helper.make_model(if_graph) + cond = np.array(1).astype("bool") + onnx_out = get_onnxruntime_output(if_model, [cond], dtype="bool") + + for target, ctx in [("llvm", tvm.cpu())]: + tvm_out = get_tvm_output_with_vm(if_model, [cond], target, ctx, freeze_params=True) + for i in range(len(tvm_out)): + tvm.testing.assert_allclose(onnx_out[i], tvm_out[i], rtol=1e-05, atol=1e-05) + + if __name__ == "__main__": test_flatten() test_reshape() @@ -3917,3 +3964,4 @@ def test_loop(): test_roi_align() test_range() test_loop() + test_if() From 356e373be00e975ca30cf98b35c509670da438b8 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 21 Oct 2020 20:48:13 +0000 Subject: [PATCH 2/6] Small tweak. --- python/tvm/relay/frontend/onnx.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 1a82fe1cf917..7ebad7297471 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2233,8 +2233,7 @@ def _impl_v1(cls, inputs, attr, params): graph_scope._nodes.update({var.name_hint: var}) # Now we can construct the relay if statement and return. - output = _expr.If(cond, then_expr, else_expr) - return output + return _expr.If(cond, then_expr, else_expr) # compatible operators that do NOT require any conversion. From e87cc98dfa269e1bd539a15e581362f67f44fcf8 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 26 Oct 2020 22:52:11 +0000 Subject: [PATCH 3/6] Added uses_gpu tag. --- tests/python/frontend/onnx/test_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index e3a8c379642a..7c7ba7dce74f 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3841,10 +3841,10 @@ def test_loop(): verify_count_loop() +@tvm.testing.uses_gpu def test_if(): # Given a bool scalar input cond. # return constant tensor x if cond is True, otherwise return constant tensor y. - then_out = onnx.helper.make_tensor_value_info("then_out", onnx.TensorProto.FLOAT, [5]) else_out = onnx.helper.make_tensor_value_info("else_out", onnx.TensorProto.FLOAT, [5]) From 1a87df2e0b5fe9bc3288a1de3e4e0dd4fa193892 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 4 Nov 2020 20:39:45 +0000 Subject: [PATCH 4/6] Disable test on GPU until onnxruntime version is updated. --- tests/python/frontend/onnx/test_forward.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 7c7ba7dce74f..5ce38f94bf93 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3841,7 +3841,8 @@ def test_loop(): verify_count_loop() -@tvm.testing.uses_gpu +# TODO(jwfromm): enable once onnxruntime version is updated in CI. +# @tvm.testing.uses_gpu def test_if(): # Given a bool scalar input cond. # return constant tensor x if cond is True, otherwise return constant tensor y. @@ -3882,7 +3883,7 @@ def test_if(): cond = np.array(1).astype("bool") onnx_out = get_onnxruntime_output(if_model, [cond], dtype="bool") - for target, ctx in [("llvm", tvm.cpu())]: + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output_with_vm(if_model, [cond], target, ctx, freeze_params=True) for i in range(len(tvm_out)): tvm.testing.assert_allclose(onnx_out[i], tvm_out[i], rtol=1e-05, atol=1e-05) @@ -3964,4 +3965,3 @@ def test_if(): test_roi_align() test_range() test_loop() - test_if() From 3c8c7e080b114b6ec600edb9af0509e5938e3dd0 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Thu, 5 Nov 2020 01:27:45 +0000 Subject: [PATCH 5/6] Use parametrize_target to specify CPU only. --- tests/python/frontend/onnx/test_forward.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 5ce38f94bf93..24d8ecf81068 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3841,9 +3841,9 @@ def test_loop(): verify_count_loop() -# TODO(jwfromm): enable once onnxruntime version is updated in CI. -# @tvm.testing.uses_gpu -def test_if(): +# TODO(jwfromm): enable cuda testing once onnxruntime version is updated in CI. +@tvm.testing.parametrize_targets("llvm") +def test_if(ctx, target): # Given a bool scalar input cond. # return constant tensor x if cond is True, otherwise return constant tensor y. then_out = onnx.helper.make_tensor_value_info("then_out", onnx.TensorProto.FLOAT, [5]) @@ -3883,10 +3883,9 @@ def test_if(): cond = np.array(1).astype("bool") onnx_out = get_onnxruntime_output(if_model, [cond], dtype="bool") - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output_with_vm(if_model, [cond], target, ctx, freeze_params=True) - for i in range(len(tvm_out)): - tvm.testing.assert_allclose(onnx_out[i], tvm_out[i], rtol=1e-05, atol=1e-05) + tvm_out = get_tvm_output_with_vm(if_model, [cond], target, ctx, freeze_params=True) + for i in range(len(tvm_out)): + tvm.testing.assert_allclose(onnx_out[i], tvm_out[i], rtol=1e-05, atol=1e-05) if __name__ == "__main__": From 3e65b12d0eefa74c7818e6a766c8d6ba3cbd71a9 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Thu, 5 Nov 2020 18:05:46 +0000 Subject: [PATCH 6/6] Just dont use onnxruntime for now i guess. --- tests/python/frontend/onnx/test_forward.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 24d8ecf81068..b84e55ac800c 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3841,9 +3841,8 @@ def test_loop(): verify_count_loop() -# TODO(jwfromm): enable cuda testing once onnxruntime version is updated in CI. -@tvm.testing.parametrize_targets("llvm") -def test_if(ctx, target): +@tvm.testing.uses_gpu +def test_if(): # Given a bool scalar input cond. # return constant tensor x if cond is True, otherwise return constant tensor y. then_out = onnx.helper.make_tensor_value_info("then_out", onnx.TensorProto.FLOAT, [5]) @@ -3881,11 +3880,12 @@ def test_if(ctx, target): if_model = onnx.helper.make_model(if_graph) cond = np.array(1).astype("bool") - onnx_out = get_onnxruntime_output(if_model, [cond], dtype="bool") + correct_out = x if cond else y - tvm_out = get_tvm_output_with_vm(if_model, [cond], target, ctx, freeze_params=True) - for i in range(len(tvm_out)): - tvm.testing.assert_allclose(onnx_out[i], tvm_out[i], rtol=1e-05, atol=1e-05) + for target, ctx in tvm.testing.enabled_targets(): + tvm_out = get_tvm_output_with_vm(if_model, [cond], target, ctx, freeze_params=True) + for i in range(len(tvm_out)): + tvm.testing.assert_allclose(correct_out[i], tvm_out[i], rtol=1e-05, atol=1e-05) if __name__ == "__main__":