From da17fee580e57bb0f8d414c1091bc0fb56e72d40 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Fri, 5 Jun 2020 10:55:28 -0700 Subject: [PATCH 1/2] edit onnx parser to infer values in post order to speed up onnx imports with many calls to infer_value --- python/tvm/relay/frontend/onnx.py | 119 +++++++++++++++++++++++++++++- 1 file changed, 116 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 08027a287bba..92d9cf6884dc 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -27,12 +27,29 @@ from .. import function as _function from .. import op as _op from .. import vision as _vision + +from ..function import Function +from ..expr import Call, Let, Var, GlobalVar +from ..expr import If, Tuple, TupleGetItem, Constant +from ..expr import RefCreate, RefRead, RefWrite +from ..expr_functor import ExprFunctor +from ..adt import Constructor, Match, Clause + from .common import AttrCvt, Renamer from .common import get_relay_op, new_var, infer_shape, infer_channels -from .common import infer_type, infer_value, infer_value_simulated, get_name +from .common import infer_type, get_name +from .common import infer_value as _infer_value +from .common import infer_value_simulated as _infer_value_simulated __all__ = ['from_onnx'] +g = None + +def infer_value(input_val, params, mod=None): + return g.infer_value(input_val, params, mod) + +def infer_value_simulated(input_val, params): + return g.infer_value_simulated(input_val, params) class onnx_input(): """ Dual purpose list or dictionary access object.""" @@ -1879,8 +1896,7 @@ def _get_convert_map(opset): 'NonZero': NonZero.get_converter(opset), } - -class GraphProto(object): +class GraphProto(ExprFunctor): """A helper class for handling Relay expression copying from pb2.GraphProto. Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto @@ -1902,6 +1918,101 @@ def __init__(self, shape, dtype): self._shape = shape if shape else {} self._dtype = dtype + #For infering Values + self._tmp_params = {} + self._infer_simulated = True + self._mod = None + self.memo_map = {} + + def infer_value(self, input_val, params, mod=None): + self._tmp_params = params + self._infer_simulated = False + self._mod = mod + return self.visit(input_val).data + #return _infer_value(input_val, params, mod) + + def infer_value_simulated(self, input_val, params): + self._tmp_params = params + self._infer_simulated = True + return self.visit(input_val).data + #return _infer_value_simulated(input_val, params) + + def infer(self, expr): + if self._infer_simulated: + out = _infer_value_simulated(expr, self._tmp_params) + else: + out = _infer_value(expr, self._tmp_params) + return _expr.const(out.asnumpy()) + + def visit_function(self, fn): + new_params = [self.visit(x) for x in fn.params] + new_body = self.visit(fn.body) + return self.infer(Function( + list(new_params), + new_body, + fn.ret_type, + fn.type_params, + fn.attrs)) + + def visit_let(self, let): + new_var = self.visit(let.var) + new_val = self.visit(let.value) + new_body = self.visit(let.body) + return self.infer(Let(new_var, new_val, new_body)) + + def visit_call(self, call): + new_fn = self.visit(call.op) + new_args = [self.visit(arg) for arg in call.args] + return self.infer(Call(new_fn, new_args, call.attrs)) + + def visit_var(self, var): + return self.infer(var) + + def visit_global_id(self, global_var): + return self.infer(global_var) + + def visit_if(self, ite): + return self.infer(If( + self.visit(ite.cond), + self.visit(ite.true_branch), + self.visit(ite.false_branch))) + + def visit_tuple(self, tup): + return Tuple([self.visit(field) for field in tup.fields]) + + def visit_tuple_getitem(self, op): + tuple_value = self.visit(op.tuple_value) + if not tuple_value.same_as(op.tuple_value): + return self.infer(TupleGetItem(tuple_value, op.index)) + return self.infer(op) + + def visit_global_var(self, gvar): + return self.infer(gvar) + + def visit_op(self, op): + return op + + def visit_constant(self, const): + return const + + def visit_constructor(self, con): + return con + + def visit_match(self, m): + return self.infer(Match( + self.visit(m.data), + [Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses], + complete=m.complete)) + + def visit_ref_create(self, r): + return RefCreate(self.visit(r.value)) + + def visit_ref_write(self, r): + return RefWrite(self.visit(r.ref), self.visit(r.value)) + + def visit_ref_read(self, r): + return RefRead(self.visit(r.ref)) + def from_onnx(self, graph, opset): """Construct Relay expression from ONNX graph. @@ -2160,6 +2271,7 @@ def from_onnx(model, warnings.warn(str(e)) except ImportError: pass + global g g = GraphProto(shape, dtype) graph = model.graph if opset is None: @@ -2168,4 +2280,5 @@ def from_onnx(model, except AttributeError: opset = 1 mod, params = g.from_onnx(graph, opset) + g = None return mod, params From b510eaf548369c84eeff5ffbc66d0013433d261f Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Tue, 9 Jun 2020 11:16:03 -0700 Subject: [PATCH 2/2] fix pylint --- python/tvm/relay/frontend/onnx.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 92d9cf6884dc..2bcd9e306f53 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -29,11 +29,11 @@ from .. import vision as _vision from ..function import Function -from ..expr import Call, Let, Var, GlobalVar -from ..expr import If, Tuple, TupleGetItem, Constant +from ..expr import Call, Let +from ..expr import If, Tuple, TupleGetItem from ..expr import RefCreate, RefRead, RefWrite from ..expr_functor import ExprFunctor -from ..adt import Constructor, Match, Clause +from ..adt import Match, Clause from .common import AttrCvt, Renamer from .common import get_relay_op, new_var, infer_shape, infer_channels @@ -1922,7 +1922,7 @@ def __init__(self, shape, dtype): self._tmp_params = {} self._infer_simulated = True self._mod = None - self.memo_map = {} + super(GraphProto, self).__init__() def infer_value(self, input_val, params, mod=None): self._tmp_params = params @@ -1955,10 +1955,10 @@ def visit_function(self, fn): fn.attrs)) def visit_let(self, let): - new_var = self.visit(let.var) - new_val = self.visit(let.value) - new_body = self.visit(let.body) - return self.infer(Let(new_var, new_val, new_body)) + newvar = self.visit(let.var) + newval = self.visit(let.value) + newbody = self.visit(let.body) + return self.infer(Let(newvar, newval, newbody)) def visit_call(self, call): new_fn = self.visit(call.op)