Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 29 additions & 13 deletions include/tvm/topi/nn/pooling.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ inline Tensor pool_impl(const Tensor& x, const Array<PrimExpr>& kernel_size,
auto stride_height = cast(DataType::DataType::Int(32), stride_size[0]);
auto stride_width = cast(DataType::DataType::Int(32), stride_size[1]);

auto height = x->shape[height_axis];
auto width = x->shape[width_axis];
auto height = cast(DataType::DataType::Int(32), x->shape[height_axis]);
auto width = cast(DataType::DataType::Int(32), x->shape[width_axis]);

auto pad_top = cast(DataType::DataType::Int(32), padding_size[0]);
auto pad_left = cast(DataType::DataType::Int(32), padding_size[1]);
Expand Down Expand Up @@ -107,6 +107,9 @@ inline Tensor pool_impl(const Tensor& x, const Array<PrimExpr>& kernel_size,
auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width));

Array<PrimExpr> out_shape = x->shape;
for (size_t i = 0; i < out_shape.size(); ++i) {
out_shape.Set(i, cast(DataType::DataType::Int(32), out_shape[i]));
}
out_shape.Set(height_axis, out_height);
out_shape.Set(width_axis, out_width);

Expand Down Expand Up @@ -189,8 +192,8 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
auto stride_height = cast(DataType::DataType::Int(32), stride_size[0]);
auto stride_width = cast(DataType::DataType::Int(32), stride_size[1]);

auto height = x->shape[height_axis];
auto width = x->shape[width_axis];
auto height = cast(DataType::DataType::Int(32), x->shape[height_axis]);
auto width = cast(DataType::DataType::Int(32), x->shape[width_axis]);

auto pad_top = cast(DataType::DataType::Int(32), padding_size[0]);
auto pad_left = cast(DataType::DataType::Int(32), padding_size[1]);
Expand Down Expand Up @@ -220,7 +223,12 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
auto dheight = tvm::te::reduce_axis(Range(0, kernel_height));
auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width));

Array<PrimExpr> out_shape = x->shape;
Array<PrimExpr> data_shape = x->shape;
for (size_t i = 0; i < data_shape.size(); ++i) {
data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i]));
}

Array<PrimExpr> out_shape = data_shape;
out_shape.Set(height_axis, out_height);
out_shape.Set(width_axis, out_width);

Expand All @@ -232,7 +240,7 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1));

if (pool_type == kMaxPool) {
Array<PrimExpr> ravel_shape{x->shape.begin(), x->shape.end()};
Array<PrimExpr> ravel_shape{data_shape.begin(), data_shape.end()};
ravel_shape.Set(height_axis, ravel_shape[height_axis] + pad_top + pad_bottom);
ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right);

Expand All @@ -257,7 +265,7 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
auto mp_inds = mp_argmax[0];

return tvm::te::compute(
x->shape,
data_shape,
[&](const Array<Var>& inds) {
Array<PrimExpr> pad_inds{inds.begin(), inds.end()};
pad_inds.Set(height_axis, pad_inds[height_axis] + pad_top);
Expand Down Expand Up @@ -288,7 +296,7 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height));
auto windoww = tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width));
return tvm::te::compute(
x->shape,
data_shape,
[&](const Array<Var>& inds) {
PrimExpr pad_h_idx = inds[height_axis] + pad_top;
PrimExpr pad_w_idx = inds[width_axis] + pad_left;
Expand Down Expand Up @@ -483,10 +491,14 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const Array<PrimExpr>& output_
const auto n_dim = output_size.size();
CHECK_EQ(axes.size(), n_dim) << "The number of axes not equal to the in/out dimension";

Array<PrimExpr> out_shape = x->shape;
Array<PrimExpr> data_shape = x->shape;
for (size_t i = 0; i < data_shape.size(); ++i) {
data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i]));
}
Array<PrimExpr> out_shape = data_shape;
Array<PrimExpr> in_size, out_size;
for (size_t i = 0; i < n_dim; ++i) {
in_size.push_back(x->shape[axes[i]]);
in_size.push_back(data_shape[axes[i]]);
out_size.push_back(cast(DataType::Int(32), output_size[i]));
out_shape.Set(axes[i], out_size[i]);
}
Expand Down Expand Up @@ -661,7 +673,11 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array<PrimExpr>& kernel_size,
std::vector<PrimExpr> pad_tail(k_size);
Array<PrimExpr> pad_before(std::vector<PrimExpr>(x_size, 0));
Array<PrimExpr> pad_after(std::vector<PrimExpr>(x_size, 0));
Array<PrimExpr> out_shape = x->shape;
Array<PrimExpr> data_shape = x->shape;
for (size_t i = 0; i < data_shape.size(); ++i) {
data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i]));
}
Array<PrimExpr> out_shape = data_shape;

