From e0c40663e89820a1f377e7f0200d68ea7a51d1c5 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Wed, 21 Oct 2020 13:21:49 -0700 Subject: [PATCH 1/5] Minor fix for some tf od models --- include/tvm/topi/transform.h | 15 +++++++------ python/tvm/relay/frontend/tensorflow.py | 22 ++++++++++++++----- python/tvm/relay/op/_transform.py | 5 +++++ .../python/topi/python/test_topi_transform.py | 1 + 4 files changed, 31 insertions(+), 12 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index e01eb703cb99..79d594fab601 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -896,10 +896,7 @@ inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y, if (x->shape.size() == 0) { return compute( condition->shape, - [&](const Array& indices) { - Array condition_idx{indices[0]}; - return tvm::tir::Select(condition(condition_idx) != 0, x(), y()); - }, + [&](const Array& indices) { return tvm::tir::Select(condition() != 0, x(), y()); }, name, tag); } else if (condition->shape.size() != 1) { CHECK_EQ(condition->shape.size(), x->shape.size()) @@ -913,9 +910,13 @@ inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y, }, name, tag); } else { - CHECK_EQ(topi::GetConstInt(condition->shape[0]), topi::GetConstInt(x->shape[0])) - << "If condition is 1-D, the first dimension must be the same as x: " << condition->shape[0] - << " vs " << x->shape[0]; + int64_t cond_first_dim = topi::GetConstInt(condition->shape[0]); + int64_t x_first_dim = topi::GetConstInt(x->shape[0]); + if (cond_first_dim > 0 && x_first_dim > 0) { + CHECK_EQ(cond_first_dim, x_first_dim) + << "If condition is 1-D, the first dimension must be the same as x: " << cond_first_dim + << " vs " << x_first_dim; + } return compute( x->shape, [&](const Array& indices) { diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 3df582a0c76a..9671e45a59a3 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1549,7 +1549,7 @@ def _impl(inputs, attr, params, mod): idx += st # Only return when in_shape is fully static in the range from begin to end. - if idx >= st: + if idx >= ed: ret = _expr.const(out_data, dtype) if shrink_axis_mask: ret = _op.squeeze(ret) @@ -1659,14 +1659,26 @@ def _transform_mask(stride_dim, ellipsis_mask): def _pad(name): def _impl(inputs, attr, params, mod): - padlist = _get_param(params, inputs[1]) - paddings = tuple(tuple(l) for l in padlist) + try: + padlist = _get_param(params, inputs[1]) + except (IndexError, KeyError, AttributeError): + try: + padlist = _infer_value(inputs[1], params, mod).asnumpy().tolist() + except Exception: + padlist = inputs[1] + + if isinstance(padlist, _expr.Expr): + paddings = padlist + else: + paddings = tuple(tuple(l) for l in padlist) attr["pad_width"] = paddings attr["pad_value"] = 0 new_inputs = [inputs[0]] if name == "PadV2": - constant_values = _get_num_param(params, inputs[2]) - attr["pad_value"] = constant_values + try: + attr["pad_value"] = _get_num_param(params, inputs[2]) + except (IndexError, KeyError, AttributeError): + attr["pad_value"] = inputs[2] return AttrCvt( op_name="pad", ignores=["Tpaddings"], diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 8f32d2c6f652..ecfa44f160b0 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -802,3 +802,8 @@ def stack_shape_func(attrs, inputs, _): if axis < 0: axis += inputs[0].shape[0] + 1 return [_stack_shape_func(inputs[0], convert(axis), convert(len(inputs)))] + + +@_reg.register_shape_func("where", False) +def where_shape_func(attrs, inputs, _): + return [topi.math.identity(inputs[1])] diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index f18b5397eefe..cdf0b8319087 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -839,6 +839,7 @@ def test_reshape(): @tvm.testing.uses_gpu def test_where(): + verify_where(()) verify_where((1, 2, 3, 4)) From 38da3f17d31d9fcdee06caac6c2909eb0c0911a5 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Wed, 21 Oct 2020 14:58:40 -0700 Subject: [PATCH 2/5] More fix --- include/tvm/topi/transform.h | 11 ++++++++++- python/tvm/relay/op/_transform.py | 10 +++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 79d594fab601..aa5c6d2a2256 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -896,7 +896,16 @@ inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y, if (x->shape.size() == 0) { return compute( condition->shape, - [&](const Array& indices) { return tvm::tir::Select(condition() != 0, x(), y()); }, + [&](const Array& indices) { + PrimExpr cond; + if (condition->shape.size() == 0) { + cond = condition(); + } else { + Array condition_idx{indices[0]}; + cond = condition(condition_idx); + } + return tvm::tir::Select(cond != 0, x(), y()); + }, name, tag); } else if (condition->shape.size() != 1) { CHECK_EQ(condition->shape.size(), x->shape.size()) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index ecfa44f160b0..a396585598b6 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -806,4 +806,12 @@ def stack_shape_func(attrs, inputs, _): @_reg.register_shape_func("where", False) def where_shape_func(attrs, inputs, _): - return [topi.math.identity(inputs[1])] + cond_shape = inputs[0] + x_shape = inputs[1] + + if len(x_shape.shape) == 0: + out_shape = cond_shape + else: + out_shape = x_shape + + return topi.math.identity(out_shape) From a01f76632ad893c4493f6f481e32b1b8748accb6 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Wed, 21 Oct 2020 15:10:38 -0700 Subject: [PATCH 3/5] Minor fix --- python/tvm/relay/op/_transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index a396585598b6..a34c5dbf25a9 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -814,4 +814,4 @@ def where_shape_func(attrs, inputs, _): else: out_shape = x_shape - return topi.math.identity(out_shape) + return [topi.math.identity(out_shape)] From 61a195fabcec6a6ea98dccaf041da398508d0c8e Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Wed, 21 Oct 2020 16:40:08 -0700 Subject: [PATCH 4/5] Fix lint --- python/tvm/relay/op/_transform.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index a34c5dbf25a9..2946a5a4ab55 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -778,6 +778,9 @@ def repeat_shape_func(attrs, inputs, _): @_reg.register_shape_func("broadcast_to_like", False) def broadcast_to_like_shape_func(attrs, inputs, _): + """ + Shape func for broadcast_to_like. + """ return [topi.math.identity(inputs[1])] @@ -798,6 +801,9 @@ def _stack_shape_func(data_shape, axis, num_inputs): @_reg.register_shape_func("stack", False) def stack_shape_func(attrs, inputs, _): + """ + Shape func for stack. + """ axis = get_const_int(attrs.axis) if axis < 0: axis += inputs[0].shape[0] + 1 @@ -806,6 +812,9 @@ def stack_shape_func(attrs, inputs, _): @_reg.register_shape_func("where", False) def where_shape_func(attrs, inputs, _): + """ + Shape func for where. + """ cond_shape = inputs[0] x_shape = inputs[1] From dd878b3d5bd1e05cf12707cc0ce20a0517c0abe0 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Thu, 22 Oct 2020 10:26:59 -0700 Subject: [PATCH 5/5] Minor fix --- python/tvm/relay/op/_transform.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 2946a5a4ab55..415529fdcb9a 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -817,10 +817,6 @@ def where_shape_func(attrs, inputs, _): """ cond_shape = inputs[0] x_shape = inputs[1] - - if len(x_shape.shape) == 0: - out_shape = cond_shape - else: - out_shape = x_shape + out_shape = x_shape if x_shape.shape else cond_shape return [topi.math.identity(out_shape)]