diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index fa404efc39cf..7ebad7297471 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2146,7 +2146,9 @@ def body_fn(*loop_inputs): # Get the output of the current loop using the updated inputs. with subgraph_scope: - loop_outputs = subgraph_scope.from_onnx(body, 11, get_output_expr=True) + loop_outputs = subgraph_scope.from_onnx( + body, graph_scope.opset, get_output_expr=True + ) # Unpack the body outputs and prepare variables for next iteration. new_cond = loop_outputs[0] new_loop_vars = [loop_outputs[i] for i in range(1, 1 + num_deps)] @@ -2197,6 +2199,43 @@ def body_fn(*loop_inputs): return outputs +class If(OnnxOpConverter): + """Operator converter for If""" + + @classmethod + def _impl_v1(cls, inputs, attr, params): + cond = inputs[0] + then_branch = attr.get("then_branch", None) + else_branch = attr.get("else_branch", None) + assert then_branch is not None and else_branch is not None + + # Create graph converters for both branches. + graph_scope = GraphProto.current + then_graph = GraphProto(graph_scope._shape, graph_scope._dtype) + then_graph._nodes = graph_scope._nodes.copy() + else_graph = GraphProto(graph_scope._shape, graph_scope._dtype) + else_graph._nodes = graph_scope._nodes.copy() + + # Convert each branch to a relay expression. + with then_graph: + then_expr = then_graph.from_onnx(then_branch, graph_scope.opset, get_output_expr=True) + with else_graph: + else_expr = else_graph.from_onnx(else_branch, graph_scope.opset, get_output_expr=True) + + # Add constants from both branches to parent graph. + graph_scope._params.update(then_graph._params) + then_free_vars = analysis.free_vars(then_expr) + for var in then_free_vars: + graph_scope._nodes.update({var.name_hint: var}) + graph_scope._params.update(else_graph._params) + else_free_vars = analysis.free_vars(else_expr) + for var in else_free_vars: + graph_scope._nodes.update({var.name_hint: var}) + + # Now we can construct the relay if statement and return. + return _expr.If(cond, then_expr, else_expr) + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -2354,6 +2393,7 @@ def _get_convert_map(opset): "Range": Range.get_converter(opset), # defs/control_flow "Loop": Loop.get_converter(opset), + "If": If.get_converter(opset), } @@ -2381,6 +2421,7 @@ def __init__(self, shape, dtype): self._num_param = 0 self._shape = shape if shape else {} self._dtype = dtype + self.opset = None def __enter__(self): self._old_manager = GraphProto.current @@ -2436,6 +2477,7 @@ def from_onnx(self, graph, opset, freeze_params=False, get_output_expr=False): params : dict A dict of name: tvm.nd.array pairs, used as pretrained weights """ + self.opset = opset # parse network inputs to relay, aka parameters for init_tensor in graph.initializer: if not init_tensor.name.strip(): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index bf27ba5ddcd9..b84e55ac800c 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -17,7 +17,7 @@ import numpy as np import math import onnx -from onnx import helper, TensorProto, mapping +from onnx import helper, TensorProto, mapping, numpy_helper import torch import torchvision import tvm.topi.testing @@ -3841,6 +3841,53 @@ def test_loop(): verify_count_loop() +@tvm.testing.uses_gpu +def test_if(): + # Given a bool scalar input cond. + # return constant tensor x if cond is True, otherwise return constant tensor y. + then_out = onnx.helper.make_tensor_value_info("then_out", onnx.TensorProto.FLOAT, [5]) + else_out = onnx.helper.make_tensor_value_info("else_out", onnx.TensorProto.FLOAT, [5]) + + x = np.array([1, 2, 3, 4, 5]).astype(np.float32) + y = np.array([5, 4, 3, 2, 1]).astype(np.float32) + + then_const_node = onnx.helper.make_node( + "Constant", inputs=[], outputs=["then_out"], value=onnx.numpy_helper.from_array(x) + ) + + else_const_node = onnx.helper.make_node( + "Constant", inputs=[], outputs=["else_out"], value=onnx.numpy_helper.from_array(y) + ) + + then_body = onnx.helper.make_graph([then_const_node], "then_body", [], [then_out]) + + else_body = onnx.helper.make_graph([else_const_node], "else_body", [], [else_out]) + + if_node = onnx.helper.make_node( + "If", inputs=["cond"], outputs=["res"], then_branch=then_body, else_branch=else_body + ) + + if_graph = onnx.helper.make_graph( + [if_node], + "if_outer", + inputs=[ + onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []), + ], + outputs=[ + onnx.helper.make_tensor_value_info("res", onnx.TensorProto.FLOAT, [5]), + ], + ) + + if_model = onnx.helper.make_model(if_graph) + cond = np.array(1).astype("bool") + correct_out = x if cond else y + + for target, ctx in tvm.testing.enabled_targets(): + tvm_out = get_tvm_output_with_vm(if_model, [cond], target, ctx, freeze_params=True) + for i in range(len(tvm_out)): + tvm.testing.assert_allclose(correct_out[i], tvm_out[i], rtol=1e-05, atol=1e-05) + + if __name__ == "__main__": test_flatten() test_reshape()