diff --git a/python/tvm/meta_schedule/testing/space_generation.py b/python/tvm/meta_schedule/testing/space_generation.py index 0b7072b65afe..45cd6659b6e0 100644 --- a/python/tvm/meta_schedule/testing/space_generation.py +++ b/python/tvm/meta_schedule/testing/space_generation.py @@ -88,7 +88,7 @@ def _find_match_sketch_id( decisions=new_decisions, ).apply_to_schedule(sch, remove_postproc=True) if structural_equal(sch.mod, expected_mod): - verify_trace_roundtrip(sch=sch, mod=mod, debug_mask=debug_mask) + verify_trace_roundtrip(sch=sch, mod=mod, debug_mask=debug_mask, text_format="json") return sketch_id return None diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index 99ffc1bfcdf7..6f9b46a0f734 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -162,6 +162,7 @@ class VerifyGPUCodeNode : public PostprocNode { pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); pass_list.push_back(tir::transform::UnifyThreadBinding()); + pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage()); pass_list.push_back(tir::transform::CompactBufferAllocation()); pass_list.push_back(tir::transform::LowerMatchBuffer()); pass_list.push_back(tir::transform::InjectSoftwarePipeline()); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index 779114e9cfea..0312c100b51b 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -186,15 +186,15 @@ std::vector MultiLevelTilingNode::AddWriteReuse(State state) const { return results; } -Array MultiLevelTilingNode::SplitLoop(const Schedule& sch, BlockRV block, LoopRV loop, - int n_tiles) const { +std::pair, Array> MultiLevelTilingNode::SplitLoop( + const Schedule& sch, BlockRV block, LoopRV loop, int n_tiles) const { Array factors = sch->SamplePerfectTile( /*loop=*/loop, /*n=*/n_tiles, /*max_innermost_factor=*/max_innermost_factor); Array splits = sch->Split(/*loop=*/loop, /*factors=*/{factors.begin(), factors.end()}); - return splits; + return {factors, splits}; } std::vector MultiLevelTilingNode::TileLoopNest(State state) const { @@ -207,6 +207,9 @@ std::vector MultiLevelTilingNode::TileLoopNest(State state) const { // Step 2. For each loop axis, tile it int64_t spatial_loop_product = 1; std::vector> tiles(s_indices_.size() + r_indices_.size()); + state->tile_factors.resize(tiles.size()); + std::vector> tile_factors; + tile_factors.resize(tiles.size()); for (int i = 0, n = loops.size(); i < n; ++i) { LoopRV loop = loops[i]; const std::vector* idx = nullptr; @@ -231,14 +234,16 @@ std::vector MultiLevelTilingNode::TileLoopNest(State state) const { if (n_tiles == 1) { tiles[idx->at(0)].push_back(loop); } else { - auto splits = SplitLoop(sch, block_rv, loop, n_tiles); + auto [factors, splits] = SplitLoop(sch, block_rv, loop, n_tiles); // Put every tile to its slot for (int j = 0; j < n_tiles; ++j) { tiles[idx->at(j)].push_back(splits[j]); + tile_factors[idx->at(j)].push_back(factors[j]); } } } + state->tile_factors = std::move(tile_factors); // Step 3. Reorder to organize the tiles sch->Reorder(support::ConcatArrayList(tiles.begin(), tiles.end())); // Step 4. Bind the tiles to threads diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index ff38756ff06b..41b3ca9f26f3 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -94,6 +94,8 @@ class StateNode : public Object { tir::BlockRV block_rv; /*! \brief The loop tiles */ Array> tiles; + /*! \brief The factors of the loop tiles. */ + Array> tile_factors; /*! \brief The mapping from buffer index to read cache block. */ std::unordered_map read_reuse; /*! \brief The mapping from buffer index to write cache block. */ @@ -163,8 +165,10 @@ class MultiLevelTilingNode : public ScheduleRuleNode { protected: virtual std::vector ApplySubRules(std::vector states); - virtual Array SplitLoop(const tir::Schedule& sch, tir::BlockRV block, - tir::LoopRV loop, int n_tiles) const; + virtual std::pair, Array> SplitLoop(const tir::Schedule& sch, + tir::BlockRV block, + tir::LoopRV loop, + int n_tiles) const; // Annotate a block to use cooperative fetching void AnnotateCooperativeFetching(tir::Schedule* sch, const tir::BlockRV& block) const; diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index d5cca52d41f9..1f9945022b66 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include #include @@ -124,6 +125,9 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode { private: // SubRule: Add tensorization-related transformations inline std::vector TransformForTensorization(TensorCoreState state) const; + // Subrule: Transform the layout of the output. This is necessary for efficient cache write the + // output in the shared memory. + std::vector TransformIntermediateOutputLayout(TensorCoreState state); // Subrule: Add tensorized load inline std::vector AddReadReuseTensorCore(TensorCoreState state) const; // Subrule: Add tensorized store @@ -225,6 +229,9 @@ std::vector MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector(state)); }); states = SubRule(std::move(states), [&](State state) { return TileLoopNest(state); }); + states = SubRule(std::move(states), [&](State state) { + return TransformIntermediateOutputLayout(Downcast(state)); + }); states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(state); }); states = SubRule(std::move(states), [&](State state) { return AddWriteReuseTensorCore(Downcast(state)); @@ -248,25 +255,162 @@ void MultiLevelTilingTensorCoreNode::TileAndAnnotateTensorize(Schedule* sch, (*sch)->Annotate(blockized_outer, tir::attr::meta_schedule_auto_tensorize, intrin_name); } +std::vector MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLayout( + TensorCoreState state) { + // Transform the intermediate output to packed layout + // [..., warp_m, warp_n, accum_frag_m, accum_frag_n, accum_elem_m, accum_elem_n] + // where warp_m, warp_n are thread indices bound to the warp id, accum_frag_m, accum_frag_n are + // the index of the fragments in each warp, accum_elem_m, accum_elem_n are the index of the + // elements in each accumulator fragment. + + // Get the shape of the wmma accumulator + auto [frag_shape_m, frag_shape_n] = [&]() { + tir::Block intrin_block = + Downcast( + tir::TensorIntrin::Get(state->intrin_group.init_intrin).value()->desc->body) + ->block; + tir::For loop_m = Downcast(intrin_block->body); + tir::For loop_n = Downcast(loop_m->body); + return std::make_tuple(loop_m->extent, loop_n->extent); + }(); + + // Get the tile index of the warp id (i.e. threadIdx.y) + auto it = std::find(tile_binds.begin(), tile_binds.end(), "threadIdx.y"); + ICHECK(it != tile_binds.end()); + auto tile_index_warp_id = std::distance(tile_binds.begin(), it); + + // Get the extent of loop indicated by `loop_idx` inside the warp scope. + // For example, after spatial loops i, j are tiled, we will have + // tile_factors = ((i0, j0), (i1, j1), ..., (in, jn)) + // This function computes the product of tile_factors[i][loop_idx] for i > tile_index_warp_id. + // `loop_idx` can be negative, in which case it is counted from the end. + auto f_get_inner_tile_product = [&](int loop_idx) { + Array factors; + for (int i = tile_index_warp_id + 1; i < static_cast(s_indices_.size()); ++i) { + auto s_factors = state->tile_factors[s_indices_[i]]; + if (loop_idx < 0) { + loop_idx += s_factors.size(); + } + factors.push_back(s_factors[loop_idx]); + } + ICHECK(!factors.empty()); + if (factors.size() == 1) { + return factors[0]; + } + auto result = factors[0]; + for (int i = 1; i < static_cast(factors.size()); ++i) { + result = result * factors[i]; + } + return result; + }; + + // Compute the number of output fragment of each warp + auto warp_num_frag_m = f_get_inner_tile_product(-2); + auto warp_num_frag_n = f_get_inner_tile_product(-1); + + Schedule& sch = state->sch; + int buffer_ndim = static_cast(sch->Get(state->block_rv)->writes[0]->buffer->shape.size()); + // The dimension of the buffer should be larger or same as that of the tensor intrin. + ICHECK_GE(buffer_ndim, 2); + int num_higher_dims = buffer_ndim - 2; + + auto index_map = + tir::IndexMap::FromFunc(buffer_ndim, + // frag_shape_m and frag_shape_n are structural bindings that cannot + // not be automatically captured until c++20 + [&, frag_shape_m = frag_shape_m, + frag_shape_n = frag_shape_n](const Array& indices) { + Array result; + result.reserve(indices.size() + 4); + for (int i = 0; i < num_higher_dims; ++i) { + result.push_back(indices[i]); + } + const auto& m = indices[num_higher_dims]; + const auto& n = indices[num_higher_dims + 1]; + auto accum_m = floormod(m, frag_shape_m); + auto accum_n = floormod(n, frag_shape_n); + auto outer_m = floordiv(m, frag_shape_m); + auto outer_n = floordiv(n, frag_shape_n); + + result.push_back(floordiv(outer_m, warp_num_frag_m)); + result.push_back(floordiv(outer_n, warp_num_frag_n)); + result.push_back(floormod(outer_m, warp_num_frag_m)); + result.push_back(floormod(outer_n, warp_num_frag_n)); + result.push_back(accum_m); + result.push_back(accum_n); + return result; + }); + sch->TransformLayout(state->block_rv, 0, tir::BufferIndexType::kWrite, index_map, + /*pad_value=*/NullOpt, /*assume_injective_transform=*/true); + + return {state}; +} + std::vector MultiLevelTilingTensorCoreNode::AddWriteReuseTensorCore( TensorCoreState state) const { // Add the cache write stage for Tensor Core - int level = r_indices_.front() - 1; - const LoopRV& loop = state->tiles[level].back(); Schedule& sch = state->sch; auto cache_write = sch->CacheWrite(state->block_rv, 0, "wmma.accumulator"); - sch->ReverseComputeAt(cache_write, loop, true); - - if (state->write_reuse.count(0)) { - // Fuse the iterators of the cache_write - Array buffer_loops = sch->GetLoops(state->write_reuse[0]); - ICHECK_GT(buffer_loops.size(), 2); - sch->Fuse(Array{buffer_loops.end() - 2, // The src shmem is always 2D - buffer_loops.end()}); - AnnotateCooperativeFetching(&sch, state->write_reuse[0]); + + // The compute block has been tiled by the warp shape and the fragment shape. + // We need to bind the cache write block (from the accumulator to the shared memory) to the warp + // id. The schedule is as follows: + // + // After adding cache write for wmma.accumulator, we will have + // for i0, j0, i1, j1, accum_m, accum_n: + // shared_mem[i0, j0, i1, j1, accum_m, accum_n] = accum[i0, j0, i1, j1, accum_m, accum_n] + // for i0', j0', i1', j1', accum_m', accum_n': + // global_mem[i0', j0', i1', j1', accum_m', accum_n'] = + // shared_mem[i0', j0', i1', j1', accum_m', accum_n'] + // where i0' and j0' are already bound to the block id and warp id. + // + // To reduce the shared memory usage and allow efficient data movement, we will apply + // transformations to generate the following schedule: + // + // for i1': + // for i0_j0 (fused and bound to threadIdx.y): + // for j1, accum_m, accum_n: + // shared_mem[i0, j0, i1, j1, accum_m, accum_n] = accum[i0, j0, i1, j1, accum_m, accum_n] + // for i0', j0', j1', accum_m', accum_n': + // global_mem[i0', j0', i1', j1', accum_m', accum_n'] = + // shared_mem[i0', j0', i1', j1', accum_m', accum_n'] + // + // i1' is reordered to the outermost. This effectively allows only a row (i.e. loop i1') of the + // fragments are moved to the shared memory and then to the global memory each time. + // As a result, shared memory for the output will only have shape of [j1, accum_m, accum_n] + // instead of [i0 * i1 * accum_m, j0 * j1 * accum_n]. + + // Get the loops other than the innermost two loops (accum_m and accum_n). + auto f_get_loops = [&](const BlockRV& block_rv) -> std::array { + Array buffer_loops = sch->GetLoops(block_rv); + ICHECK_GT(buffer_loops.size(), 6); + return {buffer_loops[buffer_loops.size() - 6], buffer_loops[buffer_loops.size() - 5], + buffer_loops[buffer_loops.size() - 4], buffer_loops[buffer_loops.size() - 3]}; + }; + { + const auto& [i0, j0, i1, j1] = f_get_loops(state->write_reuse[0]); + sch->Reorder({i1, i0, j0, j1}); + sch->ComputeAt(cache_write, i1, true); + } + { + auto loops = f_get_loops(cache_write); + const auto& i0 = loops[0]; + const auto& j0 = loops[1]; + auto fused = sch->Fuse({i0, j0}); + sch->Bind(fused, "threadIdx.y"); } + sch->ReverseComputeInline(state->tensor_core_reindex_store); - TileAndAnnotateTensorize(&sch, cache_write, state->intrin_group.store_intrin); + auto loops = sch->GetLoops(cache_write); + auto blockized_store = sch->Blockize(loops[loops.size() - 2]); + sch->Annotate(blockized_store, tir::attr::meta_schedule_auto_tensorize, + state->intrin_group.store_intrin); + + Array buffer_loops = sch->GetLoops(state->write_reuse[0]); + ICHECK_GT(buffer_loops.size(), 5); + sch->Fuse(Array{buffer_loops.end() - 5, // The src shmem is always 2D + buffer_loops.end()}); + AnnotateCooperativeFetching(&sch, state->write_reuse[0]); return {state}; } @@ -508,7 +652,8 @@ Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( state->sch->state(), GetRef(block), buffer_index, index_type); auto sub_index_map = f_get_sub_index_map(lhs_buffer, reindexed_buffer_region->region); buffer_sub_index_map.Set(lhs_buffer, sub_index_map); - state->sch->TransformLayout(state->block_rv, buffer_index, index_type, sub_index_map, NullOpt); + state->sch->TransformLayout(state->block_rv, buffer_index, index_type, sub_index_map, + /*pad_value=*/NullOpt, /*assume_injective_transform=*/true); }; for (int i = 0, n = block_before_reindex->reads.size(); i < n; ++i) { @@ -569,6 +714,11 @@ ScheduleRule ScheduleRule::MultiLevelTilingTensorCore( auto node = MultiLevelTilingInitCommon( structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write); + CHECK(node->reuse_write_.req == ReuseType::kMustReuse && + runtime::StorageScope::Create(node->reuse_write_.scope).rank == + runtime::StorageRank::kShared) + << "ValueError: Shared memory write reuse must be enabled for MultiLevelTilingTensorCore."; + node->intrin_groups.reserve(intrin_groups.size()); for (const auto& intrin_group_config : intrin_groups) { node->intrin_groups.emplace_back(TensorCoreIntrinGroup::FromConfig(intrin_group_config)); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc index d4c4a10fdd72..e68b64ea2d3a 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc @@ -48,11 +48,12 @@ class MultiLevelTilingWideVectorNode : public MultiLevelTilingNode { return ScheduleRule(n); } - Array SplitLoop(const Schedule& sch, BlockRV block, LoopRV loop, int n_tiles) const; + std::pair, Array> SplitLoop(const Schedule& sch, BlockRV block, + LoopRV loop, int n_tiles) const; }; -Array MultiLevelTilingWideVectorNode::SplitLoop(const Schedule& sch, BlockRV block_rv, - LoopRV loop_rv, int n_tiles) const { +std::pair, Array> MultiLevelTilingWideVectorNode::SplitLoop( + const Schedule& sch, BlockRV block_rv, LoopRV loop_rv, int n_tiles) const { const tir::ForNode* loop = TVM_SREF_TO_FOR(sch->GetSRef(loop_rv)); const tir::StmtSRef block_sref = sch->GetSRef(block_rv); const tir::BlockNode* block_node = block_sref->StmtAs(); @@ -99,12 +100,14 @@ Array MultiLevelTilingWideVectorNode::SplitLoop(const Schedule& sch Array outer_splits = sch->Split( /*loop=*/inner_splits[0], /*factors=*/{outer_factors.begin(), outer_factors.end()}); outer_splits.push_back(inner_splits[1]); - return outer_splits; + outer_factors.push_back(PrimExpr(vec_len)); + return {outer_factors, outer_splits}; } else { Array factors(n_tiles - 1, PrimExpr(1)); factors.push_back(loop->extent); - return sch->Split(/*loop=*/loop_rv, - /*factors=*/{factors.begin(), factors.end()}); + Array splits = sch->Split(/*loop=*/loop_rv, + /*factors=*/{factors.begin(), factors.end()}); + return {factors, splits}; } } } diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index e9bff1b6fdee..ab328efaa6d1 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -76,8 +76,6 @@ class BlockReadWriteDetector : public StmtExprVisitor { Map buffer_var_map_; /*! \brief The target buffer var mapping to its matching */ std::unordered_map match_buffers_; - /*! \brief The analyzer for simplifying*/ - arith::Analyzer analyzer_; /*! * \brief Update read/write buffers and regions with provided buffer and region @@ -330,7 +328,12 @@ Array BlockReadWriteDetector::CollectRegions( ICHECK_EQ(buffers[i]->shape.size(), regions[i].size()); for (size_t j = 0; j < regions[i].size(); j++) { const tvm::arith::IntSet& range = regions[i][j]; - region.push_back(range.CoverRange(Range::FromMinExtent(0, buffers[i]->shape[j]))); + if (range.IsSinglePoint()) { + PrimExpr min = range.min(); + region.push_back(Range::FromMinExtent(min, make_const(min.dtype(), 1))); + } else { + region.push_back(range.CoverRange(Range::FromMinExtent(0, buffers[i]->shape[j]))); + } } res.push_back(BufferRegion(buffers[i], region)); } diff --git a/src/tir/schedule/ir_comparator.cc b/src/tir/schedule/ir_comparator.cc index 9d89c641630b..5353a051a60a 100644 --- a/src/tir/schedule/ir_comparator.cc +++ b/src/tir/schedule/ir_comparator.cc @@ -43,7 +43,7 @@ class TensorIntrinMismatchError : public ScheduleError { std::ostringstream os; os << "The stmt {0} doesn't match the tensor intrin\nThe pattern attempting to be matched:\n" << lhs_stmt_ << "\nDoes not match the tensorize description:\n" - << rhs_stmt_; + << rhs_stmt_ << '\n'; for (const auto& msg : error_messages_) { os << msg << std::endl; } @@ -173,6 +173,9 @@ bool TensorizeComparator::VisitStmt_(const BlockRealizeNode* op, const Stmt& oth bool TensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt& other) { const auto* rhs = other.as(); + for (const IterVar& iter : op->iter_vars) { + lhs_analyzer_.Bind(iter->var, iter->dom); + } // Check block equality. // All iter vars and buffer regions including the order should match. // When checking iter vars, DefEqual is used to remap variables. @@ -465,7 +468,7 @@ bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const Buf } return false; } - if (!analyzer_.CanProveEqual(indices_base[i], lhs->region[i]->min)) { + if (!lhs_analyzer_.CanProveEqual(indices_base[i], lhs->region[i]->min)) { if (assert_mode_) { std::ostringstream os; os << "Buffer base index consistency check failed due to unequal index base: " @@ -487,7 +490,8 @@ bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const Buf } return false; } - PrimExpr normalized_lhs_min = (lhs->region[i + offset]->min - indices_base[i + offset]); + PrimExpr normalized_lhs_min = + lhs_analyzer_.Simplify((lhs->region[i + offset]->min - indices_base[i + offset])); if (!analyzer_.CanProveEqual(normalized_lhs_min, rhs->region[i]->min)) { if (assert_mode_) { std::ostringstream os; diff --git a/src/tir/schedule/ir_comparator.h b/src/tir/schedule/ir_comparator.h index 394d82867393..debf0f946e28 100644 --- a/src/tir/schedule/ir_comparator.h +++ b/src/tir/schedule/ir_comparator.h @@ -102,8 +102,13 @@ class TensorizeComparator : public ExprComparator, public StmtComparator { bool assert_mode_; /*! \brief Whether it is visiting the scope block (the outermost block). */ bool is_scope_block = true; - /*! \brief The arithmetic analyzer. */ + /*! \brief The arithmetic analyzer for comparing LHS and RHS */ arith::Analyzer analyzer_; + /*! + * \brief The arithmetic analyzer for simplifying expressions on LHS. + * This analyzer only contains the domains of the iterators on LHS. + */ + arith::Analyzer lhs_analyzer_; /*! \brief Additional error messages. Only used when assert_mode is true. */ std::vector error_messages_; // variable remap if any diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py index 9b869b4436c0..1cab2554e88f 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py @@ -83,39 +83,39 @@ def test_matmul_relu(shared_scope): @T.prim_func def matmul_relu_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "float16"), compute: T.Buffer((128, 128), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", scope=shared_scope) - C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") - A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope=shared_scope) - B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope=shared_scope) - A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a") - B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_b") + C_reindex_shared = T.alloc_buffer((4, 8, 2, 1, 16, 16), scope=shared_scope) + C_reindex_shared_wmma_accumulator = T.alloc_buffer((4, 8, 2, 1, 16, 16), scope="wmma.accumulator") + A_reindex_shared = T.alloc_buffer((128, 128), "float16", scope=shared_scope) + B_reindex_shared = T.alloc_buffer((128, 128), "float16", scope=shared_scope) + A_reindex_shared_wmma_matrix_a = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_a") + B_reindex_shared_wmma_matrix_b = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_b") for ax0_0_0_ax1_0_0_fused in T.thread_binding(8, thread="blockIdx.y"): for ax0_0_1_ax1_0_1_fused in T.thread_binding(2, thread="blockIdx.x"): for ax0_0_2_ax1_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): - for ax2_0_0 in T.serial(1): - for ax0_ax1_fused in T.serial(4096): + for ax2_0_0 in range(1): + for ax0_ax1_fused in range(4096): with T.block("A_reindex_shared"): v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 128) v1 = T.axis.spatial(128, ax0_ax1_fused % 128) T.reads(A[v0, v1]) T.writes(A_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":8}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8}) A_reindex_shared[v0, v1] = A[v0, v1] - for ax0_ax1_fused in T.serial(4096): + for ax0_ax1_fused in range(4096): with T.block("B_reindex_shared"): v0 = T.axis.spatial(128, ax0_ax1_fused // 32) v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32) T.reads(B[v0, v1]) T.writes(B_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":1}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 1}) B_reindex_shared[v0, v1] = B[v0, v1] - for ax2_0_1 in T.serial(4): + for ax2_0_1 in range(4): for ax0_0, ax1_0 in T.grid(2, 2): with T.block("A_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0) v1_o = T.axis.spatial(8, ax2_0_1 * 2 + ax1_0) - T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.reads(A_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_a_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("A_reindex_shared_wmma.matrix_a"): @@ -127,8 +127,8 @@ def matmul_relu_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "f with T.block("B_reindex_shared_wmma.matrix_b_o"): v0_o = T.axis.spatial(8, ax2_0_1 * 2 + ax0_0) v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0) - T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.reads(B_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_b_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("B_reindex_shared_wmma.matrix_b"): @@ -141,44 +141,54 @@ def matmul_relu_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "f v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0_3 * 2 + ax0_0_4) v1_o = T.axis.spatial(8, ax1_0_4 + ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0_3) v2_o = T.axis.reduce(8, ax2_0_0 * 8 + ax2_0_1 * 2 + ax2_0_2) - T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) + T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax0_1, ax1_1 in T.grid(16, 16): with T.block("C_init"): v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads() - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init]) - C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i_init, v1_i_init]) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i_init, v1_i_init] = T.float32(0) for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): with T.block("C"): v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) - C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32") - for ax0_0, ax1_0 in T.grid(2, 1): - with T.block("C_reindex_shared_wmma.accumulator_o"): - v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0) - v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("C_reindex_shared_wmma.accumulator"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0_ax1_fused in T.serial(1024): - with T.block("C_reindex_shared"): - v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 32) - v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32) - T.reads(C_reindex_shared[v0, v1]) - T.writes(compute[v0, v1]) - T.block_attr({"meta_schedule.cooperative_fetch":4}) - compute[v0, v1] = T.max(C_reindex_shared[v0, v1], T.float32(0)) + T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i] + T.Cast("float32", A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + for ax2 in range(2): + for ax0_ax1_fused in T.thread_binding(2, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 1): + with T.block("C_reindex_shared_wmma.accumulator_o"): + v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused // 2) + v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_fused) + v2 = T.axis.spatial(2, ax2 + ax2_1) + v3 = T.axis.spatial(1, ax3) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) + T.writes(C_reindex_shared[v0, v1, v2, v3, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("C_reindex_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) + T.writes(C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i]) + C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(512): + with T.block("C_reindex_shared"): + v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused // 2) + v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_ax3_ax4_ax5_fused // 256) + v2 = T.axis.spatial(2, ax2) + v3 = T.axis.spatial(1, 0) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16]) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) + compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16] = T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0)) # fmt: on decision_0 = [ @@ -223,44 +233,42 @@ def test_matmul_relu_with_fallback(): # fmt: off @T.prim_func def matmul_relu_fallback_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "float16"), compute: T.Buffer((128, 128), "float32")) -> None: - # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body - # with T.block("root") - C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared") - C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") - A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") - B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") - A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a") - B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_b") + # with T.block("root"): + C_reindex_shared = T.alloc_buffer((4, 2, 2, 4, 16, 16), scope="shared") + C_reindex_shared_wmma_accumulator = T.alloc_buffer((4, 2, 2, 4, 16, 16), scope="wmma.accumulator") + A_reindex_shared = T.alloc_buffer((128, 128), "float16", scope="shared") + B_reindex_shared = T.alloc_buffer((128, 128), "float16", scope="shared") + A_reindex_shared_wmma_matrix_a = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_a") + B_reindex_shared_wmma_matrix_b = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_b") for ax0_0_0_ax1_0_0_fused in T.thread_binding(2, thread="blockIdx.y"): for ax0_0_1_ax1_0_1_fused in T.thread_binding(2, thread="blockIdx.x"): for ax0_0_2_ax1_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): - for ax2_0_0 in T.serial(2): - for ax0_ax1_fused in T.serial(2048): + for ax2_0_0 in range(2): + for ax0_ax1_fused in range(2048): with T.block("A_reindex_shared"): v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused // 64) v1 = T.axis.spatial(128, ax2_0_0 * 64 + ax0_ax1_fused % 64) T.reads(A[v0, v1]) T.writes(A_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":4}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 4}) A_reindex_shared[v0, v1] = A[v0, v1] - for ax0_ax1_fused in T.serial(8192): + for ax0_ax1_fused in range(8192): with T.block("B_reindex_shared"): v0 = T.axis.spatial(128, ax2_0_0 * 64 + ax0_ax1_fused // 128) v1 = T.axis.spatial(128, ax0_ax1_fused % 128) T.reads(B[v0, v1]) T.writes(B_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":2}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2}) B_reindex_shared[v0, v1] = B[v0, v1] - for ax2_0_1 in T.serial(1): + for ax2_0_1 in range(1): for ax0_0, ax1_0 in T.grid(2, 4): with T.block("A_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0) v1_o = T.axis.spatial(8, ax2_0_0 * 4 + ax1_0) - T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a_shared"}) + T.reads(A_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("A_reindex_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) @@ -271,9 +279,9 @@ def matmul_relu_fallback_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, with T.block("B_reindex_shared_wmma.matrix_b_o"): v0_o = T.axis.spatial(8, ax2_0_0 * 4 + ax0_0) v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused * 4 + ax1_0) - T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b_shared"}) + T.reads(B_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("B_reindex_shared_wmma.matrix_b"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) @@ -285,44 +293,54 @@ def matmul_relu_fallback_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_3 * 2 + ax0_0_4) v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused * 4 + ax1_0_3 * 4 + ax1_0_4) v2_o = T.axis.reduce(8, ax2_0_0 * 4 + ax2_0_1 * 4 + ax2_0_2) - T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) + T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax0_1, ax1_1 in T.grid(16, 16): with T.block("C_init"): v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads() - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init]) - C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i_init, v1_i_init]) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i_init, v1_i_init] = T.float32(0) for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): with T.block("C"): v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) - C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32") - for ax0_0, ax1_0 in T.grid(2, 4): - with T.block("C_reindex_shared_wmma.accumulator_o"): - v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0) - v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused * 4 + ax1_0) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("C_reindex_shared_wmma.accumulator"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0_ax1_fused in T.serial(4096): - with T.block("C_reindex_shared"): - v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused // 128) - v1 = T.axis.spatial(128, ax0_ax1_fused % 128) - T.reads(C_reindex_shared[v0, v1]) - T.writes(compute[v0, v1]) - T.block_attr({"meta_schedule.cooperative_fetch":4}) - compute[v0, v1] = T.max(C_reindex_shared[v0, v1], T.float32(0)) + T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i] + T.Cast("float32", A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + for ax2 in range(2): + for ax0_ax1_fused in T.thread_binding(2, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 4): + with T.block("C_reindex_shared_wmma.accumulator_o"): + v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused * 2 + ax0_0_1_ax1_0_1_fused) + v1 = T.axis.spatial(2, ax0_ax1_fused) + v2 = T.axis.spatial(2, ax2 + ax2_1) + v3 = T.axis.spatial(4, ax3) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) + T.writes(C_reindex_shared[v0, v1, v2, v3, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("C_reindex_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) + T.writes(C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i]) + C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(2048): + with T.block("C_reindex_shared"): + v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused * 2 + ax0_0_1_ax1_0_1_fused) + v1 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused // 1024) + v2 = T.axis.spatial(2, ax2) + v3 = T.axis.spatial(4, ax0_ax1_ax3_ax4_ax5_fused % 1024 // 256) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 64]) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) + compute[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 64] = T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0)) # fmt: on decision_0 = [ ("SamplePerfectTile", [2, 2, 1, 1, 2]), @@ -373,46 +391,46 @@ def test_conv2d(shared_scope): @T.prim_func def conv2d_0(inputs: T.Buffer((1, 16, 16, 32), "float16"), weight: T.Buffer((3, 3, 32, 32), "float16"), conv2d_nhwc: T.Buffer((1, 16, 16, 32), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - PadInput = T.alloc_buffer([1, 18, 18, 32], dtype="float16") - conv2d_nhwc_reindex_shared = T.alloc_buffer([256, 32], dtype="float32", scope=shared_scope) - conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer([256, 32], dtype="float32", scope="wmma.accumulator") - PadInput_reindex_shared = T.alloc_buffer([256, 288], dtype="float16", scope=shared_scope) - weight_reindex_shared = T.alloc_buffer([288, 32], dtype="float16", scope=shared_scope) - PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer([256, 288], dtype="float16", scope="wmma.matrix_a") - weight_reindex_shared_wmma_matrix_b = T.alloc_buffer([288, 32], dtype="float16", scope="wmma.matrix_b") + PadInput = T.alloc_buffer((1, 18, 18, 32), "float16") + conv2d_nhwc_reindex_shared = T.alloc_buffer((16, 2, 1, 1, 16, 16), scope=shared_scope) + conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((16, 2, 1, 1, 16, 16), scope="wmma.accumulator") + PadInput_reindex_shared = T.alloc_buffer((256, 288), "float16", scope=shared_scope) + weight_reindex_shared = T.alloc_buffer((288, 32), "float16", scope=shared_scope) + PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer((256, 288), "float16", scope="wmma.matrix_a") + weight_reindex_shared_wmma_matrix_b = T.alloc_buffer((288, 32), "float16", scope="wmma.matrix_b") for i0, i1, i2, i3 in T.grid(1, 18, 18, 32): with T.block("PadInput"): - i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1]) - T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) - PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i1_1 and i1_1 < 17 and 1 <= i2_1 and i2_1 < 17, inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1], T.float16(0), dtype="float16") + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) + PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(1 <= v_i1 and v_i1 < 17 and 1 <= v_i2 and v_i2 < 17, inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3], T.float16(0)) for ax0_0_0_ax1_0_0_fused in T.thread_binding(2, thread="blockIdx.y"): for ax0_0_1_ax1_0_1_fused in T.thread_binding(16, thread="blockIdx.x"): for ax0_0_2_ax1_0_2_fused in T.thread_binding(1, thread="threadIdx.y"): - for ax2_0_0 in T.serial(1): - for ax0_ax1_fused in T.serial(4608): + for ax2_0_0 in range(1): + for ax0_ax1_fused in range(4608): with T.block("PadInput_reindex_shared"): v0 = T.axis.spatial(256, ax0_0_1_ax1_0_1_fused * 16 + ax0_ax1_fused // 288) v1 = T.axis.spatial(288, ax0_ax1_fused % 288) T.reads(PadInput[v0 // 256, v1 // 96 + v0 // 16, v1 % 96 // 32 + v0 % 16, v1 % 32]) T.writes(PadInput_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":2}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2}) PadInput_reindex_shared[v0, v1] = PadInput[v0 // 256, v1 // 96 + v0 // 16, v1 % 96 // 32 + v0 % 16, v1 % 32] - for ax0_ax1_fused in T.serial(4608): + for ax0_ax1_fused in range(4608): with T.block("weight_reindex_shared"): v0 = T.axis.spatial(288, ax0_ax1_fused // 16) v1 = T.axis.spatial(32, ax0_0_0_ax1_0_0_fused * 16 + ax0_ax1_fused % 16) T.reads(weight[v0 // 96, v0 % 96 // 32, v0 % 32, v1]) T.writes(weight_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":8}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8}) weight_reindex_shared[v0, v1] = weight[v0 // 96, v0 % 96 // 32, v0 % 32, v1] - for ax2_0_1 in T.serial(18): + for ax2_0_1 in range(18): for ax0_0, ax1_0 in T.grid(1, 1): with T.block("PadInput_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(16, ax0_0_1_ax1_0_1_fused + ax0_0) v1_o = T.axis.spatial(18, ax2_0_1 + ax1_0) - T.reads(PadInput_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.reads(PadInput_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_a_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("PadInput_reindex_shared_wmma.matrix_a"): @@ -424,8 +442,8 @@ def conv2d_0(inputs: T.Buffer((1, 16, 16, 32), "float16"), weight: T.Buffer((3, with T.block("weight_reindex_shared_wmma.matrix_b_o"): v0_o = T.axis.spatial(18, ax2_0_1 + ax0_0) v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused + ax1_0) - T.reads(weight_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(weight_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.reads(weight_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(weight_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_b_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("weight_reindex_shared_wmma.matrix_b"): @@ -438,44 +456,49 @@ def conv2d_0(inputs: T.Buffer((1, 16, 16, 32), "float16"), weight: T.Buffer((3, v0_o = T.axis.spatial(16, ax0_0_4 + ax0_0_1_ax1_0_1_fused + ax0_0_3) v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused + ax1_0_3 + ax1_0_4) v2_o = T.axis.reduce(18, ax2_0_0 * 18 + ax2_0_1 + ax2_0_2) - T.reads(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) + T.reads(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax0_1, ax1_1 in T.grid(16, 16): with T.block("conv2d_nhwc_init"): v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads() - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init]) - conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i_init, v1_i_init]) + conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i_init, v1_i_init] = T.float32(0) for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): with T.block("conv2d_nhwc"): v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], weight_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) - conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(weight_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32") - for ax0_0, ax1_0 in T.grid(1, 1): - with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): - v0_o = T.axis.spatial(16, ax0_0_1_ax1_0_1_fused + ax0_0) - v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused + ax1_0) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0_ax1_fused in T.serial(256): - with T.block("conv2d_nhwc_reindex_shared"): - v0 = T.axis.spatial(256, ax0_0_1_ax1_0_1_fused * 16 + ax0_ax1_fused // 16) - v1 = T.axis.spatial(32, ax0_0_0_ax1_0_0_fused * 16 + ax0_ax1_fused % 16) - T.reads(conv2d_nhwc_reindex_shared[v0, v1]) - T.writes(conv2d_nhwc[v0 // 256, v0 // 16, v0 % 16, v1]) - T.block_attr({"meta_schedule.cooperative_fetch":3}) - conv2d_nhwc[v0 // 256, v0 // 16, v0 % 16, v1] = conv2d_nhwc_reindex_shared[v0, v1] + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i], PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], weight_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i] + T.Cast("float32", PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", weight_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + for ax2 in range(1): + for ax0_ax1_fused in T.thread_binding(1, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 1): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0_0_1_ax1_0_1_fused, ax0_0_0_ax1_0_0_fused, ax2_1, ax3]) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) + T.writes(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) + T.writes(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4_i, v5_i]) + conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(256): + with T.block("conv2d_nhwc_reindex_shared"): + v0, v1, v2 = T.axis.remap("SSS", [ax0_0_1_ax1_0_1_fused, ax0_0_0_ax1_0_0_fused, ax2]) + v3 = T.axis.spatial(1, 0) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.writes(conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16]) + T.block_attr({"meta_schedule.cooperative_fetch": 3}) + conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16] = conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5] # fmt: on decision_0 = [ ("SamplePerfectTile", [1, 16, 1, 1, 1]), @@ -551,40 +574,40 @@ def matmul_relu_pipeline_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") - C = T.alloc_buffer([128, 128], dtype="float32") - C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", scope=shared_scope) - C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") - A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope=shared_scope) - B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope=shared_scope) - A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a") - B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_b") + C = T.alloc_buffer((128, 128)) + C_reindex_shared = T.alloc_buffer((4, 4, 2, 2, 16, 16), scope=shared_scope) + C_reindex_shared_wmma_accumulator = T.alloc_buffer((4, 4, 2, 2, 16, 16), scope="wmma.accumulator") + A_reindex_shared = T.alloc_buffer((128, 128), "float16", scope=shared_scope) + B_reindex_shared = T.alloc_buffer((128, 128), "float16", scope=shared_scope) + A_reindex_shared_wmma_matrix_a = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_a") + B_reindex_shared_wmma_matrix_b = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_b") for ax0_0_0_ax1_0_0_fused in T.thread_binding(1, thread="blockIdx.y"): for ax0_0_1_ax1_0_1_fused in T.thread_binding(16, thread="blockIdx.x"): for ax0_0_2_ax1_0_2_fused in T.thread_binding(1, thread="threadIdx.y"): - for ax2_0_0 in T.serial(4, annotations={"software_pipeline_order":[0, 3, 1, 4, 5, 2, 6], "software_pipeline_stage":[0, 0, 0, 0, 0, 1, 1]}): - for ax0_ax1_fused in T.serial(1024): + for ax2_0_0 in T.serial(4, annotations={"software_pipeline_order": [0, 3, 1, 4, 5, 2, 6], "software_pipeline_stage": [0, 0, 0, 0, 0, 1, 1]}): + for ax0_ax1_fused in range(1024): with T.block("A_reindex_shared"): v0 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused // 4 * 32 + ax0_ax1_fused // 32) v1 = T.axis.spatial(128, ax2_0_0 * 32 + ax0_ax1_fused % 32) T.reads(A[v0, v1]) T.writes(A_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "double_buffer_scope":0, "meta_schedule.cooperative_fetch":4, "tir.manifest_shared_memory_local_stage":1}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "double_buffer_scope": 0, "meta_schedule.cooperative_fetch": 4, "tir.manifest_shared_memory_local_stage": 1}) A_reindex_shared[v0, v1] = A[v0, v1] - for ax0_ax1_fused in T.serial(1024): + for ax0_ax1_fused in range(1024): with T.block("B_reindex_shared"): v0 = T.axis.spatial(128, ax2_0_0 * 32 + ax0_ax1_fused // 32) v1 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused % 4 * 32 + ax0_ax1_fused % 32) T.reads(B[v0, v1]) T.writes(B_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "double_buffer_scope":0, "meta_schedule.cooperative_fetch":2, "tir.manifest_shared_memory_local_stage":1}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "double_buffer_scope": 0, "meta_schedule.cooperative_fetch": 2, "tir.manifest_shared_memory_local_stage": 1}) B_reindex_shared[v0, v1] = B[v0, v1] - for ax2_0_1 in T.serial(2, annotations={"software_pipeline_order":[0, 1, 2], "software_pipeline_stage":[0, 0, 1]}): + for ax2_0_1 in T.serial(2, annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, 1]}): for ax0_0, ax1_0 in T.grid(2, 1): with T.block("A_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 4 * 2 + ax0_0) v1_o = T.axis.spatial(8, ax2_0_0 * 2 + ax2_0_1 + ax1_0) - T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.reads(A_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_a_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("A_reindex_shared_wmma.matrix_a"): @@ -596,8 +619,8 @@ def matmul_relu_pipeline_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, with T.block("B_reindex_shared_wmma.matrix_b_o"): v0_o = T.axis.spatial(8, ax2_0_0 * 2 + ax2_0_1 + ax0_0) v1_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused % 4 * 2 + ax1_0) - T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.reads(B_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_b_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("B_reindex_shared_wmma.matrix_b"): @@ -610,50 +633,61 @@ def matmul_relu_pipeline_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, v0_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 4 * 2 + ax0_0_3 * 2 + ax0_0_4) v1_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused % 4 * 2 + ax1_0_3 * 2 + ax1_0_4) v2_o = T.axis.reduce(8, ax2_0_0 * 2 + ax2_0_1 + ax2_0_2) - T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) + T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax0_1, ax1_1 in T.grid(16, 16): with T.block("C_init"): v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads() - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init]) - C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i_init, v1_i_init]) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i_init, v1_i_init] = T.float32(0) for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): with T.block("C"): v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) - C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32") - for ax0_0, ax1_0 in T.grid(2, 2): - with T.block("C_reindex_shared_wmma.accumulator_o"): - v0_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 4 * 2 + ax0_0) - v1_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused % 4 * 2 + ax1_0) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("C_reindex_shared_wmma.accumulator"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0_ax1_fused in T.grid(1024): - with T.block("C_reindex_shared"): - v0 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused // 4 * 32 + ax0_ax1_fused // 32) - v1 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused % 4 * 32 + ax0_ax1_fused % 32) - T.reads(C_reindex_shared[v0, v1]) - T.writes(C[v0, v1]) - T.block_attr({"meta_schedule.cooperative_fetch":3}) - C[v0, v1] = C_reindex_shared[v0, v1] + T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i, v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i, v1_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i, v1_i] + T.Cast("float32", A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + for ax2 in range(2): + for ax0_ax1_fused in T.thread_binding(1, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 2): + with T.block("C_reindex_shared_wmma.accumulator_o"): + v0 = T.axis.spatial(4, ax0_0_1_ax1_0_1_fused // 4) + v1 = T.axis.spatial(4, ax0_0_1_ax1_0_1_fused % 4) + v2 = T.axis.spatial(2, ax2 + ax2_1) + v3 = T.axis.spatial(2, ax3) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) + T.writes(C_reindex_shared[v0, v1, v2, v3, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("C_reindex_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) + T.writes(C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i]) + C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(512): + with T.block("C_reindex_shared"): + v0 = T.axis.spatial(4, ax0_0_1_ax1_0_1_fused // 4) + v1 = T.axis.spatial(4, ax0_0_1_ax1_0_1_fused % 4) + v2 = T.axis.spatial(2, ax2) + v3 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused // 256) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.writes(C[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 32]) + T.block_attr({"meta_schedule.cooperative_fetch": 3}) + C[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 32] = C_reindex_shared[v0, v1, v2, v3, v4, v5] for i0, i1 in T.grid(128, 128): with T.block("compute"): - i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(C[i0_1, i1_1]) - T.writes(compute[i0_1, i1_1]) - compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0)) + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(C[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.max(C[v_i0, v_i1], T.float32(0)) + # fmt: on decision_0 = [ ("SamplePerfectTile", [1, 4, 1, 1, 2]), @@ -693,141 +727,6 @@ def matmul_relu_pipeline_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, ) -def test_matmul_relu_global(): - # fmt: off - @T.prim_func - def matmul_relu_global_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "float16"), compute: T.Buffer((128, 128), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body - # with T.block("root") - C = T.alloc_buffer([128, 128], dtype="float32") - C_reindex_wmma_accumulator = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") - A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") - B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") - A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a") - B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_b") - for ax0_0_0_ax1_0_0_fused in T.thread_binding(1, thread="blockIdx.y"): - for ax0_0_1_ax1_0_1_fused in T.thread_binding(1, thread="blockIdx.x"): - for ax0_0_2_ax1_0_2_fused in T.thread_binding(16, thread="threadIdx.y"): - for ax2_0_0 in T.serial(2): - for ax0_ax1_fused in T.serial(8192): - with T.block("A_reindex_shared"): - v0 = T.axis.spatial(128, ax0_ax1_fused // 64) - v1 = T.axis.spatial(128, ax2_0_0 * 64 + ax0_ax1_fused % 64) - T.reads(A[v0, v1]) - T.writes(A_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":1}) - A_reindex_shared[v0, v1] = A[v0, v1] - for ax0_ax1_fused in T.serial(8192): - with T.block("B_reindex_shared"): - v0 = T.axis.spatial(128, ax2_0_0 * 64 + ax0_ax1_fused // 128) - v1 = T.axis.spatial(128, ax0_ax1_fused % 128) - T.reads(B[v0, v1]) - T.writes(B_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":1}) - B_reindex_shared[v0, v1] = B[v0, v1] - for ax2_0_1 in T.serial(2): - for ax0_0, ax1_0 in T.grid(1, 2): - with T.block("A_reindex_shared_wmma.matrix_a_o"): - v0_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused // 2 + ax0_0) - v1_o = T.axis.spatial(8, ax2_0_0 * 4 + ax2_0_1 * 2 + ax1_0) - T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a_shared"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("A_reindex_shared_wmma.matrix_a"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0_0, ax1_0 in T.grid(2, 4): - with T.block("B_reindex_shared_wmma.matrix_b_o"): - v0_o = T.axis.spatial(8, ax2_0_0 * 4 + ax2_0_1 * 2 + ax0_0) - v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused % 2 * 4 + ax1_0) - T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b_shared"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("B_reindex_shared_wmma.matrix_b"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(1, 4, 2, 1, 1): - with T.block("C_o"): - v0_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused // 2 + ax0_0_3 + ax0_0_4) - v1_o = T.axis.spatial(8, ax1_0_4 + ax0_0_2_ax1_0_2_fused % 2 * 4 + ax1_0_3) - v2_o = T.axis.reduce(8, ax2_0_0 * 4 + ax2_0_1 * 2 + ax2_0_2) - T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C_reindex_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) - with T.init(): - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("C_init"): - v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads() - T.writes(C_reindex_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init]) - C_reindex_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0) - for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): - with T.block("C"): - v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) - T.reads(C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) - T.writes(C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) - C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32") - for ax0_0, ax1_0 in T.grid(1, 4): - with T.block("C_reindex_wmma.accumulator_o"): - v0_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused // 2 + ax0_0) - v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused % 2 * 4 + ax1_0) - T.reads(C_reindex_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_global"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("C_reindex_wmma.accumulator"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(C[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - C[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for i0, i1 in T.grid(128, 128): - with T.block("compute"): - i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(C[i0_1, i1_1]) - T.writes(compute[i0_1, i1_1]) - compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0)) - # fmt: on - decision_0 = [ - ("SamplePerfectTile", [1, 1, 8, 1, 1]), - ("SamplePerfectTile", [1, 1, 2, 4, 1]), - ("SamplePerfectTile", [2, 2, 2]), - ("SampleCategorical", 0), - ("SampleCategorical", 0), - ] - mod = te.create_prim_func( - te_workload.matmul_relu( - n=128, - m=128, - k=128, - in_dtype="float16", - out_dtype="float32", - ) - ) - actual = generate_design_space( - kind="cuda", - mod=mod, - target=tvm.target.Target("cuda"), - types=None, - sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="global")] - + get_rules("cuda", ms.schedule_rule.AutoInline), - ) - check_sketches( - mod, - sketches=actual, - expected_mods=[matmul_relu_global_0], - expected_decisions=[decision_0], - ) - - def test_matmul_relu_non_tensorizable(): # expected to do nothing on non-tensorizable workloads mod = te.create_prim_func( @@ -842,7 +741,7 @@ def test_matmul_relu_non_tensorizable(): mod=mod, target=tvm.target.Target("cuda"), types=None, - sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="global")] + sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="shared")] + get_rules("cuda", ms.schedule_rule.AutoInline), ) tvm.ir.assert_structural_equal(mod, sch.mod["main"]) @@ -856,40 +755,40 @@ def padded_matmul_relu_0(A: T.Buffer((127, 127), "float16"), B: T.Buffer((127, 1 T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") - C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared") - C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") - A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") - B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") - A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a") - B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_b") + C_reindex_shared = T.alloc_buffer((4, 8, 2, 1, 16, 16), scope="shared") + C_reindex_shared_wmma_accumulator = T.alloc_buffer((4, 8, 2, 1, 16, 16), scope="wmma.accumulator") + A_reindex_shared = T.alloc_buffer((128, 128), "float16", scope="shared") + B_reindex_shared = T.alloc_buffer((128, 128), "float16", scope="shared") + A_reindex_shared_wmma_matrix_a = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_a") + B_reindex_shared_wmma_matrix_b = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_b") for ax0_0_0_ax1_0_0_fused in T.thread_binding(8, thread="blockIdx.y"): for ax0_0_1_ax1_0_1_fused in T.thread_binding(2, thread="blockIdx.x"): for ax0_0_2_ax1_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): - for ax2_0_0 in T.serial(1): - for ax0_ax1_fused in T.serial(4096): + for ax2_0_0 in range(1): + for ax0_ax1_fused in range(4096): with T.block("A_reindex_shared"): v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 128) v1 = T.axis.spatial(128, ax0_ax1_fused % 128) T.reads(A[v0, v1]) T.writes(A_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":8}) - A_reindex_shared[v0, v1] = T.if_then_else(v0 < 127 and v1 < 127, A[v0, v1], T.float16(0), dtype="float16") - for ax0_ax1_fused in T.serial(4096): + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8}) + A_reindex_shared[v0, v1] = T.if_then_else(v0 < 127 and v1 < 127, A[v0, v1], T.float16(0)) + for ax0_ax1_fused in range(4096): with T.block("B_reindex_shared"): v0 = T.axis.spatial(128, ax0_ax1_fused // 32) v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32) T.reads(B[v0, v1]) T.writes(B_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":1}) - B_reindex_shared[v0, v1] = T.if_then_else(v0 < 127 and v1 < 127, B[v0, v1], T.float16(0), dtype="float16") - for ax2_0_1 in T.serial(4): + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 1}) + B_reindex_shared[v0, v1] = T.if_then_else(v0 < 127 and v1 < 127, B[v0, v1], T.float16(0)) + for ax2_0_1 in range(4): for ax0_0, ax1_0 in T.grid(2, 2): with T.block("A_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0) v1_o = T.axis.spatial(8, ax2_0_1 * 2 + ax1_0) - T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a_shared"}) + T.reads(A_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("A_reindex_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) @@ -900,9 +799,9 @@ def padded_matmul_relu_0(A: T.Buffer((127, 127), "float16"), B: T.Buffer((127, 1 with T.block("B_reindex_shared_wmma.matrix_b_o"): v0_o = T.axis.spatial(8, ax2_0_1 * 2 + ax0_0) v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0) - T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b_shared"}) + T.reads(B_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("B_reindex_shared_wmma.matrix_b"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) @@ -914,45 +813,56 @@ def padded_matmul_relu_0(A: T.Buffer((127, 127), "float16"), B: T.Buffer((127, 1 v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0_3 * 2 + ax0_0_4) v1_o = T.axis.spatial(8, ax1_0_4 + ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0_3) v2_o = T.axis.reduce(8, ax2_0_0 * 8 + ax2_0_1 * 2 + ax2_0_2) - T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) + T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax0_1, ax1_1 in T.grid(16, 16): with T.block("C_init"): v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads() - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init]) - C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i_init, v1_i_init]) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i_init, v1_i_init] = T.float32(0) for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): with T.block("C"): v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) - C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32") - for ax0_0, ax1_0 in T.grid(2, 1): - with T.block("C_reindex_shared_wmma.accumulator_o"): - v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0) - v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("C_reindex_shared_wmma.accumulator"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0_ax1_fused in T.serial(1024): - with T.block("C_reindex_shared"): - T.where(ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 32 < 127 and ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32 < 127) - v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 32) - v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32) - T.reads(C_reindex_shared[v0, v1]) - T.writes(compute[v0, v1]) - T.block_attr({"meta_schedule.cooperative_fetch":4}) - compute[v0, v1] = T.max(C_reindex_shared[v0, v1], T.float32(0)) + T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i] + T.Cast("float32", A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + for ax2 in range(2): + for ax0_ax1_fused in T.thread_binding(2, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 1): + with T.block("C_reindex_shared_wmma.accumulator_o"): + v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused // 2) + v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_fused) + v2 = T.axis.spatial(2, ax2 + ax2_1) + v3 = T.axis.spatial(1, ax3) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) + T.writes(C_reindex_shared[v0, v1, v2, v3, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("C_reindex_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) + T.writes(C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i]) + C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(512): + with T.block("C_reindex_shared"): + v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused // 2 + 0) + v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_ax3_ax4_ax5_fused % 512 // 256) + v2 = T.axis.spatial(2, ax2) + v3 = T.axis.spatial(1, 0) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.where(ax0_0_0_ax1_0_0_fused // 2 * 32 + ax2 * 16 + ax0_ax1_ax3_ax4_ax5_fused % 256 // 16 < 127 and ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_ax3_ax4_ax5_fused % 512 // 256 * 16 + ax0_ax1_ax3_ax4_ax5_fused % 16 < 127) + T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16]) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) + compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16] = T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0)) + # fmt: on decision_0 = [ @@ -994,25 +904,25 @@ def test_conv_1x1(): @T.prim_func def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer((1, 1, 64, 64), "float16"), conv2d_nhwc: T.Buffer((1, 16, 16, 64), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - conv2d_nhwc_reindex_shared = T.alloc_buffer([256, 64], dtype="float32", scope="shared") - conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer([256, 64], dtype="float32", scope="wmma.accumulator") - PadInput_reindex_shared = T.alloc_buffer([256, 64], dtype="float16", scope="shared") - weight_reindex_shared = T.alloc_buffer([1, 1, 64, 64], dtype="float16", scope="shared") - PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer([256, 64], dtype="float16", scope="wmma.matrix_a") - weight_reindex_shared_wmma_matrix_b = T.alloc_buffer([1, 1, 64, 64], dtype="float16", scope="wmma.matrix_b") + conv2d_nhwc_reindex_shared = T.alloc_buffer((16, 4, 1, 1, 16, 16), scope="shared") + conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((16, 4, 1, 1, 16, 16), scope="wmma.accumulator") + PadInput_reindex_shared = T.alloc_buffer((256, 64), "float16", scope="shared") + weight_reindex_shared = T.alloc_buffer((1, 1, 64, 64), "float16", scope="shared") + PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer((256, 64), "float16", scope="wmma.matrix_a") + weight_reindex_shared_wmma_matrix_b = T.alloc_buffer((1, 1, 64, 64), "float16", scope="wmma.matrix_b") for ax2_0_0_ax3_0_0_fused in T.thread_binding(16, thread="blockIdx.y"): for ax2_0_1_ax3_0_1_fused in T.thread_binding(2, thread="blockIdx.x"): for ax2_0_2_ax3_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): for ax0_0, ax1_0, ax4_0_0 in T.grid(1, 1, 1): - for ax0_ax1_fused in T.serial(1024): + for ax0_ax1_fused in range(1024): with T.block("PadInput_reindex_shared"): v0 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused // 2 * 32 + ax2_0_1_ax3_0_1_fused * 16 + ax0_ax1_fused // 64) v1 = T.axis.spatial(64, ax0_ax1_fused % 64) T.reads(inputs[v0 // 256, v0 // 16, v0 % 16, v1]) T.writes(PadInput_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":1}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 1}) PadInput_reindex_shared[v0, v1] = inputs[v0 // 256, v0 // 16, v0 % 16, v1] - for ax0_ax1_ax2_ax3_fused in T.serial(2048): + for ax0_ax1_ax2_ax3_fused in range(2048): with T.block("weight_reindex_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(1, 0) @@ -1020,16 +930,16 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer( v3 = T.axis.spatial(64, ax2_0_0_ax3_0_0_fused % 2 * 32 + ax0_ax1_ax2_ax3_fused % 32) T.reads(weight[v0, v1, v2, v3]) T.writes(weight_reindex_shared[v0, v1, v2, v3]) - T.block_attr({"buffer_dim_align":[[0, 2, 32, 8]], "meta_schedule.cooperative_fetch":4}) + T.block_attr({"buffer_dim_align": [[0, 2, 32, 8]], "meta_schedule.cooperative_fetch": 4}) weight_reindex_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3] for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 1): for ax0_0_1, ax1_0_1 in T.grid(1, 4): with T.block("PadInput_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused + ax0_0_1) v1_o = T.axis.spatial(4, ax1_0_1) - T.reads(PadInput_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a_shared"}) + T.reads(PadInput_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) for ax0_1_1, ax1_1_1 in T.grid(16, 16): with T.block("PadInput_reindex_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1_1, ax1_1_1]) @@ -1040,9 +950,9 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer( with T.block("weight_reindex_shared_wmma.matrix_b_o"): v0, v1, v2_o = T.axis.remap("SSS", [ax0, ax1, ax2_0]) v3_o = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused + ax3_0) - T.reads(weight_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.writes(weight_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b_shared"}) + T.reads(weight_reindex_shared[v0, v1, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) + T.writes(weight_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) for ax2_1, ax3_1 in T.grid(16, 16): with T.block("weight_reindex_shared_wmma.matrix_b"): v2_i, v3_i = T.axis.remap("SS", [ax2_1, ax3_1]) @@ -1056,44 +966,53 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer( v2_o = T.axis.spatial(16, ax2_0_4 + ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused + ax2_0_3) v3_o = T.axis.spatial(4, ax3_0_4 + ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused + ax3_0_3) v4_o = T.axis.reduce(4, ax4_0_0 * 4 + ax4_0_1 * 4 + ax4_0_2) - T.reads(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 : v2_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 : v4_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) + T.reads(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16:v2_o * 16 + 16, v4_o * 16:v4_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16:v4_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax2_1, ax3_1 in T.grid(16, 16): with T.block("conv2d_nhwc_init"): v2_i_init, v3_i_init = T.axis.remap("SS", [ax2_1, ax3_1]) T.reads() - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init]) - conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init] = T.float32(0) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i_init, v3_i_init]) + conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i_init, v3_i_init] = T.float32(0) for ax2_1, ax3_1, ax4_1 in T.grid(16, 16, 16): with T.block("conv2d_nhwc"): v2_i, v3_i, v4_i = T.axis.remap("SSR", [ax2_1, ax3_1, ax4_1]) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i], PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 + v4_i, v3_o * 16 + v3_i]) - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i]) - T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) - conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] + T.cast(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], "float32") * T.cast(weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 + v4_i, v3_o * 16 + v3_i], "float32") - for ax0_0, ax1_0 in T.grid(1, 1): - with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): - v0_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused + ax0_0) - v1_o = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused + ax1_0) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0_ax1_fused in T.serial(512): - with T.block("conv2d_nhwc_reindex_shared"): - v0 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused // 2 * 32 + ax2_0_1_ax3_0_1_fused * 16 + ax0_ax1_fused // 32) - v1 = T.axis.spatial(64, ax2_0_0_ax3_0_0_fused % 2 * 32 + ax0_ax1_fused % 32) - T.reads(conv2d_nhwc_reindex_shared[v0, v1]) - T.writes(conv2d_nhwc[v0 // 256, v0 // 16, v0 % 16, v1]) - T.block_attr({"meta_schedule.cooperative_fetch":2}) - conv2d_nhwc[v0 // 256, v0 // 16, v0 % 16, v1] = conv2d_nhwc_reindex_shared[v0, v1] + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i, v3_i], PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 + v4_i, v3_o * 16 + v3_i]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i, v3_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i, v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i, v3_i] + T.Cast("float32", PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i]) * T.Cast("float32", weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 + v4_i, v3_o * 16 + v3_i]) + for ax2 in range(1): + for ax0_ax1_fused in T.thread_binding(2, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 1): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): + v0 = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused) + v1 = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax0_ax1_fused) + v2, v3 = T.axis.remap("SS", [ax2_1, ax3]) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) + T.writes(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) + T.writes(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4_i, v5_i]) + conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(512): + with T.block("conv2d_nhwc_reindex_shared"): + v0 = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused) + v1 = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax0_ax1_ax3_ax4_ax5_fused // 256) + v2 = T.axis.spatial(1, ax2) + v3 = T.axis.spatial(1, 0) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.writes(conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16]) + T.block_attr({"meta_schedule.cooperative_fetch": 2}) + conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16] = conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5] # fmt: on decision_0 = [ diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index 1e5fd8843ba3..b9f35ed553e1 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -1124,7 +1124,7 @@ def ref(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")): B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") with T.block(): T.reads(A[tx, 0]) - T.writes(B[0, tx, 0]) + T.writes(B[T.FloorMod(0, 2), tx, 0]) with T.attr(0, "async_commit_queue_scope", 0): with T.attr(0, "async_scope", 1): B[T.FloorMod(0, 2), tx, 0] = A[tx, 0] * T.float32(2) @@ -1350,8 +1350,8 @@ def ref(A: T.Buffer((16, 16), "float32"), D: T.Buffer((16, 16), "float32")) -> N B[i % 2, tx, 0] = A[tx, i] * T.float32(2) with T.block(): T.where(i == 1 and i - 1 < 16) - T.reads(B[(i + 1) % 2, tx, 0]) - T.writes(C[(i + 1) % 2, tx, 0]) + T.reads(B[(i - 1) % 2, tx, 0]) + T.writes(C[(i - 1) % 2, tx, 0]) with T.attr(0, "async_commit_queue_scope", 1): with T.attr(0, "async_wait_queue_scope", 0): with T.attr(0, "async_wait_inflight_count", 1): @@ -1366,14 +1366,14 @@ def ref(A: T.Buffer((16, 16), "float32"), D: T.Buffer((16, 16), "float32")) -> N with T.block(): T.where(i + 2 < 16) T.reads(A[tx, i + 2]) - T.writes(B[i % 2, tx, 0]) + T.writes(B[(i + 2) % 2, tx, 0]) with T.attr(0, "async_commit_queue_scope", 0): with T.attr(0, "async_scope", 1): B[(i + 2) % 2, tx, 0] = A[tx, i + 2] * T.float32(2) with T.block(): T.where(i + 2 - 1 < 16) - T.reads(B[(i + 1) % 2, tx, 0]) - T.writes(C[(i + 1) % 2, tx, 0]) + T.reads(B[(i - 1 + 2) % 2, tx, 0]) + T.writes(C[(i - 1 + 2) % 2, tx, 0]) with T.attr(0, "async_commit_queue_scope", 1): with T.attr(0, "async_wait_queue_scope", 0): with T.attr(0, "async_wait_inflight_count", 1): @@ -1394,8 +1394,8 @@ def ref(A: T.Buffer((16, 16), "float32"), D: T.Buffer((16, 16), "float32")) -> N for i in T.unroll(2): with T.block(): T.where(i + 16 - 1 < 16) - T.reads(B[(i + 1) % 2, tx, 0]) - T.writes(C[(i + 1) % 2, tx, 0]) + T.reads(B[(i - 1 + 16) % 2, tx, 0]) + T.writes(C[(i - 1 + 16) % 2, tx, 0]) with T.attr(0, "async_commit_queue_scope", 1): with T.attr(0, "async_wait_queue_scope", 0): with T.attr(0, "async_wait_inflight_count", 0 - i):