From 88c2e2d17972044d39131354a0d3f502e1464267 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 19 Sep 2022 16:48:09 -0700 Subject: [PATCH 1/2] [TIR, MetaSchedule] Preserve unit block iters for auto-tensorization * Update schedule primitives (ReIndex, TransformBlockLayout) to preserve unit iters. Added test cases. * Allow workloads with unit dimensions to be detected during auto-tensorization pattern marching. This allows padding to be added for tensorizing such workloads. --- src/tir/ir/stmt.cc | 2 +- src/tir/schedule/analysis.h | 12 + src/tir/schedule/analysis/analysis.cc | 9 + src/tir/schedule/ir_comparator.cc | 4 +- .../schedule/primitive/cache_read_write.cc | 23 +- .../primitive/layout_transformation.cc | 42 ++- src/tir/schedule/transform.cc | 25 +- src/tir/schedule/transform.h | 9 +- ...test_meta_schedule_schedule_rule_mlt_tc.py | 356 +++++++++--------- .../unittest/test_tir_schedule_analysis.py | 5 +- .../unittest/test_tir_schedule_reindex.py | 90 ++++- .../test_tir_schedule_transform_layout.py | 73 +++- 12 files changed, 407 insertions(+), 243 deletions(-) diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index e21d014fe185..8f2a7b4ffe5b 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -831,7 +831,7 @@ BufferRegion BufferRegion::FromPoint(Buffer buffer, Array indices) { region.push_back( Range::FromMinExtent(ramp_index->base, ramp_index->stride * ramp_index->lanes)); } else { - region.push_back(Range::FromMinExtent(index, 1)); + region.push_back(Range::FromMinExtent(index, make_const(index.dtype(), 1))); } } return BufferRegion(buffer, region); diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 57165fd08ad4..7df991826728 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -724,6 +724,18 @@ Optional> CheckTrivialBufferIndices(const T& buffer_access) { return indices; } +/*! + * \brief Simplify non-trivial expressions + * \param expr The expression to be simplified + * \param analyzer The analyzer + * \return The simplified expression + * + * During scheduling, we often need preserve block iters in trivial expressions that can be + * simplified to constant values for further scheduling and analysis because simplifing away the + * block iters may result in loss of information for further analysis. + */ +PrimExpr SimplifyNonTrivialExpr(const PrimExpr& expr, arith::Analyzer* analyzer); + /*! \brief Necessary information used for tensorization */ class TensorizeInfoNode : public Object { public: diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 294826a1f6b9..384d006562f0 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1646,6 +1646,15 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // } } +PrimExpr SimplifyNonTrivialExpr(const PrimExpr& expr, arith::Analyzer* analyzer) { + auto simplified = analyzer->Simplify(expr); + if (simplified->IsInstance()) { + return expr; + } else { + return simplified; + } +} + TVM_REGISTER_NODE_TYPE(TensorizeInfoNode); /*! \brief Auxiliary data structure of information extracted from tensor intrin description */ diff --git a/src/tir/schedule/ir_comparator.cc b/src/tir/schedule/ir_comparator.cc index bfd394f24de7..648305d3655d 100644 --- a/src/tir/schedule/ir_comparator.cc +++ b/src/tir/schedule/ir_comparator.cc @@ -444,8 +444,8 @@ bool AutoTensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) { return false; } std::vector lhs_indices; - for (const auto& index : lhs->indices) { - lhs_indices.push_back(analyzer_.Simplify(index)); + for (const PrimExpr& index : lhs->indices) { + lhs_indices.push_back(SimplifyNonTrivialExpr(index, &analyzer_)); } auto is_scalar_access = [](const Array& indices, PrimExpr index) { diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 489308ae8c0f..e03b1058d4ef 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -188,8 +188,6 @@ Block MakeReIndexStage(const Block& block, CacheStageInfo* info, Array new_block_iters; // the substition map from the original block iter to the iters of the reindex block std::unordered_map block_var_replace_map; - // block access region of reindexed buffer and target buffer - Region reindex_region, target_region; // indices to access the reindex buffer and the target buffer Array reindex_indices, target_indices; @@ -201,7 +199,7 @@ Block MakeReIndexStage(const Block& block, CacheStageInfo* info, Var var("v" + std::to_string(new_block_iters.size()), iter->var->dtype); bool used = covered.count(iter->var); if (used) { - new_block_iters.push_back(IterVar(/*dom=*/used ? iter->dom : Range::FromMinExtent(0, 1), + new_block_iters.push_back(IterVar(/*dom=*/iter->dom, /*var=*/var, /*IterVarType=*/kDataPar)); } else { @@ -209,16 +207,11 @@ Block MakeReIndexStage(const Block& block, CacheStageInfo* info, } if (used) { reindex_indices.push_back(var); - reindex_region.push_back(Range::FromMinExtent(var, IntImm(var->dtype, 1))); } block_var_replace_map[iter->var] = var; } // Step 2: Replace the original block iters with the new block iters - BufferRegion buffer_region = buffer_index_type == BufferIndexType::kWrite - ? block->writes[buffer_index] - : block->reads[buffer_index]; - target_region = Substitute(buffer_region->region, block_var_replace_map); for (const PrimExpr& index : original_indices) { target_indices.push_back(Substitute(index, block_var_replace_map)); } @@ -232,13 +225,9 @@ Block MakeReIndexStage(const Block& block, CacheStageInfo* info, Array dst_indices{nullptr}; if (buffer_index_type == BufferIndexType::kWrite) { - src_region = reindex_region; - dst_region = target_region; src_indices = reindex_indices; dst_indices = target_indices; } else { - src_region = target_region; - dst_region = reindex_region; src_indices = target_indices; dst_indices = reindex_indices; } @@ -246,11 +235,9 @@ Block MakeReIndexStage(const Block& block, CacheStageInfo* info, // Create the body block Block new_block( /*iter_vars=*/new_block_iters, - /*reads=*/ - {BufferRegion(info->read_buffer, src_region)}, - /*writes=*/ - {BufferRegion(info->write_buffer, dst_region)}, - /*name_hint=*/buffer_region->buffer->name + "_reindex", + /*reads=*/{BufferRegion::FromPoint(info->read_buffer, src_indices)}, + /*writes=*/{BufferRegion::FromPoint(info->write_buffer, dst_indices)}, + /*name_hint=*/info->write_buffer->name + "_reindex", /*body=*/ BufferStore(info->write_buffer, BufferLoad(info->read_buffer, src_indices), dst_indices)); @@ -1169,7 +1156,7 @@ StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_inde analyzer.Bind(iter->var, iter->dom); } original_indices.MutateByApply( - [&analyzer](const PrimExpr& expr) { return analyzer.Simplify(expr); }); + [&analyzer](const PrimExpr& expr) { return SimplifyNonTrivialExpr(expr, &analyzer); }); // Collect block iters appearing in the original_indices std::unordered_set covered; diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 025723e1793d..d14ea317f847 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -699,7 +699,9 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { void RewriteBufferAccess(Buffer* buffer, Array* indices) { *buffer = new_buffer_; - *indices = index_map_->MapIndices(*indices, analyzer_); + *indices = index_map_->MapIndices(*indices); + (*indices).MutateByApply( + [&](const PrimExpr& e) { return SimplifyNonTrivialExpr(e, analyzer_); }); } using Parent = arith::IRMutatorWithAnalyzer; @@ -1113,7 +1115,7 @@ class IndexMapNotApplicableToBlockIterError : public ScheduleError { IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + Array LocationsOfInterest() const final { return {block_}; } private: IRModule mod_; @@ -1194,14 +1196,6 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, Array transformed_block_iters = index_map->MapIndices(block_vars); Array new_block_iter_range = index_map->MapShape(block_iter_range_array); - auto iter_map = arith::DetectIterMap( - /*indices=*/transformed_block_iters, /*input_iters=*/block_iter_dom, /*predicate=*/Bool(true), - /*check_level=*/arith::IterMapLevel::Bijective, &analyzer, - /*simplify_trivial_iterators=*/true); - if (iter_map->indices.empty()) { - throw NotBijectiveAffineIndexMapError(self->mod, index_map); - } - // Step 5: Create the new block after transformation. // Step 5.1: Create new block iters. After applying the IndexMap f to block iters ax_0, ..., ax_n, @@ -1221,18 +1215,28 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, // Step 5.2: Update the block body. Use the inverse map f^{-1} to replace the original block iters // in the body. + Map inverse_subst_map; + // Construct the inverse map + { + Array initial_ranges; + for (const PrimExpr& extent : block_iter_range_array) { + initial_ranges.push_back(Range::FromMinExtent(make_const(extent.dtype(), 0), extent)); + } + IndexMap inverse_index_map{nullptr}; + try { + inverse_index_map = index_map.Inverse(initial_ranges); + } catch (...) { + throw NotBijectiveAffineIndexMapError(self->mod, index_map); + } - auto inverse_map = arith::InverseAffineIterMap(iter_map->indices, new_block_vars); - // Trivial block iters will be simplified in DetectIterMap, they should be mapped to constant - // zero. - for (const auto& iter_var : block_ptr->iter_vars) { - if (inverse_map.find(iter_var->var) == inverse_map.end()) { - ICHECK(is_one(iter_var->dom->extent)); - inverse_map.Set(iter_var->var, 0); + Array inversed_new_block_vars = inverse_index_map->MapIndices( + new_block_vars); // old block vars written in terms of new block vars + + for (int i = 0, n = block_vars.size(); i < n; ++i) { + inverse_subst_map.Set(Downcast(block_vars[i]), inversed_new_block_vars[i]); } } - - Block new_block = Downcast(Substitute(GetRef(block_ptr), inverse_map)); + Block new_block = Downcast(Substitute(GetRef(block_ptr), inverse_subst_map)); new_block.CopyOnWrite()->iter_vars = new_block_iters; new_block = Downcast(BlockBufferAccessSimplifier::Simplify(new_block, &analyzer)); diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index 7a720fe3eae2..e91c5d142c04 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -359,14 +359,25 @@ void BlockBufferAccessSimplifier::SimplifyAccessRegion(Array* old_ auto fmutate = [this](const BufferRegion& buffer_region) { std::vector new_buffer_region; for (const auto& range : buffer_region->region) { - new_buffer_region.push_back(Range::FromMinExtent(analyzer_->Simplify(range->min), - analyzer_->Simplify(range->extent))); + if (is_one(range->extent) && range->min->IsInstance()) { + new_buffer_region.push_back(Range::FromMinExtent( + SimplifyNonTrivialExpr(range->min, analyzer_), make_const(range->min.dtype(), 1))); + } else { + new_buffer_region.push_back( + Range::FromMinExtent(SimplifyNonTrivialExpr(range->min, analyzer_), + SimplifyNonTrivialExpr(range->extent, analyzer_))); + } } return BufferRegion(buffer_region->buffer, new_buffer_region); }; (*old_access_regions).MutateByApply(fmutate); } +void BlockBufferAccessSimplifier::SimplifyBufferIndices(Array* indices) { + (*indices).MutateByApply( + [this](const PrimExpr& expr) { return SimplifyNonTrivialExpr(expr, analyzer_); }); +} + Stmt BlockBufferAccessSimplifier::VisitStmt_(const BlockNode* op) { Block block = Downcast(arith::IRMutatorWithAnalyzer::VisitStmt_(op)); auto* n = block.CopyOnWrite(); @@ -376,13 +387,15 @@ Stmt BlockBufferAccessSimplifier::VisitStmt_(const BlockNode* op) { } Stmt BlockBufferAccessSimplifier::VisitStmt_(const BufferStoreNode* op) { - auto node = Downcast(arith::IRMutatorWithAnalyzer::VisitStmt_(op)); - return VisitBufferAccess(std::move(node)); + BufferStore node = Downcast(arith::IRMutatorWithAnalyzer::VisitStmt_(op)); + SimplifyBufferIndices(&node.CopyOnWrite()->indices); + return std::move(node); } PrimExpr BlockBufferAccessSimplifier::VisitExpr_(const BufferLoadNode* op) { - auto node = Downcast(arith::IRMutatorWithAnalyzer::VisitExpr_(op)); - return VisitBufferAccess(std::move(node)); + BufferLoad node = Downcast(arith::IRMutatorWithAnalyzer::VisitExpr_(op)); + SimplifyBufferIndices(&node.CopyOnWrite()->indices); + return std::move(node); } } // namespace tir diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h index 2bba13e2bd1c..3593d6b9a444 100644 --- a/src/tir/schedule/transform.h +++ b/src/tir/schedule/transform.h @@ -226,16 +226,11 @@ class BlockBufferAccessSimplifier : public arith::IRMutatorWithAnalyzer { using IRMutatorWithAnalyzer::VisitStmt_; void SimplifyAccessRegion(Array* old_access_regions); + void SimplifyBufferIndices(Array* indices); + Stmt VisitStmt_(const BlockNode* op) final; Stmt VisitStmt_(const BufferStoreNode* op) final; PrimExpr VisitExpr_(const BufferLoadNode* op) final; - - template - Node VisitBufferAccess(Node node) { - node.CopyOnWrite()->indices.MutateByApply( - [this](const PrimExpr& expr) { return analyzer_->Simplify(expr); }); - return node; - } }; } // namespace tir diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py index f7a5ce997edf..a53c1062b98d 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py @@ -365,30 +365,30 @@ def conv2d_0(inputs: T.Buffer[(1, 16, 16, 32), "float16"], weight: T.Buffer[(3, T.reads(inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1]) T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i1_1 and i1_1 < 17 and 1 <= i2_1 and i2_1 < 17, inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1], T.float16(0), dtype="float16") - for ax0_0_ax1_0_0_ax2_0_0_fused in T.thread_binding(2, thread="blockIdx.y"): - for ax0_1_ax1_0_1_ax2_0_1_fused in T.thread_binding(16, thread="blockIdx.x"): - for ax0_2_ax1_0_2_ax2_0_2_fused in T.thread_binding(1, thread="threadIdx.y"): - for ax3_0_0 in T.serial(1): + for ax0_0_0_ax1_0_0_fused in T.thread_binding(2, thread="blockIdx.y"): + for ax0_0_1_ax1_0_1_fused in T.thread_binding(16, thread="blockIdx.x"): + for ax0_0_2_ax1_0_2_fused in T.thread_binding(1, thread="threadIdx.y"): + for ax2_0_0 in T.serial(1): for ax0_ax1_fused in T.serial(4608): with T.block("PadInput_reindex_shared"): - v0 = T.axis.spatial(256, ax0_1_ax1_0_1_ax2_0_1_fused * 16 + ax0_ax1_fused // 288) + v0 = T.axis.spatial(256, ax0_0_1_ax1_0_1_fused * 16 + ax0_ax1_fused // 288) v1 = T.axis.spatial(288, ax0_ax1_fused % 288) - T.reads(PadInput[0, v1 // 96 + v0 // 16, v1 % 96 // 32 + v0 % 16, v1 % 32]) + T.reads(PadInput[v0 // 256, v1 // 96 + v0 // 16, v1 % 96 // 32 + v0 % 16, v1 % 32]) T.writes(PadInput_reindex_shared[v0, v1]) T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":2}) - PadInput_reindex_shared[v0, v1] = PadInput[0, v1 // 96 + v0 // 16, v1 % 96 // 32 + v0 % 16, v1 % 32] + PadInput_reindex_shared[v0, v1] = PadInput[v0 // 256, v1 // 96 + v0 // 16, v1 % 96 // 32 + v0 % 16, v1 % 32] for ax0_ax1_fused in T.serial(4608): with T.block("weight_reindex_shared"): v0 = T.axis.spatial(288, ax0_ax1_fused // 16) - v1 = T.axis.spatial(32, ax0_0_ax1_0_0_ax2_0_0_fused * 16 + ax0_ax1_fused % 16) + v1 = T.axis.spatial(32, ax0_0_0_ax1_0_0_fused * 16 + ax0_ax1_fused % 16) T.reads(weight[v0 // 96, v0 % 96 // 32, v0 % 32, v1]) T.writes(weight_reindex_shared[v0, v1]) T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":8}) weight_reindex_shared[v0, v1] = weight[v0 // 96, v0 % 96 // 32, v0 % 32, v1] - for ax3_0_1 in T.serial(18): + for ax2_0_1 in T.serial(18): for ax0_0, ax1_0 in T.grid(1, 1): with T.block("PadInput_reindex_shared_wmma.matrix_a_o"): - v0_o, v1_o = T.axis.remap("SS", [ax0_1_ax1_0_1_ax2_0_1_fused, ax3_0_1]) + v0_o, v1_o = T.axis.remap("SS", [ax0_0_1_ax1_0_1_fused, ax2_0_1]) T.reads(PadInput_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"}) @@ -400,7 +400,7 @@ def conv2d_0(inputs: T.Buffer[(1, 16, 16, 32), "float16"], weight: T.Buffer[(3, PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0_0, ax1_0 in T.grid(1, 1): with T.block("weight_reindex_shared_wmma.matrix_b_o"): - v0_o, v1_o = T.axis.remap("SS", [ax3_0_1, ax0_0_ax1_0_0_ax2_0_0_fused]) + v0_o, v1_o = T.axis.remap("SS", [ax2_0_1, ax0_0_0_ax1_0_0_fused]) T.reads(weight_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(weight_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"}) @@ -410,32 +410,31 @@ def conv2d_0(inputs: T.Buffer[(1, 16, 16, 32), "float16"], weight: T.Buffer[(3, T.reads(weight_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) T.writes(weight_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) weight_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = weight_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0_3, ax1_0_3, ax2_0_3, ax3_0_2, ax0_4, ax1_0_4, ax2_0_4 in T.grid(1, 1, 1, 1, 1, 1, 1): + for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(1, 1, 1, 1, 1): with T.block("conv2d_nhwc_o"): - v0 = T.axis.spatial(1, 0) - v1_o = T.axis.spatial(16, ax1_0_4 + ax0_1_ax1_0_1_ax2_0_1_fused + ax1_0_3) - v2_o = T.axis.spatial(2, ax0_0_ax1_0_0_ax2_0_0_fused + ax2_0_3 + ax2_0_4) - v3_o = T.axis.reduce(18, ax3_0_0 * 18 + ax3_0_1 + ax3_0_2) - T.reads(PadInput_reindex_shared_wmma_matrix_a[v1_o * 16 : v1_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v3_o * 16 : v3_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16]) - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 : v1_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16]) + v0_o = T.axis.spatial(16, ax0_0_4 + ax0_0_1_ax1_0_1_fused + ax0_0_3) + v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused + ax1_0_3 + ax1_0_4) + v2_o = T.axis.reduce(18, ax2_0_0 * 18 + ax2_0_1 + ax2_0_2) + T.reads(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) with T.init(): - for ax1_1, ax2_1 in T.grid(16, 16): + for ax0_1, ax1_1 in T.grid(16, 16): with T.block("conv2d_nhwc_init"): - v1_i_init, v2_i_init = T.axis.remap("SS", [ax1_1, ax2_1]) + v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads() - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i_init, v2_o * 16 + v2_i_init]) - conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i_init, v2_o * 16 + v2_i_init] = T.float32(0) - for ax1_1, ax2_1, ax3_1 in T.grid(16, 16, 16): + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init]) + conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0) + for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): with T.block("conv2d_nhwc"): - v1_i, v2_i, v3_i = T.axis.remap("SSR", [ax1_1, ax2_1, ax3_1]) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i, v2_o * 16 + v2_i], PadInput_reindex_shared_wmma_matrix_a[v1_o * 16 + v1_i, v3_o * 16 + v3_i], weight_reindex_shared_wmma_matrix_b[v3_o * 16 + v3_i, v2_o * 16 + v2_i]) - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i, v2_o * 16 + v2_i]) + v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], weight_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) - conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i, v2_o * 16 + v2_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i, v2_o * 16 + v2_i] + T.cast(PadInput_reindex_shared_wmma_matrix_a[v1_o * 16 + v1_i, v3_o * 16 + v3_i], "float32") * T.cast(weight_reindex_shared_wmma_matrix_b[v3_o * 16 + v3_i, v2_o * 16 + v2_i], "float32") + conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(weight_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32") for ax0_0, ax1_0 in T.grid(1, 1): with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): - v0_o, v1_o = T.axis.remap("SS", [ax0_1_ax1_0_1_ax2_0_1_fused, ax0_0_ax1_0_0_ax2_0_0_fused]) + v0_o, v1_o = T.axis.remap("SS", [ax0_0_1_ax1_0_1_fused, ax0_0_0_ax1_0_0_fused]) T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"}) @@ -447,15 +446,14 @@ def conv2d_0(inputs: T.Buffer[(1, 16, 16, 32), "float16"], weight: T.Buffer[(3, conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0, ax1 in T.grid(16, 16): with T.block("conv2d_nhwc_reindex_shared"): - v0 = T.axis.spatial(256, ax0_1_ax1_0_1_ax2_0_1_fused * 16 + ax0) - v1 = T.axis.spatial(32, ax0_0_ax1_0_0_ax2_0_0_fused * 16 + ax1) + v0 = T.axis.spatial(256, ax0_0_1_ax1_0_1_fused * 16 + ax0) + v1 = T.axis.spatial(32, ax0_0_0_ax1_0_0_fused * 16 + ax1) T.reads(conv2d_nhwc_reindex_shared[v0, v1]) - T.writes(conv2d_nhwc[0, v0 // 16, v0 % 16, v1]) + T.writes(conv2d_nhwc[v0 // 256, v0 // 16, v0 % 16, v1]) T.block_attr({"meta_schedule.cooperative_fetch":3}) - conv2d_nhwc[0, v0 // 16, v0 % 16, v1] = conv2d_nhwc_reindex_shared[v0, v1] + conv2d_nhwc[v0 // 256, v0 // 16, v0 % 16, v1] = conv2d_nhwc_reindex_shared[v0, v1] # fmt: on decision_0 = [ - ("SamplePerfectTile", [1, 1, 1, 1, 1]), ("SamplePerfectTile", [1, 16, 1, 1, 1]), ("SamplePerfectTile", [2, 1, 1, 1, 1]), ("SamplePerfectTile", [1, 18, 1]), @@ -490,145 +488,8 @@ def conv2d_0(inputs: T.Buffer[(1, 16, 16, 32), "float16"], weight: T.Buffer[(3, expected_decisions=[decision_0], ) - -def test_conv2d_more_intrin(): - # test adding inapplicable tensor intrinsics doesn't change the search space - # fmt: off - @T.prim_func - def conv2d_more_intrin_0(inputs: T.Buffer[(1, 16, 16, 32), "float16"], weight: T.Buffer[(3, 3, 32, 32), "float16"], conv2d_nhwc: T.Buffer[(1, 16, 16, 32), "float32"]) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body - # with T.block("root") - PadInput = T.alloc_buffer([1, 18, 18, 32], dtype="float16") - conv2d_nhwc_reindex_shared = T.alloc_buffer([256, 32], dtype="float32", scope="shared") - conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer([256, 32], dtype="float32", scope="wmma.accumulator") - PadInput_reindex_shared = T.alloc_buffer([256, 288], dtype="float16", scope="shared") - weight_reindex_shared = T.alloc_buffer([288, 32], dtype="float16", scope="shared") - PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer([256, 288], dtype="float16", scope="wmma.matrix_a") - weight_reindex_shared_wmma_matrix_b = T.alloc_buffer([288, 32], dtype="float16", scope="wmma.matrix_b") - for i0, i1, i2, i3 in T.grid(1, 18, 18, 32): - with T.block("PadInput"): - i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1]) - T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) - PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i1_1 and i1_1 < 17 and 1 <= i2_1 and i2_1 < 17, inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1], T.float16(0), dtype="float16") - for ax0_0_ax1_0_0_ax2_0_0_fused in T.thread_binding(4, thread="blockIdx.y"): - for ax0_1_ax1_0_1_ax2_0_1_fused in T.thread_binding(4, thread="blockIdx.x"): - for ax0_2_ax1_0_2_ax2_0_2_fused in T.thread_binding(1, thread="threadIdx.y"): - for ax3_0_0 in T.serial(3): - for ax0_ax1_fused in T.serial(1536): - with T.block("PadInput_reindex_shared"): - v0 = T.axis.spatial(256, ax0_0_ax1_0_0_ax2_0_0_fused * 64 + ax0_1_ax1_0_1_ax2_0_1_fused * 16 + ax0_ax1_fused // 96) - v1 = T.axis.spatial(288, ax3_0_0 * 96 + ax0_ax1_fused % 96) - T.reads(PadInput[0, v1 // 96 + v0 // 16, v1 % 96 // 32 + v0 % 16, v1 % 32]) - T.writes(PadInput_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":8}) - PadInput_reindex_shared[v0, v1] = PadInput[0, v1 // 96 + v0 // 16, v1 % 96 // 32 + v0 % 16, v1 % 32] - for ax0_ax1_fused in T.serial(3072): - with T.block("weight_reindex_shared"): - v0 = T.axis.spatial(288, ax3_0_0 * 96 + ax0_ax1_fused // 32) - v1 = T.axis.spatial(32, ax0_ax1_fused % 32) - T.reads(weight[v0 // 96, v0 % 96 // 32, v0 % 32, v1]) - T.writes(weight_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":8}) - weight_reindex_shared[v0, v1] = weight[v0 // 96, v0 % 96 // 32, v0 % 32, v1] - for ax3_0_1 in T.serial(2): - for ax0_0, ax1_0 in T.grid(1, 3): - with T.block("PadInput_reindex_shared_wmma.matrix_a_o"): - v0_o = T.axis.spatial(16, ax0_0_ax1_0_0_ax2_0_0_fused * 4 + ax0_1_ax1_0_1_ax2_0_1_fused) - v1_o = T.axis.spatial(18, ax3_0_0 * 6 + ax3_0_1 * 3 + ax1_0) - T.reads(PadInput_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("PadInput_reindex_shared_wmma.matrix_a"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0_0, ax1_0 in T.grid(3, 2): - with T.block("weight_reindex_shared_wmma.matrix_b_o"): - v0_o = T.axis.spatial(18, ax3_0_0 * 6 + ax3_0_1 * 3 + ax0_0) - v1_o = T.axis.spatial(2, ax1_0) - T.reads(weight_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(weight_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("weight_reindex_shared_wmma.matrix_b"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(weight_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(weight_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - weight_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = weight_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0_3, ax1_0_3, ax2_0_3, ax3_0_2, ax0_4, ax1_0_4, ax2_0_4 in T.grid(1, 1, 2, 3, 1, 1, 1): - with T.block("conv2d_nhwc_o"): - v0 = T.axis.spatial(1, 0) - v1_o = T.axis.spatial(16, ax1_0_4 + ax0_0_ax1_0_0_ax2_0_0_fused * 4 + ax0_1_ax1_0_1_ax2_0_1_fused + ax1_0_3) - v2_o = T.axis.spatial(2, ax2_0_4 + ax2_0_3) - v3_o = T.axis.reduce(18, ax3_0_0 * 6 + ax3_0_1 * 3 + ax3_0_2) - T.reads(PadInput_reindex_shared_wmma_matrix_a[v1_o * 16 : v1_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v3_o * 16 : v3_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16]) - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 : v1_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) - with T.init(): - for ax1_1, ax2_1 in T.grid(16, 16): - with T.block("conv2d_nhwc_init"): - v1_i_init, v2_i_init = T.axis.remap("SS", [ax1_1, ax2_1]) - T.reads() - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i_init, v2_o * 16 + v2_i_init]) - conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i_init, v2_o * 16 + v2_i_init] = T.float32(0) - for ax1_1, ax2_1, ax3_1 in T.grid(16, 16, 16): - with T.block("conv2d_nhwc"): - v1_i, v2_i, v3_i = T.axis.remap("SSR", [ax1_1, ax2_1, ax3_1]) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i, v2_o * 16 + v2_i], PadInput_reindex_shared_wmma_matrix_a[v1_o * 16 + v1_i, v3_o * 16 + v3_i], weight_reindex_shared_wmma_matrix_b[v3_o * 16 + v3_i, v2_o * 16 + v2_i]) - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i, v2_o * 16 + v2_i]) - T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) - conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i, v2_o * 16 + v2_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i, v2_o * 16 + v2_i] + T.cast(PadInput_reindex_shared_wmma_matrix_a[v1_o * 16 + v1_i, v3_o * 16 + v3_i], "float32") * T.cast(weight_reindex_shared_wmma_matrix_b[v3_o * 16 + v3_i, v2_o * 16 + v2_i], "float32") - for ax0_0, ax1_0 in T.grid(1, 2): - with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): - v0_o = T.axis.spatial(16, ax0_0_ax1_0_0_ax2_0_0_fused * 4 + ax0_1_ax1_0_1_ax2_0_1_fused) - v1_o = T.axis.spatial(2, ax1_0) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0, ax1 in T.grid(16, 32): - with T.block("conv2d_nhwc_reindex_shared"): - v0 = T.axis.spatial(256, ax0_0_ax1_0_0_ax2_0_0_fused * 64 + ax0_1_ax1_0_1_ax2_0_1_fused * 16 + ax0) - v1 = T.axis.spatial(32, ax1) - T.reads(conv2d_nhwc_reindex_shared[v0, v1]) - T.writes(conv2d_nhwc[0, v0 // 16, v0 % 16, v1]) - T.block_attr({"meta_schedule.cooperative_fetch":3}) - conv2d_nhwc[0, v0 // 16, v0 % 16, v1] = conv2d_nhwc_reindex_shared[v0, v1] - # fmt: on - decision_0 = [ - ("SamplePerfectTile", [1, 1, 1, 1, 1]), - ("SamplePerfectTile", [4, 4, 1, 1, 1]), - ("SamplePerfectTile", [1, 1, 1, 2, 1]), - ("SamplePerfectTile", [3, 2, 3]), - ("SampleCategorical", 2), - ("SampleCategorical", 3), - ("SampleCategorical", 3), - ] - - mod = te.create_prim_func( - te_workload.conv2d_nhwc( - N=1, - H=16, - W=16, - CI=32, - CO=32, - kernel_size=3, - stride=1, - padding=1, - in_dtype="float16", - out_dtype="float32", - ) - ) + # Test adding inapplicable tensor intrinsics doesn't change the search space + # This test case uses the same workload, decision and the expected sketch as above actual = ms.TuneContext( mod=mod, target=tvm.target.Target("cuda"), @@ -643,7 +504,7 @@ def conv2d_more_intrin_0(inputs: T.Buffer[(1, 16, 16, 32), "float16"], weight: T check_sketches( mod, sketches=actual, - expected_mods=[conv2d_more_intrin_0], + expected_mods=[conv2d_0], expected_decisions=[decision_0], ) @@ -1088,5 +949,154 @@ def padded_matmul_relu_0(A: T.Buffer[(127, 127), "float16"], B: T.Buffer[(127, 1 ) +def test_conv_1x1(): + # fmt: off + @T.prim_func + def conv2d_1x1_0(inputs: T.Buffer[(1, 16, 16, 64), "float16"], weight: T.Buffer[(1, 1, 64, 64), "float16"], conv2d_nhwc: T.Buffer[(1, 16, 16, 64), "float32"]) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + conv2d_nhwc_reindex_shared = T.alloc_buffer([256, 64], dtype="float32", scope="shared") + conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer([256, 64], dtype="float32", scope="wmma.accumulator") + PadInput_reindex_shared = T.alloc_buffer([256, 64], dtype="float16", scope="shared") + weight_reindex_shared = T.alloc_buffer([1, 1, 64, 64], dtype="float16", scope="shared") + PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer([256, 64], dtype="float16", scope="wmma.matrix_a") + weight_reindex_shared_wmma_matrix_b = T.alloc_buffer([1, 1, 64, 64], dtype="float16", scope="wmma.matrix_b") + for ax2_0_0_ax3_0_0_fused in T.thread_binding(16, thread="blockIdx.y"): + for ax2_0_1_ax3_0_1_fused in T.thread_binding(2, thread="blockIdx.x"): + for ax2_0_2_ax3_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): + for ax0_0, ax1_0, ax4_0_0 in T.grid(1, 1, 1): + for ax0_ax1_fused in T.serial(1024): + with T.block("PadInput_reindex_shared"): + v0 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused // 2 * 32 + ax2_0_1_ax3_0_1_fused * 16 + ax0_ax1_fused // 64) + v1 = T.axis.spatial(64, ax0_ax1_fused % 64) + T.reads(inputs[v0 // 256, v0 // 16, v0 % 16, v1]) + T.writes(PadInput_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":1}) + PadInput_reindex_shared[v0, v1] = inputs[v0 // 256, v0 // 16, v0 % 16, v1] + for ax0_ax1_ax2_ax3_fused in T.serial(2048): + with T.block("weight_reindex_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(1, 0) + v2 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused // 32) + v3 = T.axis.spatial(64, ax2_0_0_ax3_0_0_fused % 2 * 32 + ax0_ax1_ax2_ax3_fused % 32) + T.reads(weight[v0, v1, v2, v3]) + T.writes(weight_reindex_shared[v0, v1, v2, v3]) + T.block_attr({"buffer_dim_align":[[0, 2, 32, 8]], "meta_schedule.cooperative_fetch":4}) + weight_reindex_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3] + for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 1): + for ax0_0_1, ax1_0_1 in T.grid(1, 4): + with T.block("PadInput_reindex_shared_wmma.matrix_a_o"): + v0_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused) + v1_o = T.axis.spatial(4, ax1_0_1) + T.reads(PadInput_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"}) + for ax0_1_1, ax1_1_1 in T.grid(16, 16): + with T.block("PadInput_reindex_shared_wmma.matrix_a"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1_1, ax1_1_1]) + T.reads(PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 4, 1): + with T.block("weight_reindex_shared_wmma.matrix_b_o"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(1, 0) + v2_o = T.axis.spatial(4, ax2_0) + v3_o = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused) + T.reads(weight_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) + T.writes(weight_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"}) + for ax2_1, ax3_1 in T.grid(16, 16): + with T.block("weight_reindex_shared_wmma.matrix_b"): + v2_i, v3_i = T.axis.remap("SS", [ax2_1, ax3_1]) + T.reads(weight_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) + T.writes(weight_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) + weight_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] = weight_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] + for ax2_0_3, ax3_0_3, ax0_2, ax1_2, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(1, 1, 1, 1, 4, 1, 1): + with T.block("conv2d_nhwc_o"): + v0 = T.axis.reduce(1, 0) + v1 = T.axis.reduce(1, 0) + v2_o = T.axis.spatial(16, ax2_0_4 + ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused + ax2_0_3) + v3_o = T.axis.spatial(4, ax3_0_4 + ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused + ax3_0_3) + v4_o = T.axis.reduce(4, ax4_0_0 * 4 + ax4_0_1 * 4 + ax4_0_2) + T.reads(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 : v2_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 : v4_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) + with T.init(): + for ax2_1, ax3_1 in T.grid(16, 16): + with T.block("conv2d_nhwc_init"): + v2_i_init, v3_i_init = T.axis.remap("SS", [ax2_1, ax3_1]) + T.reads() + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init]) + conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init] = T.float32(0) + for ax2_1, ax3_1, ax4_1 in T.grid(16, 16, 16): + with T.block("conv2d_nhwc"): + v2_i, v3_i, v4_i = T.axis.remap("SSR", [ax2_1, ax3_1, ax4_1]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i], PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 + v4_i, v3_o * 16 + v3_i]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i]) + T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) + conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] + T.cast(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], "float32") * T.cast(weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 + v4_i, v3_o * 16 + v3_i], "float32") + for ax0_0, ax1_0 in T.grid(1, 1): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): + v0_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused) + v1_o = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0, ax1 in T.grid(16, 32): + with T.block("conv2d_nhwc_reindex_shared"): + v0 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused // 2 * 32 + ax2_0_1_ax3_0_1_fused * 16 + ax0) + v1 = T.axis.spatial(64, ax2_0_0_ax3_0_0_fused % 2 * 32 + ax1) + T.reads(conv2d_nhwc_reindex_shared[v0, v1]) + T.writes(conv2d_nhwc[v0 // 256, v0 // 16, v0 % 16, v1]) + T.block_attr({"meta_schedule.cooperative_fetch":2}) + conv2d_nhwc[v0 // 256, v0 // 16, v0 % 16, v1] = conv2d_nhwc_reindex_shared[v0, v1] + # fmt: on + + decision_0 = [ + ("SamplePerfectTile", [1, 1, 1]), + ("SamplePerfectTile", [1, 1, 1]), + ("SamplePerfectTile", [8, 2, 1, 1, 1]), + ("SamplePerfectTile", [2, 1, 2, 1, 1]), + ("SamplePerfectTile", [1, 1, 4]), + ("SampleCategorical", 1), + ("SampleCategorical", 0), + ("SampleCategorical", 2), + ] + + mod = te.create_prim_func( + te_workload.conv2d_nhwc( + 1, + 16, + 16, + 64, + 64, + 1, + 1, + 0, + in_dtype="float16", + out_dtype="float32", + ) + ) + actual = ms.TuneContext( + mod=mod, + target=tvm.target.Target("cuda"), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="shared")] + + get_rules("cuda", ms.schedule_rule.AutoInline), + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[conv2d_1x1_0], + expected_decisions=[decision_0], + ) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py index 807420ece3ba..e0667da6fe92 100644 --- a/tests/python/unittest/test_tir_schedule_analysis.py +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -360,8 +360,7 @@ def test_get_auto_tensorize_mapping_info_conv2d_unit_batch(): conv2d, "conv2d_nhwc", WMMA_SYNC_16x16x16_f16f16f32_INTRIN, - # unit iter is not mapped - lambda n, h, w, c, rh, rw, rc: (n, h * 16 + w, c, rh * 192 + rw * 64 + rc), + lambda n, h, w, c, rh, rw, rc: (n * 256 + h * 16 + w, c, rh * 192 + rw * 64 + rc), ) @@ -388,7 +387,7 @@ def test_get_auto_tensorize_mapping_info_batch_matmul(b, m, n, k): k, ), ), - (1, 32, 32, None), + (1, 32, 32, lambda n, m, k: (n, m, k)), ], ) def test_get_auto_tensorize_mapping_info_matmul(n, m, k, expected): diff --git a/tests/python/unittest/test_tir_schedule_reindex.py b/tests/python/unittest/test_tir_schedule_reindex.py index 60dcefba631a..53bc726ceaf3 100644 --- a/tests/python/unittest/test_tir_schedule_reindex.py +++ b/tests/python/unittest/test_tir_schedule_reindex.py @@ -76,6 +76,37 @@ def conv2d_nhwc( ) +@T.prim_func +def conv2d_nhwc_reindex_data( + Input: T.Buffer[(1, 224, 224, 3), "float32"], + Weight: T.Buffer[(7, 7, 3, 64), "float32"], + Conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float32"], +) -> None: + PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") + ReindexInput = T.alloc_buffer([1, 112, 112, 7, 7, 3], dtype="float32") + for i0, i1, i2, i3 in T.grid(1, 230, 230, 3): + with T.block("PadInput"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( + ((((i1_1 >= 3) and (i1_1 < 227)) and (i2_1 >= 3)) and (i2_1 < 227)), + Input[i0_1, (i1_1 - 3), (i2_1 - 3), i3_1], + T.float32(0), + dtype="float32", + ) + for i0, i1, i2, i3, i4, i5 in T.grid(1, 112, 112, 7, 7, 3): + with T.block("ReindexInput"): + n, h, w, rh, rw, rc = T.axis.remap("SSSSSS", [i0, i1, i2, i3, i4, i5]) + ReindexInput[n, h, w, rh, rw, rc] = PadInput[n, ((h * 2) + rh), ((w * 2) + rw), rc] + for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3): + with T.block("conv2d_nhwc"): + n, h, w, co, rh, rw, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + with T.init(): + Conv2d_nhwc[n, h, w, co] = T.float32(0) + Conv2d_nhwc[n, h, w, co] = Conv2d_nhwc[n, h, w, co] + ( + ReindexInput[n, h, w, rh, rw, rc] * Weight[rh, rw, rc, co] + ) + + @T.prim_func def conv2d_nhwc_reindex_weight( var_inputs: T.handle, var_weight: T.handle, var_conv2d_nhwc: T.handle @@ -208,6 +239,45 @@ def mixed_dtype_reindex_write( T_matmul_NT[v0, v1] = T_matmul_NT_reindex[v0, v1] +@T.prim_func +def matmul_unit_dim( + A: T.Buffer[(1, 512), "float32"], + B: T.Buffer[(512, 1), "float32"], + C: T.Buffer[(1, 1), "float32"], +) -> None: + for i0, i1, i2 in T.grid(1, 1, 512): + with T.block("matmul"): + i, j, k = T.axis.remap("SSR", [i0, i1, i2]) + T.reads(C[i, j], A[i, k], B[k, j]) + T.writes(C[i, j]) + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + +@T.prim_func +def matmul_unit_dim_reindex_write( + A: T.Buffer[(1, 512), "float32"], + B: T.Buffer[(512, 1), "float32"], + C: T.Buffer[(1, 1), "float32"], +) -> None: + C_reindex = T.alloc_buffer([1, 1], dtype="float32") + for i0, i1, i2 in T.grid(1, 1, 512): + with T.block("matmul"): + i, j, k = T.axis.remap("SSR", [i0, i1, i2]) + T.reads(C_reindex[i, j], A[i, k], B[k, j]) + T.writes(C_reindex[i, j]) + with T.init(): + C_reindex[i, j] = T.float32(0) + C_reindex[i, j] = C_reindex[i, j] + A[i, k] * B[k, j] + for i0, i1 in T.grid(1, 1): + with T.block("C_reindex"): + v0, v1 = T.axis.remap("SS", [i0, i1]) + T.reads(C_reindex[v0, v1]) + T.writes(C[v0, v1]) + C[v0, v1] = C_reindex[v0, v1] + + use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) use_buffer_name = tvm.testing.parameter(by_dict={"buffer_index": False, "buffer_name": True}) @@ -221,7 +291,7 @@ def test_reindex_read_basic(use_block_name, use_buffer_name): verify_trace_roundtrip(sch=sch, mod=transpose_elementwise) -def test_conv2d_reindex_read(use_block_name, use_buffer_name): +def test_conv2d_reindex_weight(use_block_name, use_buffer_name): sch = tir.Schedule(conv2d_nhwc) block = "conv2d_nhwc" if use_block_name else sch.get_block("conv2d_nhwc") buf = "Weight" if use_buffer_name else ("read", 1) @@ -230,6 +300,15 @@ def test_conv2d_reindex_read(use_block_name, use_buffer_name): verify_trace_roundtrip(sch=sch, mod=conv2d_nhwc) +def test_conv2d_reindex_data(use_block_name, use_buffer_name): + sch = tir.Schedule(conv2d_nhwc) + block = "conv2d_nhwc" if use_block_name else sch.get_block("conv2d_nhwc") + buf = "PadInput" if use_buffer_name else ("read", 0) + sch.reindex(block, buf) + tvm.ir.assert_structural_equal(conv2d_nhwc_reindex_data, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=conv2d_nhwc) + + def test_matmul_reindex_write(use_block_name, use_buffer_name): sch = tir.Schedule(matmul) block = "matmul" if use_block_name else sch.get_block("matmul") @@ -256,5 +335,14 @@ def test_reindex_mixed_dtype(use_block_name, use_buffer_name): verify_trace_roundtrip(sch=sch, mod=mixed_dtype) +def test_matmul_unit_dim_reindex_write(use_block_name, use_buffer_name): + sch = tir.Schedule(matmul_unit_dim) + block = "matmul" if use_block_name else sch.get_block("matmul") + buf = "C" if use_buffer_name else ("write", 0) + sch.reindex(block, buf) + tvm.ir.assert_structural_equal(matmul_unit_dim_reindex_write, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=matmul_unit_dim) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index 0b0146ee43fa..174e9eb25cc0 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -152,21 +152,27 @@ def conv2d_nhwc_transformed( T.float32(0), dtype="float32", ) - for ax0, ax_1, ax_2 in T.grid(12544, 64, 147): + for ax0, ax1, ax2 in T.grid(12544, 64, 147): with T.block("conv2d_nhwc"): - bv0, bv1, bv2 = T.axis.remap("SSR", [ax0, ax_1, ax_2]) - T.reads( - PadInput[0, bv0 // 112 * 2 + bv2 // 21, bv0 % 112 * 2 + bv2 % 21 // 3, bv2 % 3], - Weight[bv2 // 21, bv2 % 21 // 3, bv2 % 3, bv1], - ) - T.writes(Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1]) + v0, v1, v2 = T.axis.remap("SSR", [ax0, ax1, ax2]) + T.reads(PadInput[v0 // 12544, v0 // 112 * 2 + v2 // 21, v0 % 112 * 2 + v2 % 21 // 3, v2 % 3], Weight[v2 // 21, v2 % 21 // 3, v2 % 3, v1]) + T.writes(Conv2d_nhwc[v0 // 12544, v0 // 112, v0 % 112, v1]) with T.init(): - Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1] = T.float32(0) - Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1] = ( - Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1] - + PadInput[0, bv0 // 112 * 2 + bv2 // 21, bv0 % 112 * 2 + bv2 % 21 // 3, bv2 % 3] - * Weight[bv2 // 21, bv2 % 21 // 3, bv2 % 3, bv1] - ) + Conv2d_nhwc[v0 // 12544, v0 // 112, v0 % 112, v1] = T.float32(0) + Conv2d_nhwc[v0 // 12544, v0 // 112, v0 % 112, v1] = Conv2d_nhwc[v0 // 12544, v0 // 112, v0 % 112, v1] + PadInput[v0 // 12544, v0 // 112 * 2 + v2 // 21, v0 % 112 * 2 + v2 % 21 // 3, v2 % 3] * Weight[v2 // 21, v2 % 21 // 3, v2 % 3, v1] + + +@T.prim_func +def two_elementwise_unit_dim(A: T.Buffer[(1, 128), "float32"], C: T.Buffer[(1, 128), "float32"]) -> None: + B = T.alloc_buffer((1, 128), "float32") + for i, j in T.grid(1, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(1, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 # pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks # fmt: on @@ -225,6 +231,24 @@ def test_two_elementwise_transform_output_buffer(use_block_name): verify_trace_roundtrip(sch=sch, mod=two_elementwise) +def test_two_elementwise_unit_dim(use_block_name): + sch = tir.Schedule(two_elementwise_unit_dim, debug_mask="all") + index_map = lambda i, j: (i, j) + + if use_block_name: + sch.transform_layout( + index_map=index_map, + block="B", + buffer="B", + ) + else: + block = sch.get_block("B") + sch.transform_layout(block, ("write", 0), index_map) + + tvm.ir.assert_structural_equal(two_elementwise_unit_dim, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=two_elementwise_unit_dim) + + def test_simplify(): sch = tir.Schedule(two_elementwise, debug_mask="all") @@ -312,6 +336,29 @@ def test_transform_block_layout_conv2d_nhwc(use_block_name): verify_trace_roundtrip(sch=sch, mod=conv2d_nhwc) +def test_transform_block_layout_unit_dim(use_block_name): + sch = tir.Schedule(two_elementwise_unit_dim, debug_mask="all") + block = "B" if use_block_name else sch.get_block("B") + sch.transform_block_layout(block, lambda i, j: (j, i)) + + @T.prim_func + def two_elementwise_unit_dim_transformed( + A: T.Buffer[(1, 128), "float32"], C: T.Buffer[(1, 128), "float32"] + ) -> None: + B = T.alloc_buffer((1, 128), "float32") + for j, i in T.grid(128, 1): + with T.block("B"): + vj, vi = T.axis.remap("SS", [j, i]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(1, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + tvm.ir.assert_structural_equal(two_elementwise_unit_dim_transformed, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=two_elementwise_unit_dim) + + def test_transform_block_layout_fail_non_affine(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") block = "B" if use_block_name else sch.get_block("B") From 3a8a2023d807373bc871643a2f9067ed31c2ea2c Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 4 Oct 2022 15:51:14 -0700 Subject: [PATCH 2/2] fix dtype --- src/tir/schedule/primitive/layout_transformation.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index d14ea317f847..9d36a5f7e5c4 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -1202,8 +1202,8 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, // create block iter each expression in f(ax_0, ..., ax_n). Array new_block_iters; // new block iters Array new_block_vars; // iter_var->var of new block iters - for (size_t i = 0; i < index_map->final_indices.size(); ++i) { - Var new_block_var{"v" + std::to_string(i), DataType::Int(32)}; + for (size_t i = 0; i < transformed_block_iters.size(); ++i) { + Var new_block_var{"v" + std::to_string(i), transformed_block_iters[i]->dtype}; new_block_vars.push_back(new_block_var); IterVarType iter_type = DetectNewBlockIterType(transformed_block_iters[i], block_iter_type); if (iter_type == kOpaque) { @@ -1245,7 +1245,7 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, // Make new loop vars Array new_loop_vars; for (int i = 0; i < static_cast(new_block_iters.size()); ++i) { - new_loop_vars.push_back(Var("ax" + std::to_string(i), DataType::Int(32))); + new_loop_vars.push_back(Var("ax" + std::to_string(i), new_block_iters[i]->var.dtype())); } // Make new block realize