bool do_pad = false;
for (int i = 0; i < k_size; i++) {
Expand All @@ -687,7 +703,7 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array<PrimExpr>& kernel_size,

arith::Analyzer analyzer;
auto out_dim = analyzer.Simplify(
indexdiv(x->shape[ii] - kernel[i] + pad_head[i] + pad_tail[i], stride[i]) + 1);
indexdiv(data_shape[ii] - kernel[i] + pad_head[i] + pad_tail[i], stride[i]) + 1);

out_shape.Set(ii, out_dim);
}
Expand Down Expand Up @@ -746,7 +762,7 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array<PrimExpr>& kernel_size,
for (int i = 0; i < k_size; i++) {
int ii = axis[i];
start[i] = output[ii] * stride[i] - pad_head[i];
end[i] = min(start[i] + kernel[i], x->shape[ii]);
end[i] = min(start[i] + kernel[i], data_shape[ii]);
start[i] = max(start[i], make_const(DataType::Int(32), 0));
kernel_size *= (end[i] - start[i]);
}
Expand Down
75 changes: 44 additions & 31 deletions tests/python/relay/test_op_grad_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,39 +66,43 @@ def test_max_pool2d_grad():
)


def verify_avg_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode, count_include_pad):
x = relay.var("x", relay.TensorType(x_shape, "float32"))
y = tvm.relay.nn.avg_pool2d(
x,
pool_size=pool_size,
strides=strides,
padding=padding,
ceil_mode=ceil_mode,
count_include_pad=count_include_pad,
)

fwd_func = relay.Function([x], y)
fwd_func = run_infer_type(fwd_func)
bwd_func = run_infer_type(gradient(fwd_func))
def verify_avg_pool2d_grad(
x_shape, pool_size, strides, padding, ceil_mode, count_include_pad, dtype="float32"
):

for shape_dtype in ["int32", "int64"]:
x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in x_shape], dtype=dtype)
y = tvm.relay.nn.avg_pool2d(
x,
pool_size=pool_size,
strides=strides,
padding=padding,
ceil_mode=ceil_mode,
count_include_pad=count_include_pad,
)

data = np.random.rand(*x_shape).astype("float32")
ph, pw = padding
y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape)
out_grad = np.ones(shape=y_shape)
ref_grad = tvm.topi.testing.pool_grad_nchw(
data,
out_grad,
pool_size=pool_size,
strides=strides,
padding=[ph, pw, ph, pw],
pool_type="avg",
ceil_mode=ceil_mode,
)
fwd_func = relay.Function([x], y)
fwd_func = run_infer_type(fwd_func)
bwd_func = run_infer_type(gradient(fwd_func))

data = np.random.rand(*x_shape).astype(dtype)
ph, pw = padding
y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape)
out_grad = np.ones(shape=y_shape)
ref_grad = tvm.topi.testing.pool_grad_nchw(
data,
out_grad,
pool_size=pool_size,
strides=strides,
padding=[ph, pw, ph, pw],
pool_type="avg",
ceil_mode=ceil_mode,
)

for target, ctx in tvm.testing.enabled_targets():
intrp = relay.create_executor(ctx=ctx, target=target)
op_res, (op_grad,) = intrp.evaluate(bwd_func)(data)
np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)
for target, ctx in tvm.testing.enabled_targets():
intrp = relay.create_executor(ctx=ctx, target=target)
op_res, (op_grad,) = intrp.evaluate(bwd_func)(data)
np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)


@tvm.testing.uses_gpu
Expand All @@ -119,6 +123,15 @@ def test_avg_pool2d_grad():
ceil_mode=False,
count_include_pad=False,
)
verify_avg_pool2d_grad(
(1, 4, 16, 16),
pool_size=(1, 1),
strides=(1, 1),
padding=(1, 1),
ceil_mode=False,
count_include_pad=False,
dtype="int32",
)


def verify_global_avg_pool2d_grad(x_shape):
Expand Down
22 changes: 13 additions & 9 deletions tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,17 +425,18 @@ def verify_ndarray_size(shape):


def verify_adaptive_pool(dshape, out_size, pool_type, layout, dtype, opfunc):
x = relay.var("x", relay.TensorType(dshape, "float32"))
y = opfunc(x, out_size, layout)
func = relay.Function([x], y)
for shape_dtype in ["int32", "int64"]:
x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in dshape], dtype=dtype)
y = opfunc(x, out_size, layout)
func = relay.Function([x], y)

