From 9d751ec76e3e8628e62f3e121e9ec52ae1870167 Mon Sep 17 00:00:00 2001 From: Qingchao Shen Date: Thu, 22 May 2025 21:37:49 +0800 Subject: [PATCH 1/4] Update transform.py --- python/tvm/topi/transform.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 1ef65230591b..928cb23621cb 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -96,7 +96,8 @@ def _compute(*idxs): axis_index = 0 for i in range(0, len(idxs)): if i not in real_axis: - indices.append(idxs[i]) + dim = tvm.tir.if_then_else(idxs[i] < a.shape[len(indices)], idxs[i], a.shape[len(indices)] - 1) + indices.append(dim) axis_index += 1 return a(*indices) From 38c552fec4adbfa05f99ed4bf81de62967efad81 Mon Sep 17 00:00:00 2001 From: Qingchao Shen Date: Fri, 23 May 2025 09:08:04 +0800 Subject: [PATCH 2/4] Update transform.py --- python/tvm/topi/transform.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 928cb23621cb..faf5315ff4da 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -96,7 +96,9 @@ def _compute(*idxs): axis_index = 0 for i in range(0, len(idxs)): if i not in real_axis: - dim = tvm.tir.if_then_else(idxs[i] < a.shape[len(indices)], idxs[i], a.shape[len(indices)] - 1) + dim = tvm.tir.if_then_else( + idxs[i] < a.shape[len(indices)], idxs[i], a.shape[len(indices)] - 1 + ) indices.append(dim) axis_index += 1 return a(*indices) From 16ca0eeddc0b52b07b9b2738edbd571b6c56bcad Mon Sep 17 00:00:00 2001 From: Qingchao Shen Date: Fri, 23 May 2025 11:47:01 +0800 Subject: [PATCH 3/4] Update transform.py --- python/tvm/topi/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index faf5315ff4da..c45fcd456d1b 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -97,7 +97,7 @@ def _compute(*idxs): for i in range(0, len(idxs)): if i not in real_axis: dim = tvm.tir.if_then_else( - idxs[i] < a.shape[len(indices)], idxs[i], a.shape[len(indices)] - 1 + a.shape[len(indices)] !=1, idxs[i], 0 ) indices.append(dim) axis_index += 1 From 2bc069979a00d6bb47b11ad1c6bb4ebec7b420b8 Mon Sep 17 00:00:00 2001 From: Qingchao Shen Date: Fri, 23 May 2025 13:37:50 +0800 Subject: [PATCH 4/4] Update transform.py --- python/tvm/topi/transform.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index c45fcd456d1b..951944e618ab 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -96,9 +96,7 @@ def _compute(*idxs): axis_index = 0 for i in range(0, len(idxs)): if i not in real_axis: - dim = tvm.tir.if_then_else( - a.shape[len(indices)] !=1, idxs[i], 0 - ) + dim = tvm.tir.if_then_else(a.shape[len(indices)] != 1, idxs[i], 0) indices.append(dim) axis_index += 1 return a(*indices)