diff --git a/src/tir/transforms/inject_permuted_layout.cc b/src/tir/transforms/inject_permuted_layout.cc index a1afbeae6ffe..cccf2c505a51 100644 --- a/src/tir/transforms/inject_permuted_layout.cc +++ b/src/tir/transforms/inject_permuted_layout.cc @@ -19,44 +19,55 @@ /*! * \file inject_permuted_layout.cc - * \brief The pass for inject permuted layout. + * \brief The pass injects permuted layout for shared memory buffers to avoid bank conflicts. */ - #include #include #include #include #include +#include "../../arith/ir_mutator_with_analyzer.h" +#include "../../runtime/thread_storage_scope.h" #include "../../support/utils.h" -#include "../ir/functor_common.h" #include "ir_utils.h" namespace tvm { namespace tir { -using tir::Block; -using tir::BlockRealize; -using tir::Call; -using tir::For; +using namespace arith; +using namespace runtime; -class PermutedLayoutInjector : public StmtExprMutator { +class PermutedLayoutInjector : private IRMutatorWithAnalyzer { public: - PermutedLayoutInjector() {} + static PrimFunc Transform(PrimFunc func) { + Analyzer analyzer; + + auto new_body = PermutedLayoutInjector(func, &analyzer)(func->body); + auto func_node = func.CopyOnWrite(); + func_node->body = new_body; + return func; + } private: - Array GetNewIndices(PrimExpr s0, PrimExpr s1, int smem_width) { - // index after vectorize(8) - PrimExpr i = s0, j = floordiv(s1, 8), v = floormod(s1, 8); - PrimExpr permuted_j; - // In the following comments, each number represent a 8 * fp16 load - // which is correspond to a index (i, j) in line 50's PrimExpr - // Each 8 number correspond to 32 memory bank (every bank has 32 bit): - // 8 * 8 * 16bit = 32 * 32bit - // And we have 32 banks in total, so all loads in one column share - // same memory bank - if (smem_width % 64 == 0) { - // use 8 * 8 permuted + explicit PermutedLayoutInjector(PrimFunc func, Analyzer* analyzer) + : IRMutatorWithAnalyzer(analyzer) { + buffer_map_.insert(func->buffer_map.begin(), func->buffer_map.end()); + } + + using IRMutatorWithAnalyzer::VisitExpr_; + using IRMutatorWithAnalyzer::VisitStmt_; + + Array PermuteIndices(PrimExpr row_idx, PrimExpr col_idx, int row_size) { + ICHECK(permute_); + // Index after vectorizing by 8 + PrimExpr col_idx_outer = floordiv(col_idx, VECTORIZE_FACTOR), + col_idx_inner = floormod(col_idx, VECTORIZE_FACTOR); + PrimExpr new_col_idx_outer; + if (row_size % 64 == 0) { + // Use 8 * 8 permuted layout + // Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read + // Every row below corresponds to 32 banks // 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7 // 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6 // 0 1 2 3 4 5 6 7 ==> 2 3 0 1 6 7 4 5 @@ -65,10 +76,13 @@ class PermutedLayoutInjector : public StmtExprMutator { // 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2 // 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1 // 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0 - PrimExpr permuted_j_mod_8 = (floormod(j, 8) ^ floormod(i, 8)); - permuted_j = floordiv(j, 8) * 8 + permuted_j_mod_8; + auto row_idx_sub = floormod(row_idx, 8); + new_col_idx_outer = col_idx_outer ^ row_idx_sub; } else { - // use 8 * 4 permuted + ICHECK(row_size % 32 == 0); + // Use 8 * 4 permuted layout + // Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read + // Every row below corresponds to 16 banks // 0 1 2 3 ==> 0 1 2 3 // 0 1 2 3 ==> 0 1 2 3 // 0 1 2 3 ==> 1 0 3 2 @@ -77,183 +91,204 @@ class PermutedLayoutInjector : public StmtExprMutator { // 0 1 2 3 ==> 2 3 0 1 // 0 1 2 3 ==> 3 2 1 0 // 0 1 2 3 ==> 3 2 1 0 - // in 8 number each line view: + // View with 8 elements per row: // 0 1 2 3 4 0 1 2 3 ==> 0 1 2 3 0 1 2 3 // 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2 // 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1 // 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0 - permuted_j = floormod(j, 4) ^ floordiv(floormod(i, 8), 2); + auto row_idx_sub = floormod(row_idx, 8); + new_col_idx_outer = col_idx_outer ^ floordiv(row_idx_sub, 2); } - return {s0, permuted_j * 8 + v}; + return {row_idx, analyzer_->Simplify(new_col_idx_outer * 8 + col_idx_inner)}; } - Stmt VisitStmt_(const BlockRealizeNode* _op) final { - BlockRealize br = Downcast(StmtExprMutator::VisitStmt_(_op)); - BlockRealizeNode* op = br.CopyOnWrite(); - if (op->block->annotations.count("permuted_layout") == 0) { - return br; + static bool CheckAnnotation(ObjectRef annotation) { + if (auto* node = annotation.as()) { + // Support string annotation for backward compatibility + return GetRef(node) != ""; + } else if (auto* node = annotation.as()) { + return node->value != 0; + } else { + LOG(FATAL) << "Invalid permuted layout annotation: " << annotation; } - String val = Downcast(op->block->annotations.at("permuted_layout")); - if (val.empty()) return br; - Block blk = op->block; - Stmt body = blk->body; - if (support::StartsWith(val, "g2s")) { - // Case 1. Rewrite global to share.dyn - - // Step 1.1. Handle case when have local stage - // Block with local stage is like - // body { - // SeqStmt { - // seq[0]: local <- global - // seq[1]: shared.dyn <- local - // } - // } - // We only need to rewrite seq[1] - bool have_local_stage = (body.as() != nullptr); - Stmt upper_loop; - if (have_local_stage) { - SeqStmt seq = Downcast(body); - ICHECK(seq->size() == 2); - upper_loop = seq->seq[0]; - body = seq->seq[1]; - } - - // Step 1.2. get inner loop body - std::vector loops; - while (const ForNode* loop = body.as()) { - loops.push_back(loop); - body = loop->body; - } - Optional if_then_else_condition = NullOpt; - const BufferStoreNode* store = body.as(); - if (!store) { - // Case 1.2.1. IfThenElse generated by reverse_compute_inline - // It is always like - // if condition: - // loop_body - // We just extract the inner loop body inside IfThenElseNode - const IfThenElseNode* if_then_else = body.as(); - store = if_then_else->then_case.as(); - ICHECK(!if_then_else->else_case); - if_then_else_condition = if_then_else->condition; - } - ICHECK(store) << body; - - // Step 1.3. Get smem width and refuse to make any difference if invalid - auto smem_width = store->buffer->shape[1].as()->value; - if (smem_width % 32 != 0) { - LOG(WARNING) << "Permuted Layout for " << op->block->name_hint - << " is not supported since its second dimension is not divisible by 32"; - return br; - } - if (smem_width % 64 == 32) { - if (store->buffer->shape[0].as()->value % 2 != 0) { - LOG(WARNING) << "Permuted Layout for " << op->block->name_hint - << " is not supported since its first dimension is not divisible by 2" - << " and second dimension is not divisible by 64"; - return br; - } - } - - // Step 1.4. Set corresponding member variable - if (val.at(4) == 'A') { - smem_width_A_ = smem_width; - } else { - smem_width_B_ = smem_width; - } - - // Step 1.5. Rewrite index - PrimExpr s0 = store->indices[0]; - PrimExpr s1 = store->indices[1]; - Array new_indices = GetNewIndices(s0, s1, smem_width); - // Step 1.6. Create new BlockRealize - Stmt new_body = BufferStore(store->buffer, store->value, new_indices); - if (if_then_else_condition) { - // Case 1.6.1. Add back IfThenElse - new_body = IfThenElse(if_then_else_condition.value(), new_body); - } - for (int i = loops.size() - 1; i >= 0; i--) { - const ForNode* loop = loops[i]; - new_body = For(loop->loop_var, loop->min, loop->extent, loop->kind, new_body, - loop->thread_binding, loop->annotations); - } - if (have_local_stage) { - // Case 1.6.1. Add back local stage - new_body = SeqStmt({upper_loop, new_body}); - } - Block new_blk = Block(blk->iter_vars, blk->reads, blk->writes, blk->name_hint, new_body, - blk->init, blk->alloc_buffers, blk->match_buffers, blk->annotations); - BlockRealize new_br = BlockRealize(op->iter_values, op->predicate, new_blk); - return new_br; - } else if (support::StartsWith(val, "s2l")) { - // Case 2. rewrite share.dyn to local - // Step 2.1. Retrieve previous set member variable - int smem_width = val.at(4) == 'A' ? smem_width_A_ : smem_width_B_; - if (smem_width == -1) { - return br; - } - - // Step 2.2. Rewrite index - // Body of shared.dyn to local is always T.evaluate(T.ptx_ldmatrix(args...)) - // Please refer to the load tensor intrinsic - Evaluate eval = Downcast(body); - Call ldmat_call = Downcast(eval->value); - ICHECK(ldmat_call->args.size() == 7); - Array new_ldmat_args; - // Step 2.2.1. Add unchanged args - for (int i = 0; i < 5; i++) { - new_ldmat_args.push_back(ldmat_call->args[i]); - } - // 5th argument is always a T.tvm_access_ptr call - // Please refer to the load tensor intrinsic - Call accptr_call = Downcast(ldmat_call->args[5]); - PrimExpr smem_offset = ldmat_call->args[6]; - - // Step 2.2.2. Create new access ptr call - Array new_accptr_args; - for (int i = 0; i < 5; i++) { - // 2th args of T.tvm_access_ptr call is offset, we set it to 0 and calculate - // total offset in ldmatrix call - new_accptr_args.push_back(i == 2 ? 0 : accptr_call->args[i]); - } - Call new_accptr_call = Call(accptr_call->dtype, accptr_call->op, new_accptr_args); - new_ldmat_args.push_back(new_accptr_call); - - // Step 2.2.3. Calculate new offset - // We convert offset to 2-dimension, reindex it and convert it back - PrimExpr accptr_offset = accptr_call->args[2]; - PrimExpr offset = smem_offset + accptr_offset; - PrimExpr s0 = floordiv(offset, smem_width), s1 = floormod(offset, smem_width); - Array new_indices = GetNewIndices(s0, s1, smem_width); - PrimExpr new_offset = new_indices[0] * smem_width + new_indices[1]; - new_ldmat_args.push_back(new_offset); - // Step 2.2.4. Rewrite the rest part - Call new_ldmat_call = Call(ldmat_call->dtype, ldmat_call->op, new_ldmat_args); - Stmt new_body = Evaluate(new_ldmat_call); - Block new_blk = Block(blk->iter_vars, blk->reads, blk->writes, blk->name_hint, new_body, - blk->init, blk->alloc_buffers, blk->match_buffers, blk->annotations); - BlockRealize new_br = BlockRealize(op->iter_values, op->predicate, new_blk); - return new_br; + } + + Stmt VisitStmt_(const BlockNode* op) final { + // Record the mapping from buffer data var to buffer for later lookup + for (auto buffer : op->alloc_buffers) { + buffer_map_.insert({buffer->data, buffer}); + } + for (auto match_buffer : op->match_buffers) { + buffer_map_.insert({match_buffer->buffer->data, match_buffer->buffer}); + } + + if (op->annotations.count("permuted_layout") == 0 || + !CheckAnnotation(op->annotations.at("permuted_layout"))) { + return IRMutatorWithAnalyzer::VisitStmt_(op); } - return StmtExprMutator::VisitStmt_(op); + auto prev_permute = permute_; + permute_ = true; + + Block block = Downcast(IRMutatorWithAnalyzer::VisitStmt_(op)); + + permute_ = prev_permute; + + // Erase the permuted_layout annotation after the pass + auto block_node = block.CopyOnWrite(); + block_node->annotations.erase("permuted_layout"); + return block; } - int smem_width_A_ = -1; - int smem_width_B_ = -1; -}; + int CheckAndGetBufferRowSize(Buffer buffer) { + CHECK(buffer->shape.size() >= 2) + << "The dimension of Buffer \"" << buffer->name << "\" with shape " << buffer->shape + << " should be at least 2"; -PrimFunc InjectPermutedLayout(PrimFunc func) { - auto fptr = func.CopyOnWrite(); - fptr->body = PermutedLayoutInjector()(std::move(fptr->body)); - return func; -} + auto dim = buffer->shape.size(); + auto buffer_row_size = buffer->shape[dim - 1].as()->value; + auto buffer_col_size = buffer->shape[dim - 2].as()->value; + + if (buffer_row_size % 64 != 0) { + CHECK(buffer_row_size % 32 == 0) + << "Permuted Layout for Buffer \"" << buffer->name << "\" with shape " << buffer->shape + << " is not supported since its second dimension is not divisible by 32"; + CHECK(buffer_col_size % 2 == 0) + << "Permuted Layout for Buffer \"" << buffer->name << "\" with shape " << buffer->shape + << " is not supported since its first dimension is not divisible by 2 and second " + "dimension is not divisible by 64"; + } + + return buffer_row_size; + } + + Array HandleBufferIndices(Buffer buffer, Array indices) { + auto buffer_row_size = CheckAndGetBufferRowSize(buffer); + + // Mutate the last two indices + auto indices_size = indices.size(); + PrimExpr row_idx = indices[indices_size - 2]; + PrimExpr col_idx = indices[indices_size - 1]; + auto new_indices = PermuteIndices(row_idx, col_idx, buffer_row_size); + indices.Set(indices_size - 2, new_indices[0]); + indices.Set(indices_size - 1, new_indices[1]); + return indices; + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + // Rewrite write from global to shared.dyn or shared + // We assume the shape of the shared memory is [..., row_size, col_size], + // where row_size is divisible by 64, or divisible by 32 and col_size is divisible by 2. + auto store = Downcast(IRMutatorWithAnalyzer::VisitStmt_(op)); + + if (!permute_ || store->buffer->shape.size() < 2) { + return store; + } + + auto scope = StorageScope::Create(GetPtrStorageScope(store->buffer->data)); + if (scope.rank != StorageRank::kShared) { + return store; + } + + auto store_node = store.CopyOnWrite(); + store_node->indices = HandleBufferIndices(store_node->buffer, store_node->indices); + return store; + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + // Rewrite load from shared or shared.dyn to global + auto load = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); + + if (!permute_ || load->buffer->shape.size() < 2) { + return load; + } + + auto scope = StorageScope::Create(GetPtrStorageScope(load->buffer->data)); + if (scope.rank != StorageRank::kShared) { + return load; + } + + auto load_node = load.CopyOnWrite(); + load_node->indices = HandleBufferIndices(load_node->buffer, load_node->indices); + return load; + } + + PrimExpr HandleAccessPtrAndOffset(PrimExpr access_ptr, Optional offset = NullOpt) { + // The 2th arg of T.tvm_access_ptr call is offset, we set it to 0 and accumulate it to + // smem_offset + CHECK(access_ptr->IsInstance()) + << "Invalid access ptr for permuted layout: " << access_ptr; + auto access_ptr_call = Downcast(access_ptr); + CHECK(access_ptr_call->op.same_as(builtin::tvm_access_ptr())) + << "Invalid access ptr for permuted layout: " << access_ptr; + + auto buffer_map_iter = buffer_map_.find(Downcast(access_ptr_call->args[1])); + CHECK(buffer_map_iter != buffer_map_.end()) + << "The buffer corresponding to data Var " << access_ptr_call->args[1] << " is not found"; + int buffer_row_size = CheckAndGetBufferRowSize(buffer_map_iter->second); + + PrimExpr smem_offset = access_ptr_call->args[2] + (offset.defined() ? offset.value() : 0); + + // Convert offset to 2-dimension, reindex it and convert it back + PrimExpr row_idx = floordiv(smem_offset, buffer_row_size); + PrimExpr col_idx = floormod(smem_offset, buffer_row_size); + + auto new_indices = PermuteIndices(row_idx, col_idx, buffer_row_size); + auto new_offset = analyzer_->Simplify(new_indices[0] * buffer_row_size + new_indices[1]); + + auto new_access_ptr = access_ptr_call.CopyOnWrite(); + new_access_ptr->args.Set(2, new_offset); + return access_ptr_call; + } + + PrimExpr VisitExpr_(const CallNode* op) final { + // Rewrite from/to shared or shared.dyn to/from local + auto call = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); + + if (!permute_) { + return call; + } + + if (!call->op.same_as(builtin::ptx_ldmatrix()) && !call->op.same_as(builtin::mma_store())) { + return call; + } + + if (call->op.same_as(builtin::ptx_ldmatrix())) { + // form: T.ptx_ldmatrix(..., smem_ptr, smem_offset) + // smem_ptr: T.tvm_access_ptr(ptype, data, offset, extent, rw_mask) + auto access_ptr = call->args[5]; + PrimExpr smem_offset = call->args[6]; + auto new_access_ptr = HandleAccessPtrAndOffset(access_ptr, smem_offset); + auto new_call = call.CopyOnWrite(); + new_call->args.Set(5, new_access_ptr); + new_call->args.Set(6, IntImm(smem_offset->dtype, 0)); + return call; + } else if (call->op.same_as(builtin::mma_store())) { + // TODO(yixin): mma_store is not fully tested yet + // because we will directly store result to Buffer instead of calling mma_store now + auto access_ptr = call->args[2]; + auto new_access_ptr = HandleAccessPtrAndOffset(access_ptr); + auto new_call = call.CopyOnWrite(); + new_call->args.Set(2, new_access_ptr); + return call; + } else { + LOG(FATAL) << "Invalid call node: " << call; + } + } + + static constexpr size_t VECTORIZE_FACTOR = 8; + static constexpr size_t BANK_SIZE_BYTES = 128; + + // Mapping from data Var of a Buffer to Buffer, for lookup + std::unordered_map buffer_map_; + bool permute_ = false; +}; namespace transform { Pass InjectPermutedLayout() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - return InjectPermutedLayout(std::move(f)); + return PermutedLayoutInjector::Transform(std::move(f)); }; return CreatePrimFuncPass(pass_func, 0, "tir.InjectPermutedLayout", {}); } diff --git a/tests/python/unittest/test_tir_transform_inject_permuted_layout.py b/tests/python/unittest/test_tir_transform_inject_permuted_layout.py new file mode 100644 index 000000000000..6495cdb2bd54 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_inject_permuted_layout.py @@ -0,0 +1,351 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +from tvm import IRModule +from tvm.script import tir as T +from tvm.tir import PrimFunc + + +def _check_primfunc_transform(before: PrimFunc, expected: PrimFunc): + before_module = IRModule.from_expr(before) + after_module = tvm.tir.transform.InjectPermutedLayout()(before_module) + + after = after_module["before"].without_attr("global_symbol") + expected = expected.without_attr("global_symbol") + + tvm.ir.assert_structural_equal(after, expected) + + +# This pass is adapted from another previous pass, so we need to ensure backward compatibility here +def test_backward_compatibility_shared_a(): + # fmt: off + @T.prim_func + def before(X: T.Buffer((4096, 4096), "float16")): + # with T.block("root"): + for blockIdx_y in T.thread_binding(256, thread="blockIdx.y"): + for threadIdx_y in T.thread_binding(4, thread="threadIdx.y"): + for threadIdx_x in T.thread_binding(32, thread="threadIdx.x"): + with T.block(""): + T.reads(X[blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4:blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4 + 97, threadIdx_x % 4 * 8:threadIdx_x % 4 * 8 + 4072]) + T.writes() + for ax2_0_0 in range(128): + with T.block(""): + T.reads(X[blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4:blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4 + 97, ax2_0_0 * 32 + threadIdx_x % 4 * 8:ax2_0_0 * 32 + threadIdx_x % 4 * 8 + 8]) + T.writes() + X_reindex_shared_dyn = T.alloc_buffer((128, 32), "float16", strides=(32, 1), scope="shared.dyn") + with T.block("X_reindex_shared.dyn"): + T.reads(X[blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4:blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4 + 97, ax2_0_0 * 32 + threadIdx_x % 4 * 8:ax2_0_0 * 32 + threadIdx_x % 4 * 8 + 8]) + T.writes(X_reindex_shared_dyn[threadIdx_y * 8 + threadIdx_x // 4:threadIdx_y * 8 + threadIdx_x // 4 + 97, threadIdx_x % 4 * 8:threadIdx_x % 4 * 8 + 8]) + T.block_attr({"permuted_layout": "g2s_A"}) + for ax0_ax1_fused_0 in range(4): + for ax0_ax1_fused_3 in T.vectorized(8): + X_reindex_shared_dyn[ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x // 4, threadIdx_x % 4 * 8 + ax0_ax1_fused_3] = X[blockIdx_y // 8 * 128 + ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x // 4, ax2_0_0 * 32 + threadIdx_x % 4 * 8 + ax0_ax1_fused_3] + for ax2_0_1 in range(4): + with T.block(""): + T.reads(X_reindex_shared_dyn[threadIdx_y // 2 * 64:threadIdx_y // 2 * 64 + 64, ax2_0_1 * 8:ax2_0_1 * 8 + 8]) + T.writes() + X_reindex_shared_dyn_m16n8k8_matrixA = T.alloc_buffer((64, 8), "float16", scope="m16n8k8.matrixA") + for ax0_0, ax1_0 in T.grid(2, 1): + with T.block("X_reindex_shared.dyn_m16n8k8.matrixA_o"): + T.reads(X_reindex_shared_dyn[threadIdx_y // 2 * 64 + ax0_0 * 32:threadIdx_y // 2 * 64 + ax0_0 * 32 + 32, ax2_0_1 * 8:ax2_0_1 * 8 + 8]) + T.writes(X_reindex_shared_dyn_m16n8k8_matrixA[ax0_0 * 32:ax0_0 * 32 + 32, 0:8]) + T.block_attr({"permuted_layout": "s2l_A"}) + T.ptx_ldmatrix("float16", T.bool(False), 4, ".b16", X_reindex_shared_dyn_m16n8k8_matrixA.data, ax0_0 * 8, T.tvm_access_ptr(T.type_annotation("float16"), X_reindex_shared_dyn.data, threadIdx_y // 2 * 2048 + ax0_0 * 1024 + ax2_0_1 * 8, 1024, 1), threadIdx_x * 32) + + @T.prim_func + def expected(X: T.Buffer((4096, 4096), "float16")): + for blockIdx_y in T.thread_binding(256, thread="blockIdx.y"): + for threadIdx_y in T.thread_binding(4, thread="threadIdx.y"): + for threadIdx_x in T.thread_binding(32, thread="threadIdx.x"): + with T.block(""): + for ax2_0_0 in T.serial(128): + with T.block(""): + X_reindex_shared_dyn = T.alloc_buffer((128, 32), "float16", strides=(32, 1), scope="shared.dyn") + with T.block("X_reindex_shared.dyn"): + # annotate the reads and writes because they cannot be inferred from tir.bitwise_xor + T.reads(X[blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4:blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4 + 97, ax2_0_0 * 32 + threadIdx_x % 4 * 8:ax2_0_0 * 32 + threadIdx_x % 4 * 8 + 8]) + T.writes(X_reindex_shared_dyn[threadIdx_y * 8 + threadIdx_x // 4:threadIdx_y * 8 + threadIdx_x // 4 + 97, threadIdx_x % 4 * 8:threadIdx_x % 4 * 8 + 8]) + for ax0_ax1_fused_0 in range(4): + for ax0_ax1_fused_3 in T.vectorized(8): + X_reindex_shared_dyn[ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x // 4, T.bitwise_xor(threadIdx_x % 4, threadIdx_x // 8) * 8 + ax0_ax1_fused_3] = X[blockIdx_y // 8 * 128 + ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x // 4, ax2_0_0 * 32 + threadIdx_x % 4 * 8 + ax0_ax1_fused_3] + for ax2_0_1 in T.serial(4): + with T.block(""): + X_reindex_shared_dyn_m16n8k8_matrixA = T.alloc_buffer((64, 8), "float16", scope="m16n8k8.matrixA") + for ax0_0, ax1_0 in T.grid(2, 1): + with T.block("X_reindex_shared.dyn_m16n8k8.matrixA_o"): + T.reads(X_reindex_shared_dyn[threadIdx_y // 2 * 64 + ax0_0 * 32:threadIdx_y // 2 * 64 + ax0_0 * 32 + 32, ax2_0_1 * 8:ax2_0_1 * 8 + 8]) + T.writes(X_reindex_shared_dyn_m16n8k8_matrixA[ax0_0 * 32:ax0_0 * 32 + 32, 0:8]) + T.ptx_ldmatrix("float16", T.bool(False), 4, ".b16", X_reindex_shared_dyn_m16n8k8_matrixA.data, ax0_0 * 8, T.tvm_access_ptr(T.type_annotation("float16"), X_reindex_shared_dyn.data, threadIdx_y // 2 * 2048 + ax0_0 * 1024 + threadIdx_x * 32 + T.bitwise_xor(ax2_0_1, threadIdx_x % 8 // 2) * 8, 1024, 1), 0) + # fmt: on + _check_primfunc_transform(before, expected) + + +def test_backward_compatibility_shared_a_and_b(): + # fmt: off + @T.prim_func + def before(X: T.Buffer((4096, 4096), "float16"), Y: T.Buffer((4096, 4096), "float16")): + for blockIdx_x in T.thread_binding(4, thread="blockIdx.x"): + for blockIdx_y in T.thread_binding(256, thread="blockIdx.y"): + for threadIdx_y in T.thread_binding(4, thread="threadIdx.y"): + for threadIdx_x in T.thread_binding(32, thread="threadIdx.x"): + with T.block(""): + for ax2_0_0 in T.serial(128): + with T.block(""): + X_reindex_shared_dyn = T.alloc_buffer((128, 32), "float16", strides=(32, 1), scope="shared.dyn") + Y_reindex_shared_dyn = T.alloc_buffer((32, 128), "float16", strides=(128, 1), scope="shared.dyn") + with T.block("X_reindex_shared.dyn"): + T.block_attr({"permuted_layout": "g2s_A"}) + for ax0_ax1_fused_0 in range(4): + for ax0_ax1_fused_3 in T.vectorized(8): + X_reindex_shared_dyn[ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x // 4, threadIdx_x % 4 * 8 + ax0_ax1_fused_3] = X[blockIdx_y // 8 * 128 + ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x // 4, ax2_0_0 * 32 + threadIdx_x % 4 * 8 + ax0_ax1_fused_3] + with T.block("Y_reindex_shared.dyn"): + T.block_attr({"permuted_layout": "g2s_B"}) + for ax0_ax1_fused_0 in range(4): + for ax0_ax1_fused_3 in T.vectorized(8): + Y_reindex_shared_dyn[ax0_ax1_fused_0 * 8 + threadIdx_y * 2 + threadIdx_x // 16, threadIdx_x % 16 * 8 + ax0_ax1_fused_3] = Y[ax2_0_0 * 32 + ax0_ax1_fused_0 * 8 + threadIdx_y * 2 + threadIdx_x // 16, blockIdx_x * 1024 + blockIdx_y % 8 * 128 + threadIdx_x % 16 * 8 + ax0_ax1_fused_3] + for ax2_0_1 in T.serial(4): + with T.block(""): + X_reindex_shared_dyn_m16n8k8_matrixA = T.alloc_buffer((64, 8), "float16", scope="m16n8k8.matrixA") + Y_reindex_shared_dyn_m16n8k8_matrixB = T.alloc_buffer((8, 64), "float16", scope="m16n8k8.matrixB") + for ax0_0, ax1_0 in T.grid(2, 1): + with T.block("X_reindex_shared.dyn_m16n8k8.matrixA_o"): + T.reads(X_reindex_shared_dyn[threadIdx_y // 2 * 64 + ax0_0 * 32:threadIdx_y // 2 * 64 + ax0_0 * 32 + 32, ax2_0_1 * 8:ax2_0_1 * 8 + 8]) + T.writes(X_reindex_shared_dyn_m16n8k8_matrixA[ax0_0 * 32:ax0_0 * 32 + 32, 0:8]) + T.block_attr({"permuted_layout": "s2l_A"}) + T.ptx_ldmatrix("float16", T.bool(False), 4, ".b16", X_reindex_shared_dyn_m16n8k8_matrixA.data, ax0_0 * 8, T.tvm_access_ptr(T.type_annotation("float16"), X_reindex_shared_dyn.data, threadIdx_y // 2 * 2048 + ax0_0 * 1024 + ax2_0_1 * 8, 1024, 1), threadIdx_x * 32) + for ax0_0, ax1_0 in T.grid(1, 2): + with T.block("Y_reindex_shared.dyn_m16n8k8.matrixB_o"): + T.reads(Y_reindex_shared_dyn[ax2_0_1 * 8:ax2_0_1 * 8 + 8, threadIdx_y % 2 * 64 + ax1_0 * 32:threadIdx_y % 2 * 64 + ax1_0 * 32 + 32]) + T.writes(Y_reindex_shared_dyn_m16n8k8_matrixB[0:8, ax1_0 * 32:ax1_0 * 32 + 32]) + T.block_attr({"permuted_layout": "s2l_B"}) + T.ptx_ldmatrix("float16", T.bool(True), 4, ".b16", Y_reindex_shared_dyn_m16n8k8_matrixB.data, ax1_0 * 8, T.tvm_access_ptr(T.type_annotation("float16"), Y_reindex_shared_dyn.data, ax2_0_1 * 1024 + threadIdx_y % 2 * 64 + ax1_0 * 32, 1024, 1), threadIdx_x % 8 * 128 + threadIdx_x // 8 * 8) + + @T.prim_func + def expected(X: T.Buffer((4096, 4096), "float16"), Y: T.Buffer((4096, 4096), "float16")): + for blockIdx_x in T.thread_binding(4, thread="blockIdx.x"): + for blockIdx_y in T.thread_binding(256, thread="blockIdx.y"): + for threadIdx_y in T.thread_binding(4, thread="threadIdx.y"): + for threadIdx_x in T.thread_binding(32, thread="threadIdx.x"): + with T.block(""): + T.reads(X[blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4:blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4 + 97, threadIdx_x % 4 * 8:threadIdx_x % 4 * 8 + 4072], Y[threadIdx_y * 2 + threadIdx_x // 16:threadIdx_y * 2 + threadIdx_x // 16 + 4089, blockIdx_x * 1024 + blockIdx_y % 8 * 128 + threadIdx_x % 16 * 8:blockIdx_x * 1024 + blockIdx_y % 8 * 128 + threadIdx_x % 16 * 8 + 8]) + T.writes() + for ax2_0_0 in T.serial(128): + with T.block(""): + T.reads(X[blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4:blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4 + 97, ax2_0_0 * 32 + threadIdx_x % 4 * 8:ax2_0_0 * 32 + threadIdx_x % 4 * 8 + 8], Y[ax2_0_0 * 32 + threadIdx_y * 2 + threadIdx_x // 16:ax2_0_0 * 32 + threadIdx_y * 2 + threadIdx_x // 16 + 25, blockIdx_x * 1024 + blockIdx_y % 8 * 128 + threadIdx_x % 16 * 8:blockIdx_x * 1024 + blockIdx_y % 8 * 128 + threadIdx_x % 16 * 8 + 8]) + T.writes() + X_reindex_shared_dyn = T.alloc_buffer((128, 32), "float16", strides=(32, 1), scope="shared.dyn") + Y_reindex_shared_dyn = T.alloc_buffer((32, 128), "float16", strides=(128, 1), scope="shared.dyn") + with T.block("X_reindex_shared.dyn"): + T.reads(X[blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4:blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4 + 97, ax2_0_0 * 32 + threadIdx_x % 4 * 8:ax2_0_0 * 32 + threadIdx_x % 4 * 8 + 8]) + T.writes(X_reindex_shared_dyn[threadIdx_y * 8 + threadIdx_x // 4:threadIdx_y * 8 + threadIdx_x // 4 + 97, threadIdx_x % 4 * 8:threadIdx_x % 4 * 8 + 8]) + for ax0_ax1_fused_0 in range(4): + for ax0_ax1_fused_3 in T.vectorized(8): + X_reindex_shared_dyn[ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x // 4, T.bitwise_xor(threadIdx_x % 4, threadIdx_x // 8) * 8 + ax0_ax1_fused_3] = X[blockIdx_y // 8 * 128 + ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x // 4, ax2_0_0 * 32 + threadIdx_x % 4 * 8 + ax0_ax1_fused_3] + with T.block("Y_reindex_shared.dyn"): + T.reads(Y[ax2_0_0 * 32 + threadIdx_y * 2 + threadIdx_x // 16:ax2_0_0 * 32 + threadIdx_y * 2 + threadIdx_x // 16 + 25, blockIdx_x * 1024 + blockIdx_y % 8 * 128 + threadIdx_x % 16 * 8:blockIdx_x * 1024 + blockIdx_y % 8 * 128 + threadIdx_x % 16 * 8 + 8]) + T.writes(Y_reindex_shared_dyn[threadIdx_y * 2 + threadIdx_x // 16:threadIdx_y * 2 + threadIdx_x // 16 + 25, threadIdx_x % 16 * 8:threadIdx_x % 16 * 8 + 8]) + for ax0_ax1_fused_0 in range(4): + for ax0_ax1_fused_3 in T.vectorized(8): + Y_reindex_shared_dyn[ax0_ax1_fused_0 * 8 + threadIdx_y * 2 + threadIdx_x // 16, T.bitwise_xor(threadIdx_x % 16, threadIdx_y * 2 + threadIdx_x // 16) * 8 + ax0_ax1_fused_3] = Y[ax2_0_0 * 32 + ax0_ax1_fused_0 * 8 + threadIdx_y * 2 + threadIdx_x // 16, blockIdx_x * 1024 + blockIdx_y % 8 * 128 + threadIdx_x % 16 * 8 + ax0_ax1_fused_3] + for ax2_0_1 in T.serial(4): + with T.block(""): + X_reindex_shared_dyn_m16n8k8_matrixA = T.alloc_buffer((64, 8), "float16", scope="m16n8k8.matrixA") + Y_reindex_shared_dyn_m16n8k8_matrixB = T.alloc_buffer((8, 64), "float16", scope="m16n8k8.matrixB") + for ax0_0, ax1_0 in T.grid(2, 1): + with T.block("X_reindex_shared.dyn_m16n8k8.matrixA_o"): + T.reads(X_reindex_shared_dyn[threadIdx_y // 2 * 64 + ax0_0 * 32:threadIdx_y // 2 * 64 + ax0_0 * 32 + 32, ax2_0_1 * 8:ax2_0_1 * 8 + 8]) + T.writes(X_reindex_shared_dyn_m16n8k8_matrixA[ax0_0 * 32:ax0_0 * 32 + 32, 0:8]) + T.ptx_ldmatrix("float16", T.bool(False), 4, ".b16", X_reindex_shared_dyn_m16n8k8_matrixA.data, ax0_0 * 8, T.tvm_access_ptr(T.type_annotation("float16"), X_reindex_shared_dyn.data, threadIdx_y // 2 * 2048 + ax0_0 * 1024 + threadIdx_x * 32 + T.bitwise_xor(ax2_0_1, threadIdx_x % 8 // 2) * 8, 1024, 1), 0) + for ax0_0, ax1_0 in T.grid(1, 2): + with T.block("Y_reindex_shared.dyn_m16n8k8.matrixB_o"): + T.reads(Y_reindex_shared_dyn[ax2_0_1 * 8:ax2_0_1 * 8 + 8, threadIdx_y % 2 * 64 + ax1_0 * 32:threadIdx_y % 2 * 64 + ax1_0 * 32 + 32]) + T.writes(Y_reindex_shared_dyn_m16n8k8_matrixB[0:8, ax1_0 * 32:ax1_0 * 32 + 32]) + T.ptx_ldmatrix("float16", T.bool(True), 4, ".b16", Y_reindex_shared_dyn_m16n8k8_matrixB.data, ax1_0 * 8, T.tvm_access_ptr(T.type_annotation("float16"), Y_reindex_shared_dyn.data, ax2_0_1 * 1024 + threadIdx_x % 8 * 128 + T.bitwise_xor(threadIdx_y % 2 * 8 + ax1_0 * 4 + threadIdx_x // 8, threadIdx_x % 8) * 8, 1024, 1), 0) + # fmt: on + _check_primfunc_transform(before, expected) + + +def test_buffer_a(): + # fmt: off + @T.prim_func + def before(p_A: T.handle): + A = T.match_buffer(p_A, (T.int64(128), T.int64(32)), "float16") + A_shared_dyn = T.alloc_buffer((T.int64(128), T.int64(32)), "float16", scope="shared.dyn") + A_warp = T.alloc_buffer((T.int64(4), T.int64(1), T.int64(32), T.int64(8)), "float16", scope="warp") + for threadIdx_z in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for threadIdx_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for threadIdx_x in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for v0 in range(T.int64(4)): + for v1 in T.vectorized(T.int64(8)): + with T.block("A_reindex_shared.dyn"): + T.block_attr({"permuted_layout": 1}) + A_shared_dyn[ + v0 * T.int64(32) + threadIdx_z * T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4), + threadIdx_x % T.int64(4) * T.int64(8) + v1 + ] = A[ + (v0 * T.int64(32) + threadIdx_z * T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4)) % T.int64(32), + threadIdx_x % T.int64(4) * T.int64(8) + v1 + ] + for v0, v1 in T.grid(T.int64(2), T.int64(4)): + with T.block("A_reindex_shared.dyn_warp_o"): + T.block_attr({"permuted_layout": 1}) + with T.block("A_reindex_shared.dyn_warp_o"): + T.reads(A_shared_dyn[threadIdx_z * T.int64(64) + v1 * T.int64(16):threadIdx_z * T.int64(64) + v1 * T.int64(16) + T.int64(16), v0 * T.int64(16):v0 * T.int64(16) + T.int64(16)]) + T.writes(A_warp[v1, T.int64(0), T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + T.ptx_ldmatrix("float16", T.bool(False), 4, ".b16", + A_warp.data, + v1 * T.int64(256) + threadIdx_x * T.int64(8), + T.tvm_access_ptr(T.type_annotation("float16"), + A_shared_dyn.data, + threadIdx_z * T.int64(2048) + v1 * T.int64(512) + v0 * T.int64(16), T.int64(512), + 1 + ), + threadIdx_x % T.int64(16) * T.int64(32) + threadIdx_x // T.int64(16) * T.int64(8) + ) + + @T.prim_func + def expected(A: T.Buffer((T.int64(128), T.int64(32)), "float16")): + A_shared_dyn = T.alloc_buffer((T.int64(128), T.int64(32)), "float16", scope="shared.dyn") + A_warp = T.alloc_buffer((T.int64(4), T.int64(1), T.int64(32), T.int64(8)), "float16", scope="warp") + for threadIdx_z in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for threadIdx_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for threadIdx_x in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for v0 in range(T.int64(4)): + for v1 in T.vectorized(T.int64(8)): + with T.block("A_reindex_shared.dyn"): + T.reads(A[(v0 * T.int64(32) + threadIdx_z * T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4)) % T.int64(32), threadIdx_x % T.int64(4) * T.int64(8) + v1]) + T.writes(A_shared_dyn[v0 * T.int64(32) + threadIdx_z * T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4), threadIdx_x % T.int64(4) * T.int64(8) + v1]) + A_shared_dyn[v0 * T.int64(32) + threadIdx_z * T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4), T.bitwise_xor(threadIdx_x % T.int64(4), threadIdx_x // T.int64(8)) * T.int64(8) + v1] = A[(v0 * T.int64(32) + threadIdx_z * T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4)) % T.int64(32), threadIdx_x % T.int64(4) * T.int64(8) + v1] + for v0, v1 in T.grid(T.int64(2), T.int64(4)): + with T.block("A_reindex_shared.dyn_warp_o"): + T.reads(A_shared_dyn[threadIdx_z * T.int64(64) + v1 * T.int64(16):threadIdx_z * T.int64(64) + v1 * T.int64(16) + T.int64(16), v0 * T.int64(16):v0 * T.int64(16) + T.int64(16)]) + T.writes(A_warp[v1, T.int64(0), T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + with T.block("A_reindex_shared.dyn_warp_o"): + T.reads(A_shared_dyn[threadIdx_z * T.int64(64) + v1 * T.int64(16):threadIdx_z * T.int64(64) + v1 * T.int64(16) + T.int64(16), v0 * T.int64(16):v0 * T.int64(16) + T.int64(16)]) + T.writes(A_warp[v1, T.int64(0), T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + T.ptx_ldmatrix("float16", T.bool(False), 4, ".b16", A_warp.data, v1 * T.int64(256) + threadIdx_x * T.int64(8), T.tvm_access_ptr(T.type_annotation("float16"), A_shared_dyn.data, threadIdx_z * T.int64(2048) + v1 * T.int64(512) + threadIdx_x % T.int64(16) * T.int64(32) + T.bitwise_xor(v0 * T.int64(2) + threadIdx_x // T.int64(16), threadIdx_x % T.int64(8) // T.int64(2)) * T.int64(8), T.int64(512), 1), T.int64(0)) + + # fmt: on + _check_primfunc_transform(before, expected) + + +def test_buffer_b(): + # fmt: off + @T.prim_func + def before(B: T.Buffer((T.int64(128), T.int64(32)), "float16")): + B_shared_dyn = T.alloc_buffer((T.int64(128), T.int64(32)), "float16", scope="shared.dyn") + for threadIdx_z in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for threadIdx_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for threadIdx_x in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for v0 in range(T.int64(4)): + for v1 in T.vectorized(T.int64(8)): + with T.block("B_reindex_shared.dyn"): + T.block_attr({"permuted_layout": 1}) + B_shared_dyn[v0 * T.int64(32) + threadIdx_z * T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4), threadIdx_x % T.int64(4) * T.int64(8) + v1] = B[v0 * T.int64(32) + threadIdx_z * T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4), threadIdx_x % T.int64(4) * T.int64(8) + v1] + for v0 in range(T.int64(2)): + with T.block(""): + B_warp = T.alloc_buffer((T.int64(4), T.int64(1), T.int64(32), T.int64(8)), "float16", scope="warp") + for v1 in range(T.int64(4)): + with T.block("B_reindex_shared.dyn_warp_o"): + T.block_attr({"permuted_layout": 1}) + with T.block("B_reindex_shared.dyn_warp_o"): + T.reads(B_shared_dyn[threadIdx_y * T.int64(64) + v1 * T.int64(16):threadIdx_y * T.int64(64) + v1 * T.int64(16) + T.int64(16), v0 * T.int64(16):v0 * T.int64(16) + T.int64(16)]) + T.writes(B_warp[v1, T.int64(0), T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + T.ptx_ldmatrix("float16", T.bool(False), 4, ".b16", B_warp.data, v1 * T.int64(256) + threadIdx_x * T.int64(8), T.tvm_access_ptr(T.type_annotation("float16"), B_shared_dyn.data, threadIdx_y * T.int64(2048) + v1 * T.int64(512) + v0 * T.int64(16), T.int64(512), 1), threadIdx_x // T.int64(16) * T.int64(256) + threadIdx_x % T.int64(8) * T.int64(32) + threadIdx_x % T.int64(16) // T.int64(8) * T.int64(8)) + + @T.prim_func + def expected(B: T.Buffer((T.int64(128), T.int64(32)), "float16")): + B_shared_dyn = T.alloc_buffer((T.int64(128), T.int64(32)), "float16", scope="shared.dyn") + for threadIdx_z in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for threadIdx_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for threadIdx_x in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for v0 in range(T.int64(4)): + for v1 in T.vectorized(T.int64(8)): + with T.block("B_reindex_shared.dyn"): + T.reads(B[v0 * T.int64(32) + threadIdx_z * T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4), threadIdx_x % T.int64(4) * T.int64(8) + v1]) + T.writes(B_shared_dyn[v0 * T.int64(32) + threadIdx_z * T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4), threadIdx_x % T.int64(4) * T.int64(8) + v1]) + B_shared_dyn[v0 * T.int64(32) + threadIdx_z * T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4), T.bitwise_xor(threadIdx_x % T.int64(4), threadIdx_x // T.int64(8)) * T.int64(8) + v1] = B[v0 * T.int64(32) + threadIdx_z * T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4), threadIdx_x % T.int64(4) * T.int64(8) + v1] + for v0 in range(T.int64(2)): + with T.block(""): + B_warp = T.alloc_buffer((T.int64(4), T.int64(1), T.int64(32), T.int64(8)), "float16", scope="warp") + for v1 in range(T.int64(4)): + with T.block("B_reindex_shared.dyn_warp_o"): + T.reads(B_shared_dyn[threadIdx_y * T.int64(64) + v1 * T.int64(16):threadIdx_y * T.int64(64) + v1 * T.int64(16) + T.int64(16), v0 * T.int64(16):v0 * T.int64(16) + T.int64(16)]) + T.writes(B_warp[v1, T.int64(0), T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + with T.block("B_reindex_shared.dyn_warp_o"): + T.reads(B_shared_dyn[threadIdx_y * T.int64(64) + v1 * T.int64(16):threadIdx_y * T.int64(64) + v1 * T.int64(16) + T.int64(16), v0 * T.int64(16):v0 * T.int64(16) + T.int64(16)]) + T.writes(B_warp[v1, T.int64(0), T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + T.ptx_ldmatrix("float16", T.bool(False), 4, ".b16", B_warp.data, v1 * T.int64(256) + threadIdx_x * T.int64(8), T.tvm_access_ptr(T.type_annotation("float16"), B_shared_dyn.data, threadIdx_y * T.int64(2048) + v1 * T.int64(512) + threadIdx_x // T.int64(16) * T.int64(256) + threadIdx_x % T.int64(8) * T.int64(32) + T.bitwise_xor(v0 * T.int64(2) + threadIdx_x % T.int64(16) // T.int64(8), threadIdx_x % T.int64(8) // T.int64(2)) * T.int64(8), T.int64(512), 1), T.int64(0)) + + # fmt: on + _check_primfunc_transform(before, expected) + + +def test_buffer_c_fp32(): + # fmt: off + @T.prim_func + def before(p_O: T.handle): + O = T.match_buffer(p_O, (T.int64(128), T.int64(128)), "float16") + O_shared_dyn = T.alloc_buffer((T.int64(128), T.int64(128)), scope="shared.dyn") + O_warp = T.alloc_buffer((T.int64(4), T.int64(4), T.int64(32), T.int64(8)), scope="warp") + for threadIdx_z in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for threadIdx_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for threadIdx_x in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for v0, v1 in T.grid(T.int64(4), T.int64(4)): + with T.block("O.dyn_warp_o"): + T.block_attr({"permuted_layout": 1}) + with T.block("O.dyn_warp_o"): + for local_id in range(T.int64(8)): + O_shared_dyn[threadIdx_z * T.int64(64) + v0 * T.int64(16) + local_id % T.int64(4) // T.int64(2) * T.int64(8) + threadIdx_x // T.int64(4), threadIdx_y * T.int64(64) + v1 * T.int64(16) + local_id // T.int64(4) * T.int64(8) + threadIdx_x % T.int64(4) * T.int64(2) + local_id % T.int64(2)] = O_warp[v0, v1, threadIdx_x, local_id] + for v0 in range(T.int64(16)): + for v1 in T.vectorized(T.int64(8)): + with T.block("O.dyn"): + T.block_attr({"permuted_layout": 1}) + O[v0 * T.int64(8) + threadIdx_z * T.int64(4) + threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16), threadIdx_x % T.int64(16) * T.int64(8) + v1] = T.Cast("float16", O_shared_dyn[v0 * T.int64(8) + threadIdx_z * T.int64(4) + threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16), threadIdx_x % T.int64(16) * T.int64(8) + v1]) + + + @T.prim_func + def expected(O: T.Buffer((T.int64(128), T.int64(128)), "float16")): + # with T.block("root"): + O_shared_dyn = T.alloc_buffer((T.int64(128), T.int64(128)), scope="shared.dyn") + O_warp = T.alloc_buffer((T.int64(4), T.int64(4), T.int64(32), T.int64(8)), scope="warp") + for threadIdx_z in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for threadIdx_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for threadIdx_x in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for v0, v1 in T.grid(T.int64(4), T.int64(4)): + with T.block("O.dyn_warp_o"): + T.reads(O_warp[v0, v1, threadIdx_x, T.int64(0):T.int64(8)]) + T.writes(O_shared_dyn[threadIdx_z * T.int64(64) + v0 * T.int64(16) + threadIdx_x // T.int64(4):threadIdx_z * T.int64(64) + v0 * T.int64(16) + threadIdx_x // T.int64(4) + T.int64(9), threadIdx_y * T.int64(64) + v1 * T.int64(16) + threadIdx_x % T.int64(4) * T.int64(2):threadIdx_y * T.int64(64) + v1 * T.int64(16) + threadIdx_x % T.int64(4) * T.int64(2) + T.int64(10)]) + with T.block("O.dyn_warp_o"): + T.reads(O_warp[v0, v1, threadIdx_x, T.int64(0):T.int64(8)]) + T.writes(O_shared_dyn[threadIdx_z * T.int64(64) + v0 * T.int64(16) + threadIdx_x // T.int64(4):threadIdx_z * T.int64(64) + v0 * T.int64(16) + threadIdx_x // T.int64(4) + T.int64(9), threadIdx_y * T.int64(64) + v1 * T.int64(16) + threadIdx_x % T.int64(4) * T.int64(2):threadIdx_y * T.int64(64) + v1 * T.int64(16) + threadIdx_x % T.int64(4) * T.int64(2) + T.int64(10)]) + for local_id in range(T.int64(8)): + O_shared_dyn[threadIdx_z * T.int64(64) + v0 * T.int64(16) + local_id % T.int64(4) // T.int64(2) * T.int64(8) + threadIdx_x // T.int64(4), T.bitwise_xor(threadIdx_y * T.int64(8) + v1 * T.int64(2) + local_id // T.int64(4), threadIdx_x // T.int64(4)) * T.int64(8) + threadIdx_x % T.int64(4) * T.int64(2) + local_id % T.int64(2)] = O_warp[v0, v1, threadIdx_x, local_id] + for v0 in range(T.int64(16)): + for v1 in T.vectorized(T.int64(8)): + with T.block("O.dyn"): + T.reads(O_shared_dyn[v0 * T.int64(8) + threadIdx_z * T.int64(4) + threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16), threadIdx_x % T.int64(16) * T.int64(8) + v1]) + T.writes(O[v0 * T.int64(8) + threadIdx_z * T.int64(4) + threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16), threadIdx_x % T.int64(16) * T.int64(8) + v1]) + O[v0 * T.int64(8) + threadIdx_z * T.int64(4) + threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16), threadIdx_x % T.int64(16) * T.int64(8) + v1] = T.Cast("float16", O_shared_dyn[v0 * T.int64(8) + threadIdx_z * T.int64(4) + threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16), T.bitwise_xor(threadIdx_x % T.int64(16), threadIdx_z * T.int64(4) + threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16)) * T.int64(8) + v1]) + + # fmt: on + _check_primfunc_transform(before, expected) + + +if __name__ == "__main__": + tvm.testing.main()