np_data = np.random.uniform(low=0, high=255, size=dshape).astype(dtype)
np_out = tvm.topi.testing.adaptive_pool(np_data, out_size, pool_type, layout)
np_data = np.random.uniform(low=0, high=255, size=dshape).astype(dtype)
np_out = tvm.topi.testing.adaptive_pool(np_data, out_size, pool_type, layout)

for target, ctx in tvm.testing.enabled_targets():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
relay_out = intrp1.evaluate(func)(np_data)
tvm.testing.assert_allclose(relay_out.asnumpy(), np_out, rtol=1e-5, atol=1e-5)
for target, ctx in tvm.testing.enabled_targets():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
relay_out = intrp1.evaluate(func)(np_data)
tvm.testing.assert_allclose(relay_out.asnumpy(), np_out, rtol=1e-5, atol=1e-5)


def verify_adaptive_pool2d(dshape, out_size, pool_type, layout="NCHW", dtype="float32"):
Expand All @@ -452,13 +453,16 @@ def verify_adaptive_pool3d(dshape, out_size, pool_type, layout="NCHW", dtype="fl
def test_adaptive_pool():
verify_adaptive_pool2d((1, 9, 224, 224), (1, 1), "max")
verify_adaptive_pool2d((1, 3, 224, 224), (2, 3), "avg")
verify_adaptive_pool2d((1, 3, 224, 224), (2, 3), "avg", dtype="int32")
verify_adaptive_pool2d((1, 14, 56, 78), (34, 13), "max")
verify_adaptive_pool2d((1, 5, 46, 97), (4, 96), "avg")
verify_adaptive_pool2d((1, 224, 224, 3), (1, 1), "max", layout="NHWC")
verify_adaptive_pool2d((1, 3, 224, 224), (2, 3), "avg", layout="NHWC")
verify_adaptive_pool3d((1, 16, 32, 32, 32), (1, 1, 1), "max", layout="NCDHW")
verify_adaptive_pool3d((1, 16, 32, 32, 32), (1, 1, 1), "avg", layout="NCDHW")
verify_adaptive_pool3d((1, 16, 32, 32, 32), (1, 1, 1), "avg", layout="NDHWC")
verify_adaptive_pool3d((1, 16, 32, 32, 32), (1, 1, 1), "avg", layout="NCDHW", dtype="int32")
verify_adaptive_pool3d((1, 16, 32, 32, 32), (1, 1, 1), "avg", layout="NDHWC", dtype="int32")
verify_adaptive_pool3d((1, 16, 32, 32, 32), (2, 4, 4), "max", layout="NDHWC")


Expand Down
87 changes: 47 additions & 40 deletions tests/python/relay/test_op_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,15 +959,16 @@ def _test_pool2d_int(opfunc, reffunc, dtype):
# test execution
dtype = "int32"
dshape = (1, 3, 28, 28)
x = relay.var("x", shape=dshape, dtype=dtype)
y = opfunc(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
func = relay.Function([x], y)
data = np.random.randint(low=-128, high=128, size=dshape)
ref_res = reffunc(data.reshape(1, 3, 14, 2, 14, 2), axis=(3, 5)).astype(dtype)
for target, ctx in tvm.testing.enabled_targets():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
for shape_dtype in ["int32", "int64"]:
x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in dshape], dtype=dtype)
y = opfunc(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
func = relay.Function([x], y)
data = np.random.randint(low=-128, high=128, size=dshape)
ref_res = reffunc(data.reshape(1, 3, 14, 2, 14, 2), axis=(3, 5)).astype(dtype)
for target, ctx in tvm.testing.enabled_targets():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)


def _test_global_pool2d(opfunc, reffunc):
Expand Down Expand Up @@ -1010,32 +1011,34 @@ def test_pool2d():

@tvm.testing.uses_gpu
def test_pool1d():
def _test_pool1d(opfunc, pool_size=(2,), strides=(2,), padding=(0, 0)):
def _test_pool1d(opfunc, pool_size=(2,), strides=(2,), padding=(0, 0), dtype="float32"):
n, c, w = te.var("n"), 10, 224
x = relay.var("x", relay.TensorType((n, c, w), "float32"))
y = opfunc(x, pool_size=(1,))
assert "pool_size=" in y.astext()
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, 10, 224), "float32")
# test execution
dtype = "float32"
dshape = (1, 3, 32)
x = relay.var("x", shape=dshape)
pool_type = "max" if "max" in str(opfunc) else "avg"
y = opfunc(x, pool_size=pool_size, strides=strides, padding=padding)
func = relay.Function([x], y)
data = np.random.uniform(size=dshape).astype(dtype)
ref_res = tvm.topi.testing.pool1d_ncw_python(
data, (2,), (2,), (0, 0), (1, 3, 16), pool_type, False
)
for target, ctx in tvm.testing.enabled_targets():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
for shape_dtype in ["int32", "int64"]:
x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in dshape], dtype=dtype)
pool_type = "max" if "max" in str(opfunc) else "avg"
y = opfunc(x, pool_size=pool_size, strides=strides, padding=padding)
func = relay.Function([x], y)
data = np.random.uniform(size=dshape).astype(dtype)
ref_res = tvm.topi.testing.pool1d_ncw_python(
data, (2,), (2,), (0, 0), (1, 3, 16), pool_type, False
)
for target, ctx in tvm.testing.enabled_targets():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)

