From 6f6ad82898208fa3c642e1affb5bb46bf77d02f6 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Mon, 27 May 2024 12:34:33 +0800 Subject: [PATCH] [TIR] Fix Shuffle rewrite This PR fixes the shuffle rewrite pass to handle the case where the vector lanes are larger than the data type of the input vector. --- src/target/source/codegen_c.cc | 4 +- src/tir/transforms/storage_rewrite.cc | 2 +- ...ir_transform_pointer_value_type_rewrite.py | 46 +++++++++++++++++-- 3 files changed, 47 insertions(+), 5 deletions(-) diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 009fc1672ace..344d0392d4f6 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -932,7 +932,9 @@ void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) { // NOLINT( } if (op->indices.size() == 1) { // This is an extract element - os << concat_vec[Downcast(op->indices[0])->value]; + int64_t idx = Downcast(op->indices[0])->value; + ICHECK_LT(idx, concat_vec.size()); + os << concat_vec[idx]; } else { // Print the shuffle as vector constructor // vec(e0, e1, e2, .. en) diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 2ebb7671492a..1c3f916a445d 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -1493,7 +1493,7 @@ class VectorTypeRewriter : public StmtExprMutator { arith::ModularSet me = analyzer_.modular_set(last_dim_index); ICHECK(me->coeff == 0 || info.factor() % me->coeff == 0); PrimExpr new_index = last_dim_index / make_const(last_dim_index.dtype(), info.factor()); - shuffle_index = me->base; + shuffle_index = me->base % info.factor(); indices.Set(indices.size() - 1, new_index); } diff --git a/tests/python/tir-transform/test_tir_transform_pointer_value_type_rewrite.py b/tests/python/tir-transform/test_tir_transform_pointer_value_type_rewrite.py index 7baa96c1a16e..186f6bd02ae8 100644 --- a/tests/python/tir-transform/test_tir_transform_pointer_value_type_rewrite.py +++ b/tests/python/tir-transform/test_tir_transform_pointer_value_type_rewrite.py @@ -14,10 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name, missing-docstring + import tvm import tvm.testing -from tvm import te -from tvm.driver.build_module import schedule_to_module from tvm.script import tir as T @@ -25,7 +25,7 @@ class BaseCompare(tvm.testing.CompareBeforeAfter): transform = tvm.tir.transform.PointerValueTypeRewrite() -class TestRewriteToShuffle(BaseCompare): +class TestRewriteToShuffle0(BaseCompare): @T.prim_func def before(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32")): A_local_data = T.allocate([16], "float32", scope="local") @@ -50,6 +50,42 @@ def expected(A: T.Buffer((4,), "float32x4"), B: T.Buffer((4,), "float32")): ) +class TestRewriteToShuffle1(BaseCompare): + @T.prim_func + def before(A: T.Buffer((8,), "float32"), B: T.Buffer((1,), "float32")): + A_local_data = T.allocate([8], "float32", scope="local") + A_local = T.Buffer((8,), "float32", data=A_local_data, scope="local") + A_local[0:4] = A[0:4] + A_local[4:8] = A[4:8] + B[0] = ( + A_local[0] + + A_local[1] + + A_local[2] + + A_local[3] + + A_local[4] + + A_local[5] + + A_local[6] + + A_local[7] + ) + + @T.prim_func + def expected(A: T.Buffer((2,), "float32x4"), B: T.Buffer((1,), "float32")): + A_local_data = T.allocate([2], "float32x4", "local") + A_local = T.Buffer((2,), "float32x4", data=A_local_data, scope="local") + A_local[0] = A[0] + A_local[1] = A[1] + B[0] = ( + T.Shuffle([A_local[0]], [0]) + + T.Shuffle([A_local[0]], [1]) + + T.Shuffle([A_local[0]], [2]) + + T.Shuffle([A_local[0]], [3]) + + T.Shuffle([A_local[1]], [0]) + + T.Shuffle([A_local[1]], [1]) + + T.Shuffle([A_local[1]], [2]) + + T.Shuffle([A_local[1]], [3]) + ) + + class TestAddressOf(BaseCompare): @T.prim_func def before(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): @@ -71,3 +107,7 @@ def before(A: T.Buffer((16,), "float32")): T.evaluate(A[i * 4]) expected = before + + +if __name__ == "__main__": + tvm.testing.main()