From 49de14b00f554567ac4d4c2abfbc485811bf5688 Mon Sep 17 00:00:00 2001 From: Gemfield Date: Mon, 18 Mar 2019 22:04:43 +0800 Subject: [PATCH] Enhance upsample operator to adapt onnx opset version 9 --- 3rdparty/HalideIR | 2 +- 3rdparty/dmlc-core | 2 +- python/tvm/relay/frontend/onnx.py | 7 ++++++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/3rdparty/HalideIR b/3rdparty/HalideIR index 86351c40824d..b257a9221ee1 160000 --- a/3rdparty/HalideIR +++ b/3rdparty/HalideIR @@ -1 +1 @@ -Subproject commit 86351c40824dfc4cbb7447d70e5e63d9bd76eb90 +Subproject commit b257a9221ee1e5180d994b3488ddcc259b0ac157 diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core index 9acddddfc349..d07fb7a443b5 160000 --- a/3rdparty/dmlc-core +++ b/3rdparty/dmlc-core @@ -1 +1 @@ -Subproject commit 9acddddfc349eda4ef99552d11cb905afeafed39 +Subproject commit d07fb7a443b5db8a89d65a15a024af6a425615a5 diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index d322da31fc19..e92aa203b401 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -447,8 +447,13 @@ class Upsample(OnnxOpConverter): """ @classmethod - def _impl_v7(cls, inputs, attr, params): + def _impl_v9(cls, inputs, attr, params): scales = attr.get('scales') + if not scales: + #Here we are going to higher OPSET version. + assert len(inputs) == 2, "Upsample op take 2 inputs, {} given".format(len(inputs)) + scales = params[inputs[1].name_hint].asnumpy() + inputs = inputs[:1] assert len(scales) == 4 and scales[0] == 1.0 and scales[1] == 1.0 and scales[2] == scales[3] mode = attr.get('mode') if mode == b'nearest':