From b8d711ed3bb7877630d1dd2b9f47ea78943f36fe Mon Sep 17 00:00:00 2001 From: Qingchao Shen Date: Mon, 26 May 2025 16:38:16 +0800 Subject: [PATCH 1/3] Update pooling.h --- include/tvm/topi/nn/pooling.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/topi/nn/pooling.h b/include/tvm/topi/nn/pooling.h index abe26b6c6727..8e13ae49afdf 100644 --- a/include/tvm/topi/nn/pooling.h +++ b/include/tvm/topi/nn/pooling.h @@ -383,7 +383,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const Array& output_ PrimExpr divide_factor = tvm::cast(x->dtype, 1); for (size_t i = 0; i < n_dim; ++i) { - divide_factor *= tvm::cast(x->dtype, reduce_axes[i]->dom->extent); + divide_factor *= tvm::cast(DataType::Int(32), reduce_axes[i]->dom->extent); } return div(pool_sum(indices), divide_factor); From bc0d883bb0616cfdc3dfda473768649a0bec9b0f Mon Sep 17 00:00:00 2001 From: Qingchao Shen Date: Mon, 26 May 2025 16:53:56 +0800 Subject: [PATCH 2/3] Update test_te_create_primfunc.py --- tests/python/te/test_te_create_primfunc.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/python/te/test_te_create_primfunc.py b/tests/python/te/test_te_create_primfunc.py index 9925f54be4db..537dd2557c2f 100644 --- a/tests/python/te/test_te_create_primfunc.py +++ b/tests/python/te/test_te_create_primfunc.py @@ -882,6 +882,14 @@ def te_workload(): _check_workload(te_workload, tir_workload) +def test_global_pool(): + # fix the issue-17938 + data = te.placeholder((1, 1, 32, 32), dtype='int8', name='data') + op_output = topi.nn.global_pool(data=data, pool_type='avg', layout='NCHW') + f = te.create_prim_func([data, op_output]) + assert f + + def test_nested_reduce_domain_dependency(): @T.prim_func def tir_workload( From 03e2dd1c95bc62a395d4eb698d262f5eeaebf5b8 Mon Sep 17 00:00:00 2001 From: Qingchao Shen Date: Mon, 26 May 2025 18:15:05 +0800 Subject: [PATCH 3/3] fix lint error --- tests/python/te/test_te_create_primfunc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/te/test_te_create_primfunc.py b/tests/python/te/test_te_create_primfunc.py index 537dd2557c2f..b0850a89b5c5 100644 --- a/tests/python/te/test_te_create_primfunc.py +++ b/tests/python/te/test_te_create_primfunc.py @@ -884,8 +884,8 @@ def te_workload(): def test_global_pool(): # fix the issue-17938 - data = te.placeholder((1, 1, 32, 32), dtype='int8', name='data') - op_output = topi.nn.global_pool(data=data, pool_type='avg', layout='NCHW') + data = te.placeholder((1, 1, 32, 32), dtype="int8", name="data") + op_output = topi.nn.global_pool(data=data, pool_type="avg", layout="NCHW") f = te.create_prim_func([data, op_output]) assert f