From e691eedecd1e12755be536a543895b21c84696ad Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Thu, 13 Aug 2020 13:37:47 -0700 Subject: [PATCH 1/2] Update precision in the ONNX strided_slice, update precision of ToScalar --- python/tvm/relay/frontend/onnx.py | 8 ++++---- src/relay/transforms/pattern_util.h | 6 +++--- tests/python/frontend/onnx/test_forward.py | 11 ++++++----- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 74626d49a9dd..f54a145882a9 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1045,8 +1045,8 @@ def _impl_v1(cls, inputs, attr, params): end = list(attr['ends']) return _op.strided_slice(inputs[0], - begin=_expr.const(begin, dtype="int32"), - end=_expr.const(end, dtype="int32")) + begin=_expr.const(begin, dtype="int64"), + end=_expr.const(end, dtype="int64")) @classmethod def _impl_v10(cls, inputs, attr, params): @@ -1063,8 +1063,8 @@ def _impl_v10(cls, inputs, attr, params): starts = new_starts ends = new_ends return _op.strided_slice(inputs[0], - begin=_expr.const(starts, dtype="int32"), - end=_expr.const(ends, dtype="int32")) + begin=_expr.const(starts, dtype="int64"), + end=_expr.const(ends, dtype="int64")) class Gather(OnnxOpConverter): diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h index ee655037bda0..003e02a195cc 100644 --- a/src/relay/transforms/pattern_util.h +++ b/src/relay/transforms/pattern_util.h @@ -374,7 +374,7 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) { * \param i element index * \return Converted scalar value. */ -static inline double ToScalar(const runtime::NDArray& array, size_t i = 0) { +static inline long double ToScalar(const runtime::NDArray& array, size_t i = 0) { if (array->dtype.code == kDLInt) { if (array->dtype.bits == 8) { return reinterpret_cast(array->data)[i]; @@ -423,8 +423,8 @@ static inline Array ToVector(const runtime::NDArray& array) { size_t len = array.Shape().front(); Array out; for (size_t i = 0; i < len; ++i) { - double elem_val = ToScalar(array, i); - out.push_back(Integer(static_cast(elem_val))); + long double elem_val = ToScalar(array, i); + out.push_back(Integer(IntImm(DataType::Int(64), static_cast(elem_val)))); } return out; } diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 14b827c79248..c376c9aa78ea 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -478,15 +478,15 @@ def _test_slice_iteration_v10(indata, outdata, starts, ends, axes=None): inputs = [ helper.make_tensor_value_info("data", TensorProto.FLOAT, list(indata.shape)), - helper.make_tensor_value_info("starts", TensorProto.INT32, + helper.make_tensor_value_info("starts", TensorProto.INT64, list(starts.shape)), - helper.make_tensor_value_info("ends", TensorProto.INT32, + helper.make_tensor_value_info("ends", TensorProto.INT64, list(ends.shape)) ] initializer = [ - helper.make_tensor("starts", TensorProto.INT32, list(starts.shape), + helper.make_tensor("starts", TensorProto.INT64, list(starts.shape), starts), - helper.make_tensor("ends", TensorProto.INT32, list(ends.shape), ends) + helper.make_tensor("ends", TensorProto.INT64, list(ends.shape), ends) ] if axes: @@ -534,7 +534,8 @@ def test_slice(): _test_slice_iteration_v10(x, x[0:3, 0:10], (0, 0), (3, 10), (0, 1)) _test_slice_iteration_v10(x, x[:, :, 3:4], (0, 0, 3), (20, 10, 4)) _test_slice_iteration_v10(x, x[:, 1:1000], (1), (1000), (1)) - _test_slice_iteration_v10(x, x[:, 0:-1], (0), (-1), (1)) + x = np.random.randn(1, 1, 1, 128).astype(np.float32) + _test_slice_iteration_v10(x, x, (0, 0), (9223372036854775807, 9223372036854775807), (0, 3)) def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs): From 481c47fc814eae7df03e866654b0f4d7afc786db Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Thu, 13 Aug 2020 15:42:52 -0700 Subject: [PATCH 2/2] fix tests --- src/relay/transforms/pattern_util.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h index 003e02a195cc..9f034848c3d7 100644 --- a/src/relay/transforms/pattern_util.h +++ b/src/relay/transforms/pattern_util.h @@ -424,7 +424,7 @@ static inline Array ToVector(const runtime::NDArray& array) { Array out; for (size_t i = 0; i < len; ++i) { long double elem_val = ToScalar(array, i); - out.push_back(Integer(IntImm(DataType::Int(64), static_cast(elem_val)))); + out.push_back(Integer(IntImm(DataType::Int(32), static_cast(elem_val)))); } return out; }