From 9bffc970418806b5d4065b708f438016f87fd3ef Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Wed, 14 Oct 2020 16:11:45 -0600 Subject: [PATCH 1/2] Failing tests for Int32 avg_pooling with Int64 shapes --- tests/python/relay/test_op_grad_level2.py | 75 +++++++++++-------- tests/python/relay/test_op_level10.py | 22 +++--- tests/python/relay/test_op_level2.py | 87 ++++++++++++----------- 3 files changed, 104 insertions(+), 80 deletions(-) diff --git a/tests/python/relay/test_op_grad_level2.py b/tests/python/relay/test_op_grad_level2.py index 34bbf9e60b3a..85332da64221 100644 --- a/tests/python/relay/test_op_grad_level2.py +++ b/tests/python/relay/test_op_grad_level2.py @@ -66,39 +66,43 @@ def test_max_pool2d_grad(): ) -def verify_avg_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode, count_include_pad): - x = relay.var("x", relay.TensorType(x_shape, "float32")) - y = tvm.relay.nn.avg_pool2d( - x, - pool_size=pool_size, - strides=strides, - padding=padding, - ceil_mode=ceil_mode, - count_include_pad=count_include_pad, - ) - - fwd_func = relay.Function([x], y) - fwd_func = run_infer_type(fwd_func) - bwd_func = run_infer_type(gradient(fwd_func)) +def verify_avg_pool2d_grad( + x_shape, pool_size, strides, padding, ceil_mode, count_include_pad, dtype="float32" +): + + for shape_dtype in ["int32", "int64"]: + x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in x_shape], dtype=dtype) + y = tvm.relay.nn.avg_pool2d( + x, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + ) - data = np.random.rand(*x_shape).astype("float32") - ph, pw = padding - y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape) - out_grad = np.ones(shape=y_shape) - ref_grad = tvm.topi.testing.pool_grad_nchw( - data, - out_grad, - pool_size=pool_size, - strides=strides, - padding=[ph, pw, ph, pw], - pool_type="avg", - ceil_mode=ceil_mode, - ) + fwd_func = relay.Function([x], y) + fwd_func = run_infer_type(fwd_func) + bwd_func = run_infer_type(gradient(fwd_func)) + + data = np.random.rand(*x_shape).astype(dtype) + ph, pw = padding + y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape) + out_grad = np.ones(shape=y_shape) + ref_grad = tvm.topi.testing.pool_grad_nchw( + data, + out_grad, + pool_size=pool_size, + strides=strides, + padding=[ph, pw, ph, pw], + pool_type="avg", + ceil_mode=ceil_mode, + ) - for target, ctx in tvm.testing.enabled_targets(): - intrp = relay.create_executor(ctx=ctx, target=target) - op_res, (op_grad,) = intrp.evaluate(bwd_func)(data) - np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01) + for target, ctx in tvm.testing.enabled_targets(): + intrp = relay.create_executor(ctx=ctx, target=target) + op_res, (op_grad,) = intrp.evaluate(bwd_func)(data) + np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01) @tvm.testing.uses_gpu @@ -119,6 +123,15 @@ def test_avg_pool2d_grad(): ceil_mode=False, count_include_pad=False, ) + verify_avg_pool2d_grad( + (1, 4, 16, 16), + pool_size=(1, 1), + strides=(1, 1), + padding=(1, 1), + ceil_mode=False, + count_include_pad=False, + dtype="int32", + ) def verify_global_avg_pool2d_grad(x_shape): diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index bc565682d932..3ec1a5bb6129 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -425,17 +425,18 @@ def verify_ndarray_size(shape): def verify_adaptive_pool(dshape, out_size, pool_type, layout, dtype, opfunc): - x = relay.var("x", relay.TensorType(dshape, "float32")) - y = opfunc(x, out_size, layout) - func = relay.Function([x], y) + for shape_dtype in ["int32", "int64"]: + x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in dshape], dtype=dtype) + y = opfunc(x, out_size, layout) + func = relay.Function([x], y) - np_data = np.random.uniform(low=0, high=255, size=dshape).astype(dtype) - np_out = tvm.topi.testing.adaptive_pool(np_data, out_size, pool_type, layout) + np_data = np.random.uniform(low=0, high=255, size=dshape).astype(dtype) + np_out = tvm.topi.testing.adaptive_pool(np_data, out_size, pool_type, layout) - for target, ctx in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", ctx=ctx, target=target) - relay_out = intrp1.evaluate(func)(np_data) - tvm.testing.assert_allclose(relay_out.asnumpy(), np_out, rtol=1e-5, atol=1e-5) + for target, ctx in tvm.testing.enabled_targets(): + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + relay_out = intrp1.evaluate(func)(np_data) + tvm.testing.assert_allclose(relay_out.asnumpy(), np_out, rtol=1e-5, atol=1e-5) def verify_adaptive_pool2d(dshape, out_size, pool_type, layout="NCHW", dtype="float32"): @@ -452,6 +453,7 @@ def verify_adaptive_pool3d(dshape, out_size, pool_type, layout="NCHW", dtype="fl def test_adaptive_pool(): verify_adaptive_pool2d((1, 9, 224, 224), (1, 1), "max") verify_adaptive_pool2d((1, 3, 224, 224), (2, 3), "avg") + verify_adaptive_pool2d((1, 3, 224, 224), (2, 3), "avg", dtype="int32") verify_adaptive_pool2d((1, 14, 56, 78), (34, 13), "max") verify_adaptive_pool2d((1, 5, 46, 97), (4, 96), "avg") verify_adaptive_pool2d((1, 224, 224, 3), (1, 1), "max", layout="NHWC") @@ -459,6 +461,8 @@ def test_adaptive_pool(): verify_adaptive_pool3d((1, 16, 32, 32, 32), (1, 1, 1), "max", layout="NCDHW") verify_adaptive_pool3d((1, 16, 32, 32, 32), (1, 1, 1), "avg", layout="NCDHW") verify_adaptive_pool3d((1, 16, 32, 32, 32), (1, 1, 1), "avg", layout="NDHWC") + verify_adaptive_pool3d((1, 16, 32, 32, 32), (1, 1, 1), "avg", layout="NCDHW", dtype="int32") + verify_adaptive_pool3d((1, 16, 32, 32, 32), (1, 1, 1), "avg", layout="NDHWC", dtype="int32") verify_adaptive_pool3d((1, 16, 32, 32, 32), (2, 4, 4), "max", layout="NDHWC") diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index c25c2bf48ca7..546ea6019e56 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -959,15 +959,16 @@ def _test_pool2d_int(opfunc, reffunc, dtype): # test execution dtype = "int32" dshape = (1, 3, 28, 28) - x = relay.var("x", shape=dshape, dtype=dtype) - y = opfunc(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) - func = relay.Function([x], y) - data = np.random.randint(low=-128, high=128, size=dshape) - ref_res = reffunc(data.reshape(1, 3, 14, 2, 14, 2), axis=(3, 5)).astype(dtype) - for target, ctx in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", ctx=ctx, target=target) - op_res1 = intrp1.evaluate(func)(data) - tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + for shape_dtype in ["int32", "int64"]: + x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in dshape], dtype=dtype) + y = opfunc(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) + func = relay.Function([x], y) + data = np.random.randint(low=-128, high=128, size=dshape) + ref_res = reffunc(data.reshape(1, 3, 14, 2, 14, 2), axis=(3, 5)).astype(dtype) + for target, ctx in tvm.testing.enabled_targets(): + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(data) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) def _test_global_pool2d(opfunc, reffunc): @@ -1010,7 +1011,7 @@ def test_pool2d(): @tvm.testing.uses_gpu def test_pool1d(): - def _test_pool1d(opfunc, pool_size=(2,), strides=(2,), padding=(0, 0)): + def _test_pool1d(opfunc, pool_size=(2,), strides=(2,), padding=(0, 0), dtype="float32"): n, c, w = te.var("n"), 10, 224 x = relay.var("x", relay.TensorType((n, c, w), "float32")) y = opfunc(x, pool_size=(1,)) @@ -1018,24 +1019,26 @@ def _test_pool1d(opfunc, pool_size=(2,), strides=(2,), padding=(0, 0)): yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n, 10, 224), "float32") # test execution - dtype = "float32" dshape = (1, 3, 32) - x = relay.var("x", shape=dshape) - pool_type = "max" if "max" in str(opfunc) else "avg" - y = opfunc(x, pool_size=pool_size, strides=strides, padding=padding) - func = relay.Function([x], y) - data = np.random.uniform(size=dshape).astype(dtype) - ref_res = tvm.topi.testing.pool1d_ncw_python( - data, (2,), (2,), (0, 0), (1, 3, 16), pool_type, False - ) - for target, ctx in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", ctx=ctx, target=target) - op_res1 = intrp1.evaluate(func)(data) - tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + for shape_dtype in ["int32", "int64"]: + x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in dshape], dtype=dtype) + pool_type = "max" if "max" in str(opfunc) else "avg" + y = opfunc(x, pool_size=pool_size, strides=strides, padding=padding) + func = relay.Function([x], y) + data = np.random.uniform(size=dshape).astype(dtype) + ref_res = tvm.topi.testing.pool1d_ncw_python( + data, (2,), (2,), (0, 0), (1, 3, 16), pool_type, False + ) + for target, ctx in tvm.testing.enabled_targets(): + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(data) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) _test_pool1d(relay.nn.max_pool1d) + _test_pool1d(relay.nn.max_pool1d, dtype="int32") _test_pool1d(relay.nn.max_pool1d, pool_size=2, strides=2, padding=0) _test_pool1d(relay.nn.avg_pool1d) + _test_pool1d(relay.nn.avg_pool1d, dtype="int32") _test_pool1d(relay.nn.avg_pool1d, pool_size=2, strides=2, padding=0) @@ -1047,6 +1050,7 @@ def _test_pool3d( strides=(2, 2, 2), padding=(0, 0, 0, 0, 0, 0), out_shape=(1, 3, 16, 16, 16), + dtype="float32", ): n, c, d, h, w = te.size_var("n"), 10, 5, 224, 224 x = relay.var("x", relay.TensorType((n, c, d, h, w), "float32")) @@ -1057,30 +1061,33 @@ def _test_pool3d( # test execution dtype = "float32" dshape = (1, 3, 32, 32, 32) - x = relay.var("x", shape=dshape) - pool_type = "max" if "max" in str(opfunc) else "avg" - y = opfunc(x, pool_size=pool_size, strides=strides, padding=padding) - func = relay.Function([x], y) - # check output shape - f_out_shape = tuple(map(lambda x: int(x), run_infer_type(func).ret_type.shape)) - assert out_shape == f_out_shape, "Output shape mismatch. expected {}, actual {}".format( - out_shape, f_out_shape - ) - data = np.random.uniform(size=dshape).astype(dtype) - ref_res = tvm.topi.testing.pool3d_ncdhw_python( - data, pool_size, strides, padding, out_shape, pool_type, False - ) - for target, ctx in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", ctx=ctx, target=target) - op_res1 = intrp1.evaluate(func)(data) - tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + for shape_dtype in ["int32", "int64"]: + x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in dshape], dtype=dtype) + pool_type = "max" if "max" in str(opfunc) else "avg" + y = opfunc(x, pool_size=pool_size, strides=strides, padding=padding) + func = relay.Function([x], y) + # check output shape + f_out_shape = tuple(map(lambda x: int(x), run_infer_type(func).ret_type.shape)) + assert out_shape == f_out_shape, "Output shape mismatch. expected {}, actual {}".format( + out_shape, f_out_shape + ) + data = np.random.uniform(size=dshape).astype(dtype) + ref_res = tvm.topi.testing.pool3d_ncdhw_python( + data, pool_size, strides, padding, out_shape, pool_type, False + ) + for target, ctx in tvm.testing.enabled_targets(): + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(data) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) _test_pool3d(relay.nn.max_pool3d) + _test_pool3d(relay.nn.max_pool3d, dtype="int32") _test_pool3d(relay.nn.max_pool3d, padding=(2, 0, 0, 2, 0, 0), out_shape=(1, 3, 18, 16, 16)) _test_pool3d(relay.nn.max_pool3d, padding=(0, 3, 0, 0, 3, 0), out_shape=(1, 3, 16, 19, 16)) _test_pool3d(relay.nn.max_pool3d, padding=(0, 0, 4, 0, 0, 4), out_shape=(1, 3, 16, 16, 20)) _test_pool3d(relay.nn.max_pool3d, pool_size=2, padding=0, strides=2) _test_pool3d(relay.nn.avg_pool3d) + _test_pool3d(relay.nn.avg_pool3d, dtype="int32") _test_pool3d(relay.nn.avg_pool3d, padding=(2, 0, 0, 2, 0, 0), out_shape=(1, 3, 18, 16, 16)) _test_pool3d(relay.nn.avg_pool3d, padding=(0, 3, 0, 0, 3, 0), out_shape=(1, 3, 16, 19, 16)) _test_pool3d(relay.nn.avg_pool3d, padding=(0, 0, 4, 0, 0, 4), out_shape=(1, 3, 16, 16, 20)) From 08749e0f9c5ac3211fd4380cf6639a9de068177a Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Wed, 14 Oct 2020 16:12:49 -0600 Subject: [PATCH 2/2] fix pooling implementations --- include/tvm/topi/nn/pooling.h | 42 ++++++++++++++++++++++++----------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/include/tvm/topi/nn/pooling.h b/include/tvm/topi/nn/pooling.h index 935d399a6604..2396fc25c23f 100644 --- a/include/tvm/topi/nn/pooling.h +++ b/include/tvm/topi/nn/pooling.h @@ -75,8 +75,8 @@ inline Tensor pool_impl(const Tensor& x, const Array& kernel_size, auto stride_height = cast(DataType::DataType::Int(32), stride_size[0]); auto stride_width = cast(DataType::DataType::Int(32), stride_size[1]); - auto height = x->shape[height_axis]; - auto width = x->shape[width_axis]; + auto height = cast(DataType::DataType::Int(32), x->shape[height_axis]); + auto width = cast(DataType::DataType::Int(32), x->shape[width_axis]); auto pad_top = cast(DataType::DataType::Int(32), padding_size[0]); auto pad_left = cast(DataType::DataType::Int(32), padding_size[1]); @@ -107,6 +107,9 @@ inline Tensor pool_impl(const Tensor& x, const Array& kernel_size, auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width)); Array out_shape = x->shape; + for (size_t i = 0; i < out_shape.size(); ++i) { + out_shape.Set(i, cast(DataType::DataType::Int(32), out_shape[i])); + } out_shape.Set(height_axis, out_height); out_shape.Set(width_axis, out_width); @@ -189,8 +192,8 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, auto stride_height = cast(DataType::DataType::Int(32), stride_size[0]); auto stride_width = cast(DataType::DataType::Int(32), stride_size[1]); - auto height = x->shape[height_axis]; - auto width = x->shape[width_axis]; + auto height = cast(DataType::DataType::Int(32), x->shape[height_axis]); + auto width = cast(DataType::DataType::Int(32), x->shape[width_axis]); auto pad_top = cast(DataType::DataType::Int(32), padding_size[0]); auto pad_left = cast(DataType::DataType::Int(32), padding_size[1]); @@ -220,7 +223,12 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, auto dheight = tvm::te::reduce_axis(Range(0, kernel_height)); auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width)); - Array out_shape = x->shape; + Array data_shape = x->shape; + for (size_t i = 0; i < data_shape.size(); ++i) { + data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i])); + } + + Array out_shape = data_shape; out_shape.Set(height_axis, out_height); out_shape.Set(width_axis, out_width); @@ -232,7 +240,7 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, ((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1)); if (pool_type == kMaxPool) { - Array ravel_shape{x->shape.begin(), x->shape.end()}; + Array ravel_shape{data_shape.begin(), data_shape.end()}; ravel_shape.Set(height_axis, ravel_shape[height_axis] + pad_top + pad_bottom); ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right); @@ -257,7 +265,7 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, auto mp_inds = mp_argmax[0]; return tvm::te::compute( - x->shape, + data_shape, [&](const Array& inds) { Array pad_inds{inds.begin(), inds.end()}; pad_inds.Set(height_axis, pad_inds[height_axis] + pad_top); @@ -288,7 +296,7 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height)); auto windoww = tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width)); return tvm::te::compute( - x->shape, + data_shape, [&](const Array& inds) { PrimExpr pad_h_idx = inds[height_axis] + pad_top; PrimExpr pad_w_idx = inds[width_axis] + pad_left; @@ -483,10 +491,14 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const Array& output_ const auto n_dim = output_size.size(); CHECK_EQ(axes.size(), n_dim) << "The number of axes not equal to the in/out dimension"; - Array out_shape = x->shape; + Array data_shape = x->shape; + for (size_t i = 0; i < data_shape.size(); ++i) { + data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i])); + } + Array out_shape = data_shape; Array in_size, out_size; for (size_t i = 0; i < n_dim; ++i) { - in_size.push_back(x->shape[axes[i]]); + in_size.push_back(data_shape[axes[i]]); out_size.push_back(cast(DataType::Int(32), output_size[i])); out_shape.Set(axes[i], out_size[i]); } @@ -661,7 +673,11 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, std::vector pad_tail(k_size); Array pad_before(std::vector(x_size, 0)); Array pad_after(std::vector(x_size, 0)); - Array out_shape = x->shape; + Array data_shape = x->shape; + for (size_t i = 0; i < data_shape.size(); ++i) { + data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i])); + } + Array out_shape = data_shape; bool do_pad = false; for (int i = 0; i < k_size; i++) { @@ -687,7 +703,7 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, arith::Analyzer analyzer; auto out_dim = analyzer.Simplify( - indexdiv(x->shape[ii] - kernel[i] + pad_head[i] + pad_tail[i], stride[i]) + 1); + indexdiv(data_shape[ii] - kernel[i] + pad_head[i] + pad_tail[i], stride[i]) + 1); out_shape.Set(ii, out_dim); } @@ -746,7 +762,7 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, for (int i = 0; i < k_size; i++) { int ii = axis[i]; start[i] = output[ii] * stride[i] - pad_head[i]; - end[i] = min(start[i] + kernel[i], x->shape[ii]); + end[i] = min(start[i] + kernel[i], data_shape[ii]); start[i] = max(start[i], make_const(DataType::Int(32), 0)); kernel_size *= (end[i] - start[i]); }