_test_pool1d(relay.nn.max_pool1d)
_test_pool1d(relay.nn.max_pool1d, dtype="int32")
_test_pool1d(relay.nn.max_pool1d, pool_size=2, strides=2, padding=0)
_test_pool1d(relay.nn.avg_pool1d)
_test_pool1d(relay.nn.avg_pool1d, dtype="int32")
_test_pool1d(relay.nn.avg_pool1d, pool_size=2, strides=2, padding=0)


Expand All @@ -1047,6 +1050,7 @@ def _test_pool3d(
strides=(2, 2, 2),
padding=(0, 0, 0, 0, 0, 0),
out_shape=(1, 3, 16, 16, 16),
dtype="float32",
):
n, c, d, h, w = te.size_var("n"), 10, 5, 224, 224
x = relay.var("x", relay.TensorType((n, c, d, h, w), "float32"))
Expand All @@ -1057,30 +1061,33 @@ def _test_pool3d(
# test execution
dtype = "float32"
dshape = (1, 3, 32, 32, 32)
x = relay.var("x", shape=dshape)
pool_type = "max" if "max" in str(opfunc) else "avg"
y = opfunc(x, pool_size=pool_size, strides=strides, padding=padding)
func = relay.Function([x], y)
# check output shape
f_out_shape = tuple(map(lambda x: int(x), run_infer_type(func).ret_type.shape))
assert out_shape == f_out_shape, "Output shape mismatch. expected {}, actual {}".format(
out_shape, f_out_shape
)
data = np.random.uniform(size=dshape).astype(dtype)
ref_res = tvm.topi.testing.pool3d_ncdhw_python(
data, pool_size, strides, padding, out_shape, pool_type, False
)
for target, ctx in tvm.testing.enabled_targets():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
for shape_dtype in ["int32", "int64"]:
x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in dshape], dtype=dtype)
pool_type = "max" if "max" in str(opfunc) else "avg"
y = opfunc(x, pool_size=pool_size, strides=strides, padding=padding)
func = relay.Function([x], y)
# check output shape
f_out_shape = tuple(map(lambda x: int(x), run_infer_type(func).ret_type.shape))
assert out_shape == f_out_shape, "Output shape mismatch. expected {}, actual {}".format(
out_shape, f_out_shape
)
data = np.random.uniform(size=dshape).astype(dtype)
ref_res = tvm.topi.testing.pool3d_ncdhw_python(
data, pool_size, strides, padding, out_shape, pool_type, False
)
for target, ctx in tvm.testing.enabled_targets():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)

_test_pool3d(relay.nn.max_pool3d)
_test_pool3d(relay.nn.max_pool3d, dtype="int32")
_test_pool3d(relay.nn.max_pool3d, padding=(2, 0, 0, 2, 0, 0), out_shape=(1, 3, 18, 16, 16))
_test_pool3d(relay.nn.max_pool3d, padding=(0, 3, 0, 0, 3, 0), out_shape=(1, 3, 16, 19, 16))
_test_pool3d(relay.nn.max_pool3d, padding=(0, 0, 4, 0, 0, 4), out_shape=(1, 3, 16, 16, 20))
_test_pool3d(relay.nn.max_pool3d, pool_size=2, padding=0, strides=2)
_test_pool3d(relay.nn.avg_pool3d)
_test_pool3d(relay.nn.avg_pool3d, dtype="int32")
_test_pool3d(relay.nn.avg_pool3d, padding=(2, 0, 0, 2, 0, 0), out_shape=(1, 3, 18, 16, 16))
_test_pool3d(relay.nn.avg_pool3d, padding=(0, 3, 0, 0, 3, 0), out_shape=(1, 3, 16, 19, 16))
_test_pool3d(relay.nn.avg_pool3d, padding=(0, 0, 4, 0, 0, 4), out_shape=(1, 3, 16, 16, 20))
Expand Down