From 1ffd7986d5418fb4340f031e33b6553e5db6ef1d Mon Sep 17 00:00:00 2001 From: masa Date: Thu, 17 Sep 2020 09:26:26 +0900 Subject: [PATCH 1/5] clean up infer value usage --- python/tvm/relay/frontend/common.py | 8 ++++ python/tvm/relay/frontend/pytorch.py | 62 +++++++++++++--------------- 2 files changed, 36 insertions(+), 34 deletions(-) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index e4d605aa4560..ac94d463aa4f 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -563,6 +563,14 @@ def infer_value_simulated(input_val, params): return output_value +def try_infer_value(val, on_success, on_failure): + try: + ret = infer_value(val, {}).asnumpy() + return on_success(ret), True + except Exception: + return on_failure(), False + + def new_var(name_hint, type_annotation=None, shape=None, dtype="float32"): return _expr.var(name_hint, type_annotation, shape, dtype) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 9ceb9fc66ec4..09bcea2ffe7e 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -36,6 +36,7 @@ from .common import AttrCvt, get_relay_op from .common import infer_shape as _infer_shape from .common import infer_value as _infer_value +from .common import try_infer_value from .common import infer_value_simulated as _infer_value_simulated from .common import infer_type as _infer_type from ..prelude import Prelude, StaticTensorArrayOps @@ -185,11 +186,8 @@ def _impl(inputs, input_types): def _get_value(val, dtype): # dtype is a tvm dtype if isinstance(val, _expr.Expr): - try: - ret = _infer_value(_op.cast(val, dtype), {}).asnumpy() - ret = _expr.const(ret, dtype) - except Exception: - ret = _op.cast(val, dtype) + inp = _op.cast(val, dtype) + ret, _ = try_infer_value(inp, lambda ret: _expr.const(ret, dtype), lambda: inp) else: ret = _create_typed_const(val, dtype) return ret @@ -305,10 +303,9 @@ def _impl(inputs, input_types): dim = int(inputs[1]) stride = int(inputs[4]) if isinstance(inputs[2], _expr.Call): - try: - begin[dim] = np.asscalar(_infer_value(inputs[2], {}).asnumpy().astype(np.int)) - except Exception: - begin[dim] = inputs[2] + begin[dim], _ = try_infer_value( + inputs[2], lambda ret: np.asscalar(ret.astype(np.int)), lambda: inputs[2] + ) else: begin[dim] = int(inputs[2]) @@ -329,10 +326,9 @@ def _impl(inputs, input_types): target_end = int(inputs[3]) else: if isinstance(inputs[3], _expr.Expr): - try: - target_end = np.asscalar(_infer_value(inputs[3], {}).asnumpy().astype(np.int)) - except Exception: - target_end = inputs[3] + target_end, _ = try_infer_value( + inputs[3], lambda ret: np.asscalar(ret.astype(np.int)), lambda: inputs[3] + ) else: target_end = inputs[3] @@ -457,10 +453,7 @@ def _impl(inputs, input_types): sort = bool(inputs[4]) if isinstance(inputs[1], _expr.Expr): - try: - k = _infer_value(inputs[1], {}).asnumpy().tolist() - except Exception: - k = inputs[1] + k, _ = try_infer_value(inputs[1], lambda ret: ret.tolist(), lambda: inputs[1]) else: k = inputs[1] @@ -546,15 +539,15 @@ def _full_impl(data, fill_value, dtype): size.append(dim) new_shape.append(dim) else: - try: - dim = int(_infer_value(dim, {}).asnumpy()) + dim, success = try_infer_value(dim, lambda ret: int(ret), lambda: 0) + new_shape.append(dim) + + if success: if isinstance(size, list): size.append(dim) - new_shape.append(dim) - except Exception: + else: size = None need_reshape = True - new_shape.append(0) else: if isinstance(size, list): size.append(dim) @@ -1346,12 +1339,11 @@ def _impl(inputs, input_types): if isinstance(s, _expr.Constant): tmp_shape.append(int(s.data.asnumpy())) elif isinstance(s, _expr.Expr): - try: - dim = int(_infer_value(s, {}).asnumpy()) - tmp_shape.append(dim) - except Exception: + dim, success = try_infer_value(s, lambda ret: int(ret), lambda: s) + tmp_shape.append(dim) + + if not success: is_dyn = True - tmp_shape.append(s) else: tmp_shape.append(s) @@ -2312,13 +2304,15 @@ def _impl(inputs, input_types): if isinstance(inputs[1], _expr.Expr): out_size = inputs[1] elif isinstance(inputs[1], list): - try: - infer_res = [_infer_value(size, {}) for size in inputs[1]] - out_size = [np.asscalar(res.asnumpy().astype(np.int)) for res in infer_res] - except Exception: - h = _op.expand_dims(inputs[1][0], axis=0) - w = _op.expand_dims(inputs[1][1], axis=0) - out_size = _op.concatenate([h, w], axis=0) + out_size = [] + for i in [0, 1]: + size, success = try_infer_value( + inputs[1][i], + lambda ret: ret.astype(np.int), + lambda: _op.expand_dims(inputs[1][i], axis=0), + ) + out_size.append(size) + out_size = _op.concatenate(out_size, axis=0) data = inputs[0] align_corners = inputs[4] From d954678950145cd977ba6222b298e616fd2b8d00 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 17 Sep 2020 12:59:14 +0900 Subject: [PATCH 2/5] try silence pylint --- python/tvm/relay/frontend/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 09bcea2ffe7e..6a3500dacade 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=import-self, too-many-lines, len-as-condition, no-else-return, unused-variable, too-many-nested-blocks # pylint: disable=consider-iterating-dictionary, invalid-name, unused-argument, unused-variable, broad-except -# pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension +# pylint: disable=import-outside-toplevel, simplifiable-if-expression, cell-var-from-loop, unnecessary-lambda """PT: PyTorch frontend.""" import itertools import logging From 7593620ac2efc5c0d43b37fc80d73e460331e2f4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 17 Sep 2020 13:48:13 +0900 Subject: [PATCH 3/5] remove unused variable --- python/tvm/relay/frontend/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 6a3500dacade..d70ece5a5601 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2306,7 +2306,7 @@ def _impl(inputs, input_types): elif isinstance(inputs[1], list): out_size = [] for i in [0, 1]: - size, success = try_infer_value( + size, _ = try_infer_value( inputs[1][i], lambda ret: ret.astype(np.int), lambda: _op.expand_dims(inputs[1][i], axis=0), From 9ac55ca009d829a5d2b341c14911c800cad377d9 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 22 Sep 2020 05:40:30 +0900 Subject: [PATCH 4/5] make on_failuare optional --- python/tvm/relay/frontend/common.py | 11 +++++++++-- python/tvm/relay/frontend/pytorch.py | 12 +++++------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index ac94d463aa4f..d5ad7576a3d1 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -563,12 +563,19 @@ def infer_value_simulated(input_val, params): return output_value -def try_infer_value(val, on_success, on_failure): +def try_infer_value(val, on_success, on_failure=None): + """Try running infer_value on the input val, and if successful, pass the inferred value to + on_success callback. Otherwise, run on_failure callback if it is provided or return the + input val as output. In each case, the second return value indicates whether infer_value has + succeeded or not. + """ try: ret = infer_value(val, {}).asnumpy() return on_success(ret), True except Exception: - return on_failure(), False + if on_failure: + return on_failure(), False + return val, False def new_var(name_hint, type_annotation=None, shape=None, dtype="float32"): diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index d70ece5a5601..c667b0430f00 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -187,7 +187,7 @@ def _get_value(val, dtype): # dtype is a tvm dtype if isinstance(val, _expr.Expr): inp = _op.cast(val, dtype) - ret, _ = try_infer_value(inp, lambda ret: _expr.const(ret, dtype), lambda: inp) + ret, _ = try_infer_value(inp, lambda ret: _expr.const(ret, dtype)) else: ret = _create_typed_const(val, dtype) return ret @@ -303,9 +303,7 @@ def _impl(inputs, input_types): dim = int(inputs[1]) stride = int(inputs[4]) if isinstance(inputs[2], _expr.Call): - begin[dim], _ = try_infer_value( - inputs[2], lambda ret: np.asscalar(ret.astype(np.int)), lambda: inputs[2] - ) + begin[dim], _ = try_infer_value(inputs[2], lambda ret: np.asscalar(ret.astype(np.int))) else: begin[dim] = int(inputs[2]) @@ -327,7 +325,7 @@ def _impl(inputs, input_types): else: if isinstance(inputs[3], _expr.Expr): target_end, _ = try_infer_value( - inputs[3], lambda ret: np.asscalar(ret.astype(np.int)), lambda: inputs[3] + inputs[3], lambda ret: np.asscalar(ret.astype(np.int)) ) else: target_end = inputs[3] @@ -453,7 +451,7 @@ def _impl(inputs, input_types): sort = bool(inputs[4]) if isinstance(inputs[1], _expr.Expr): - k, _ = try_infer_value(inputs[1], lambda ret: ret.tolist(), lambda: inputs[1]) + k, _ = try_infer_value(inputs[1], lambda ret: ret.tolist()) else: k = inputs[1] @@ -1339,7 +1337,7 @@ def _impl(inputs, input_types): if isinstance(s, _expr.Constant): tmp_shape.append(int(s.data.asnumpy())) elif isinstance(s, _expr.Expr): - dim, success = try_infer_value(s, lambda ret: int(ret), lambda: s) + dim, success = try_infer_value(s, lambda ret: int(ret)) tmp_shape.append(dim) if not success: From 3ff21f7c8109f9a60d6f5a1ea07020bfa4463156 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 22 Sep 2020 07:18:40 +0900 Subject: [PATCH 5/5] make on_success optional True --- python/tvm/relay/frontend/common.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index d5ad7576a3d1..027d6bd76141 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -563,15 +563,17 @@ def infer_value_simulated(input_val, params): return output_value -def try_infer_value(val, on_success, on_failure=None): - """Try running infer_value on the input val, and if successful, pass the inferred value to - on_success callback. Otherwise, run on_failure callback if it is provided or return the - input val as output. In each case, the second return value indicates whether infer_value has - succeeded or not. +def try_infer_value(val, on_success=None, on_failure=None): + """Try running infer_value on the input val, and if successful, return the inferred value or + pass it to on_success callback if provided. Otherwise, run on_failure callback if it is + provided, or return the input val as output. In each case, the second return value + indicates whether infer_value has succeeded or not. """ try: ret = infer_value(val, {}).asnumpy() - return on_success(ret), True + if on_success: + return on_success(ret), True + return ret, True except Exception: if on_failure: return on_failure(), False