From 46c8edb3122b8e966085c3074fa1a6637d992868 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 14 Jun 2022 16:53:53 -0700 Subject: [PATCH 1/6] [WIP] Tensorize Mapping proposer --- python/tvm/tir/schedule/analysis.py | 11 + src/tir/schedule/analysis.h | 43 +++ src/tir/schedule/analysis/analysis.cc | 275 ++++++++++++++++-- src/tir/schedule/ir_comparator.cc | 130 +++++++++ src/tir/schedule/ir_comparator.h | 42 +++ .../unittest/test_tir_schedule_analysis.py | 93 +++++- 6 files changed, 569 insertions(+), 25 deletions(-) diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py index 71ff024217c7..c6bb12fea4b9 100644 --- a/python/tvm/tir/schedule/analysis.py +++ b/python/tvm/tir/schedule/analysis.py @@ -87,3 +87,14 @@ def get_tensorize_loop_mapping( TensorizeInfo structure if a valid mapping is found, None otherwise """ return _ffi_api.GetTensorizeLoopMapping(sch, block, desc_func) # type: ignore + + +@tvm._ffi.register_object("tir.schedule.AutoTensorizeMappingInfo") +class AutoTensorizeMappingInfo(Object): + """TODO""" + + +def get_tensorize_layout_info( + sch: Schedule, block: BlockRV, desc_func: PrimFunc +) -> Optional[AutoTensorizeMappingInfo]: + return _ffi_api.GetAutoTensorizeMappingInfo(sch, block, desc_func) # type: ignore \ No newline at end of file diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 5adc4f8f1b30..8a7a9358d574 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -707,6 +707,49 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, const tir::StmtSRef& block_sref, const tir::PrimFunc& desc_func); +/*!\brief Necessary information used to perform transformations for tensorization */ +class AutoTensorizeMappingInfoNode : public Object { + public: + /*! \brief Possible mappings to apply to block iters */ + Array mapping; + /*! \brief Mapping from LHS buffer to RHS buffer */ + Map lhs_buffer_map; + + Map> rhs_indices_map; + Array lhs_iters, rhs_iters; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("mapping", &mapping); + v->Visit("rhs_indices_map", &rhs_indices_map); + v->Visit("lhs_iters", &lhs_iters); + v->Visit("rhs_iters", &rhs_iters); + } + + static constexpr const char* _type_key = "tir.schedule.AutoTensorizeMappingInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(AutoTensorizeMappingInfoNode, Object); +}; + +class AutoTensorizeMappingInfo : public ObjectRef { + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AutoTensorizeMappingInfo, ObjectRef, + AutoTensorizeMappingInfoNode); +}; + +/*! + * \brief Get mapping info between a target block and an intrinsic description including layout + * transformations to apply. + * \param self The schedule state + * \param block_sref The compute block for auto tensorization + * \param desc_func The prim func describing the computation to be tensorized + * \return AutoTensorizeMappingInfo structure if a potential mapping is found, NullOpt otherwise. + * \note Returning a valid AutoTensorizeMappingInfo doesn't guarantee the block can be tensorized. + * We will need to apply the suggested layout transformations and then match against the tensor + * intrinsics. + */ +Optional GetAutoTensorizeMappingInfo(const ScheduleState& self, + const StmtSRef& block_sref, + const PrimFunc& desc_func); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 7def8b8674e1..205feac57217 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -19,6 +19,7 @@ #include #include +#include "../ir_comparator.h" #include "../utils.h" namespace tvm { @@ -2085,39 +2086,60 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // TVM_REGISTER_NODE_TYPE(TensorizeInfoNode); -Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, - const tir::StmtSRef& block_sref, - const tir::PrimFunc& desc_func) { - arith::Analyzer analyzer; - const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref); - // Step 1. Analyze desc_func, extract its block, loops and loop vars - const tir::BlockRealizeNode* desc_block = nullptr; +/*! \brief Auxiliary data structure of information extracted from tensor intrin description */ +struct TensorIntrinDescInfo { + /*! \brief The block of the description function, which is the (unique) direct child of the root + * block. + */ + const BlockRealizeNode* desc_block = nullptr; + /*! \brief The loops of the description function, in the order from outer loops to inner ones. */ std::vector desc_loops; + /*! \brief The loop variables. */ std::unordered_set desc_loop_vars; - const auto* desc_scope_realize = desc_func->body.as(); +}; + +/*! + * \brief Extract auxilary information from the tensor intrin description. + * \param analyze The arithmetic analyzer + * \param desc_func The description PrimFunc + * \return The auxilary information + */ +TensorIntrinDescInfo ExtractTensorIntrinDescInfo(arith::Analyzer* analyzer, + const PrimFunc& desc_func) { + TensorIntrinDescInfo info; + const auto* desc_scope_realize = desc_func->body.as(); ICHECK(desc_scope_realize); { - auto f_visit = [&desc_block, &desc_loops, &desc_loop_vars, - &analyzer](const ObjectRef& obj) -> bool { + auto f_visit = [&](const ObjectRef& obj) -> bool { // Extract the block - if (const auto* block = obj.as()) { - desc_block = block; + if (const auto* block = obj.as()) { + info.desc_block = block; return false; } - // Extract loops - if (const auto* loop = obj.as()) { - desc_loops.push_back(loop); - desc_loop_vars.insert(loop->loop_var.get()); - if (!analyzer.CanProve(loop->min == 0)) { + // Extract the loops + if (const auto* loop = obj.as()) { + info.desc_loops.push_back(loop); + info.desc_loop_vars.insert(loop->loop_var.get()); + if (!analyzer->CanProve(loop->min == 0)) { return false; } } return true; }; tir::PostOrderVisit(desc_scope_realize->block->body, f_visit); - std::reverse(desc_loops.begin(), desc_loops.end()); - ICHECK(desc_block); + std::reverse(info.desc_loops.begin(), info.desc_loops.end()); + ICHECK(info.desc_block); } + return info; +} + +Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, + const tir::StmtSRef& block_sref, + const tir::PrimFunc& desc_func) { + arith::Analyzer analyzer; + const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref); + // Step 1. Analyze desc_func, extract its block, loops and loop vars + TensorIntrinDescInfo desc_info = ExtractTensorIntrinDescInfo(&analyzer, desc_func); // Step 2. Collect loops from block_sref const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false); const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); @@ -2138,6 +2160,9 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, std::reverse(block_loops.begin(), block_loops.end()); } // Step 3. Map from block loops to desc block loops + const std::vector& desc_loops = desc_info.desc_loops; + const std::unordered_set& desc_loop_vars = desc_info.desc_loop_vars; + const BlockRealizeNode* desc_block = desc_info.desc_block; ObjectPtr ret = make_object(); const int n_block_vars = block->iter_values.size(); const int n_desc_vars = desc_block->iter_values.size(); @@ -2240,5 +2265,217 @@ TVM_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping") return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func); }); +<<<<<<< HEAD +======= +/******** Auto Tensorization ********/ + +/*! \brief IndexMap proposer for layout transformation in auto tensorization. */ +class MappingProposer { + public: + static Array ProposeMappings(const AutoTensorizeExtractor* extractor) { + MappingProposer proposer(extractor); + proposer.CollectFeasibleSet(); + proposer.ProposeAllFuseMapping(); + return proposer.mappings_; + } + + private: + explicit MappingProposer(const AutoTensorizeExtractor* extractor) : extractor_(extractor) {} + + using VarSet = std::unordered_set; + + std::string to_string(const VarSet& vs) { + std::ostringstream os; + for (const auto& v : vs) { + os << v << ", "; + } + return os.str(); + }; + + void CollectFeasibleSet() { + // Collect the set of potential iter var mapping between the workload and the tensor intrin. + // We analyze the appearance of each variable in the buffer indices of each buffer on LHS and + // RHS. The appearance of a variable in the buffer indices is encoded as bit-masks (BufferMask). + // Variables on the LHS and the RHS with the same bit-mask are potential mappings. + // + // For example, consider the conv2d case. We will try to match the workload + // conv2d[n, h, w, c] = sum_{rh, rw, rc} X[n, h + rh, w + rw, c + rc] * W[rh, rw, rc, c] + // against a matmul tensor intrin + // C[m, n] = sum_{k} A[m, k] * B[k, n] + // First we extract the correspondence of the buffers: conv2d <=> C, A <=> X, B <=> W. + // Then for each variable, we extract the buffers where it is used for indexing. + // Take the variable m on the RHS as an example. m is used to index buffer A and C. On the LHS, + // we will find the variables used to index only the exact corresponding buffers conv2d and X + // (the variable is not allowed to index other buffers). In this case, n, h, w is used to index + // both buffer conv2d and W, and not in other buffers. Therefore, {n, h, w} <=> m is a potential + // mapping. + + // Note: the mapping is not unique when multiple variables in RHS has the same bit-mask. + // This is currently not supported. + + using BufferMask = std::vector; + + // Step 1: Assign an index to each buffer in LHS and RHS + std::unordered_map rhs_buffer_index; + std::unordered_map lhs_buffer_index; + { + int i = 0; + for (const auto& kv : extractor_->rhs_buffer_map_) { + const Buffer& rhs_buffer = kv.first; + const Buffer& lhs_buffer = kv.second; + rhs_buffer_index[rhs_buffer] = i; + lhs_buffer_index[lhs_buffer] = i; + ++i; + } + } + + // Step 2: Compute the buffer mask + ICHECK_EQ(rhs_buffer_index.size(), lhs_buffer_index.size()); + int num_buffers = rhs_buffer_index.size(); + std::unordered_map> rhs_buffer_masks, lhs_buffer_masks; + // helper function to initialize or update the buffer mask + auto update_mask = [&](const VarNode* var, + std::unordered_map>* masks, int i) { + if (!masks->count(var)) { + (*masks)[var].resize(num_buffers); + } + (*masks)[var][i] = true; + }; + + for (const auto& it : extractor_->rhs_buffer_indices_map_) { + const Buffer& rhs_buffer = it.first; + for (const PrimExpr& rhs_index : it.second) { + if (const VarNode* var_node = rhs_index.as()) { + update_mask(var_node, &rhs_buffer_masks, rhs_buffer_index.at(rhs_buffer)); + } else { + LOG(FATAL) << "ValueError: Buffer index " << rhs_index + << " other that variables in tensor intrinsics is not supported."; + } + } + + auto lhs_buffer_it = extractor_->rhs_buffer_map_.find(rhs_buffer); + ICHECK(lhs_buffer_it != extractor_->rhs_buffer_map_.end()); + const Buffer& lhs_buffer = lhs_buffer_it->second; + for (const PrimExpr& index : extractor_->lhs_buffer_indices_map_.at(lhs_buffer)) { + PreOrderVisit(index, [&](const ObjectRef& obj) -> bool { + if (const VarNode* var = obj.as()) { + update_mask(var, &lhs_buffer_masks, lhs_buffer_index.at(lhs_buffer)); + } + return true; + }); + } + } + + // Step 3: Find variables on LHS and RHS with the same buffer mask + std::unordered_map mask_to_rhs_vars; + for (const auto& kv : rhs_buffer_masks) { + const VarNode* rhs_var = kv.first; + const BufferMask& mask = kv.second; + mask_to_rhs_vars[mask].insert(GetRef(rhs_var)); + } + + for (const auto& iter : extractor_->lhs_iters_) { + // lhs_representers.push_back(iter->var.copy_with_suffix("_l")); + lhs_feasible_vars_[iter->var] = mask_to_rhs_vars[lhs_buffer_masks[iter->var.get()]]; + } + } + + void ProposeAllFuseMapping() { + // Now we have calcuated potential mapping for each iter var on LHS. For iters on LHS mapped to + // the same iter on RHS, they will be fused in the original order in LHS block iters. We will + // generate IndexMap to represent such fusion on LHS. For example, if n, h, w on LHS are mapped + // to the same iter var on RHS, we will produce index map `lambda n, h, w: fuse(n, h, w)`, where + // fuse(v0, .., vn) = ((v0 * v1_extent + v1) + ... ) * vn_extent + vn + + // the parameters of the result index map, each parameter corresponds to a LHS iter + Array index_map_src; + // the outputs of the result index map + Array index_map_tgt; + + // Step 1: Collect extents of LHS iters and prepare the initial indices of the IndexMap + Map lhs_iter_extents; + for (const auto& iter : extractor_->lhs_iters_) { + lhs_iter_extents.Set(iter->var, iter->dom->extent); + index_map_src.push_back(iter->var.copy_with_suffix("")); + } + + // Step 2: Each iter on RHS has a group of corresponding iters on LHS. Initialize the fusion + // result for each group of iters on LHS. + Map fused_lhs_iters; + for (const auto& iter : extractor_->rhs_iters_) { + fused_lhs_iters.Set(iter->var, 0); + } + + // Step 3: Fuse LHS iters mapped to the same RHS iter + for (size_t i = 0; i < extractor_->lhs_iters_.size(); ++i) { + const Var& lhs_iter_var = extractor_->lhs_iters_[i]->var; + const VarSet& rhs_candidates = lhs_feasible_vars_[lhs_iter_var]; + if (rhs_candidates.empty()) { + // put unmapped iters at the beginning + index_map_tgt.push_back(index_map_src[i]); + } else if (rhs_candidates.size() == 1) { + Var rhs_var = *rhs_candidates.begin(); + PrimExpr fused_lhs = fused_lhs_iters.at(rhs_var); + PrimExpr updated_fused_lhs = fused_lhs * lhs_iter_extents.at(lhs_iter_var) + index_map_src[i]; + fused_lhs_iters.Set(rhs_var, updated_fused_lhs); + } else { + // non-unique mapping is not supported + return; + } + } + arith::Analyzer analyzer; + for (const auto& iter : extractor_->rhs_iters_) { + index_map_tgt.push_back(analyzer.Simplify(fused_lhs_iters[iter->var])); + } + mappings_.push_back(IndexMap(index_map_src, index_map_tgt)); + LOG(INFO) << mappings_[0]; + } + + public: + Array lhs_representers; + std::unordered_map lhs_buffer_map_; + // std::unordered_map rhs_feasible_vars_; + std::unordered_map lhs_feasible_vars_; + Array mappings_; + const AutoTensorizeExtractor* extractor_; +}; + +Optional GetAutoTensorizeMappingInfo(const tir::ScheduleState& self, + const tir::StmtSRef& block_sref, + const tir::PrimFunc& desc_func) { + arith::Analyzer analyzer; + const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref); + // Step 1. Analyze desc_func, extract its block, loops and loop vars + TensorIntrinDescInfo desc_info = ExtractTensorIntrinDescInfo(&analyzer, desc_func); + // Step 2. Check if `desc_block` matches `block` + // Ignore the scope of buffers when comparing, since we can do cache_read/write + const StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false); + const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); + AutoTensorizeExtractor extractor(self->mod); + if (!extractor.VisitStmt(block->block, desc_info.desc_block->block)) { + return NullOpt; + } + Array mappings = MappingProposer::ProposeMappings(&extractor); + if (mappings.empty()) { + return NullOpt; + } + ObjectPtr ret = make_object(); + // Only using 1 layout now + ret->mapping = std::move(mappings); + ret->lhs_buffer_map = std::move(extractor.lhs_buffer_map_); + ret->rhs_indices_map = std::move(extractor.rhs_buffer_indices_map_); + ret->lhs_iters = std::move(extractor.lhs_iters_); + ret->rhs_iters = std::move(extractor.rhs_iters_); + return AutoTensorizeMappingInfo(ret); +} + +TVM_REGISTER_NODE_TYPE(AutoTensorizeMappingInfoNode); + +TVM_REGISTER_GLOBAL("tir.schedule.GetAutoTensorizeMappingInfo") + .set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func) { + return GetAutoTensorizeMappingInfo(sch->state(), sch->GetSRef(block), desc_func); + }); + +>>>>>>> 19a13545e ([WIP] Tensorize Mapping proposer) } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/ir_comparator.cc b/src/tir/schedule/ir_comparator.cc index 58c502379a7a..a1f1aedf85bf 100644 --- a/src/tir/schedule/ir_comparator.cc +++ b/src/tir/schedule/ir_comparator.cc @@ -355,5 +355,135 @@ void TensorizeComparator::EmitError(const std::string& error_message) { error_messages_.push_back(error_message); } +/******** AutoTensorize Extractor ********/ + +bool AutoTensorizeExtractor::VisitExprDefault_(const Object* op, const PrimExpr& other) { + return false; +} + +bool AutoTensorizeExtractor::VisitStmtDefault_(const Object* op, const Stmt& other) { + return false; +} + +template +bool AutoTensorizeExtractor::CompareArray(const Array& lhs, const Array& rhs, F cmp) { + if (lhs.same_as(rhs)) return true; + if (lhs.size() != rhs.size()) return false; + for (size_t i = 0; i < lhs.size(); ++i) { + if (!(this->*cmp)(lhs[i], rhs[i])) return false; + } + return true; +} + +bool AutoTensorizeExtractor::VisitStmt_(const BlockNode* op, const Stmt& other) { + const auto* rhs = other.as(); + // Check block equality. + // All iter vars and buffer regions including the order should match. + // When checking iter vars, DefEqual is used to remap variables. + if (!is_scope_block) { + if (!CompareArray(op->iter_vars, rhs->iter_vars, &AutoTensorizeExtractor::CompareIterVar)) { + return false; + } + if (!CompareAnnotationMap(op->annotations, rhs->annotations)) { + return false; + } + if (!CompareArray(op->alloc_buffers, rhs->alloc_buffers, + &AutoTensorizeExtractor::CompareBuffer)) { + return false; + } + for (const IterVar& block_iter : op->iter_vars) { + inner_iter_dom_map_.Set(block_iter->var, arith::IntSet::FromRange(block_iter->dom)); + } + } else { + auto collect_iter = [&](const BlockNode* op, std::vector& iters) -> bool { + for (const auto& iter : op->iter_vars) { + analyzer_.Bind(iter->var, iter->dom); + if (iter->iter_type == IterVarType::kDataPar || + iter->iter_type == IterVarType::kCommReduce) { + iters.push_back(iter); + } else { + return false; + } + } + return true; + }; + if (!collect_iter(op, lhs_iters_)) { + return false; + } + if (!collect_iter(rhs, rhs_iters_)) { + return false; + } + } + is_scope_block = false; + return VisitStmt(op->body, rhs->body); +} + +bool AutoTensorizeExtractor::CompareBuffer(const Buffer& lhs, const Buffer& rhs) { + if (lhs.same_as(rhs)) return true; + auto it = rhs_buffer_map_.find(rhs); + bool equal; + if (it != rhs_buffer_map_.end()) { + equal = (*it).second.same_as(lhs); + } else { + // Remap both buffer itself and buffer data, skip buffer shape and scope + equal = DefEqual(lhs->data, rhs->data) && lhs->dtype == rhs->dtype; + if (equal) { + rhs_buffer_map_[rhs] = lhs; + lhs_buffer_map_[lhs] = rhs; + } + } + return equal; +} + +bool AutoTensorizeExtractor::VisitStmt_(const BufferStoreNode* op, const Stmt& other) { + const auto* rhs = other.as(); + return CompareBufferAccess(op, rhs) && VisitExpr(op->value, rhs->value); +} + +bool AutoTensorizeExtractor::VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return CompareBufferAccess(op, rhs); +} + +template +bool AutoTensorizeExtractor::CompareBufferAccess(const T* lhs, const T* rhs) { + if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false; + auto it_lhs = lhs_buffer_indices_map_.find(lhs->buffer); + if (it_lhs == lhs_buffer_indices_map_.end()) { + if (rhs_buffer_indices_map_.find(rhs->buffer) != rhs_buffer_indices_map_.end()) { + return false; + } + std::vector lhs_indices; + for (const auto& index : lhs->indices) { + lhs_indices.push_back(analyzer_.Simplify(index)); + } + for (const auto& index : rhs->indices) { + if (!index.template as()) return false; + } + lhs_buffer_indices_map_[lhs->buffer] = lhs_indices; + rhs_buffer_indices_map_[rhs->buffer] = rhs->indices; + } else { + auto it_rhs = rhs_buffer_indices_map_.find(rhs->buffer); + if (it_rhs == rhs_buffer_indices_map_.end()) { + return false; + } + auto indices_check = [&](const Array& indices, + const Array& old_indices) -> bool { + if (indices.size() != old_indices.size()) { + return false; + } + for (size_t i = 0; i < indices.size(); ++i) { + if (!analyzer_.CanProveEqual(indices[i], old_indices[i])) { + return false; + } + } + return true; + }; + if (!indices_check(lhs->indices, it_lhs->second)) return false; + if (!indices_check(rhs->indices, it_rhs->second)) return false; + } + return true; +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/ir_comparator.h b/src/tir/schedule/ir_comparator.h index 359677d8852f..df0536a13a15 100644 --- a/src/tir/schedule/ir_comparator.h +++ b/src/tir/schedule/ir_comparator.h @@ -110,6 +110,48 @@ class TensorizeComparator : public ExprComparator, public StmtComparator { std::unordered_map equal_map_; }; +/*! \brief IR comparator for auto tensorization. Extract correspondence between the IR of the + * workload and the tensor intrin. + */ +class AutoTensorizeExtractor : public TensorizeComparator { + public: + explicit AutoTensorizeExtractor(const IRModule& lhs_mod) + : TensorizeComparator(lhs_mod, /* assert_mode=*/false) {} + + private: + bool VisitExprDefault_(const Object* op, const PrimExpr& other) override; + bool VisitStmtDefault_(const Object* op, const Stmt& other) override; + + bool VisitStmt_(const BlockNode* op, const Stmt& other) override; + bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) override; + + bool VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) override; + + template + bool CompareArray(const Array& lhs, const Array& rhs, F cmp); + bool CompareBuffer(const Buffer& lhs, const Buffer& rhs) override; + template + bool CompareBufferAccess(const T* lhs, const T* rhs); + + public: + /*! \brief Block iters in the LHS stmt. */ + std::vector lhs_iters_; + /*! \brief Block iters in the RHS stmt. */ + std::vector rhs_iters_; + /*! \brief The buffer and its access indices in the LHS stmt. */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + lhs_buffer_indices_map_; + /*! \brief The buffer and its access indices in the RHS stmt. */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + rhs_buffer_indices_map_; + /*! \brief Map from LHS buffer to RHS buffer */ + std::unordered_map lhs_buffer_map_; + + private: + /*! \brief The domain of the inner block iters. */ + Map inner_iter_dom_map_; +}; + } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py index 19be0b8699ac..b06cd6f42ec9 100644 --- a/tests/python/unittest/test_tir_schedule_analysis.py +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -18,12 +18,15 @@ from typing import List import tvm +import tvm.testing +from tvm.tir.function import TensorIntrin from tvm.tir.tensor_intrin.x86 import dot_product_16x4_u8i8i32_desc +from tvm.tir.tensor_intrin import cuda as cuda_intrin from tvm.tir import Evaluate, For, ForKind, IndexMap, Var, decl_buffer, floordiv, floormod, Schedule from tvm.tir.analysis import expr_deep_equal -from tvm.tir.schedule.analysis import suggest_index_map, get_tensorize_loop_mapping, TensorizeInfo +from tvm.tir.schedule.analysis import suggest_index_map, get_tensorize_loop_mapping, get_tensorize_layout_info, TensorizeInfo from tvm.script import tir as T from tvm.tir.stmt_functor import pre_order_visit from tvm.meta_schedule.testing import te_workload @@ -156,6 +159,72 @@ def main( "int32", ) +@T.prim_func +def conv2d_nhwc_hwio( + Input: T.Buffer[(4, 16, 16, 64), "float16"], + Weight: T.Buffer[(3, 3, 64, 64), "float16"], + Conv2d_nhwc: T.Buffer[(4, 16, 16, 64), "float32"], +) -> None: + PadInput = T.alloc_buffer([4, 18, 18, 64], dtype="float16") + for i0, i1, i2, i3 in T.grid(4, 18, 18, 64): + with T.block("PadInput"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( + ((((i1_1 >= 1) and (i1_1 < 17)) and (i2_1 >= 1)) and (i2_1 < 17)), + Input[i0_1, (i1_1 - 1), (i2_1 - 1), i3_1], + T.float32(0), + dtype="float32", + ) + for i0, i1, i2, i3, i4, i5, i6 in T.grid(4, 16, 16, 64, 3, 3, 64): + with T.block("conv2d_nhwc"): + n, h, w, co, rh, rw, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + with T.init(): + Conv2d_nhwc[n, h, w, co] = T.float32(0) + Conv2d_nhwc[n, h, w, co] = Conv2d_nhwc[n, h, w, co] + ( + T.cast(PadInput[n, h + rh, w + rw, rc], 'float32') + * T.cast(Weight[rh, rw, rc, co], 'float32') + ) + +@T.prim_func +def conv2d_nhwc_ohwi( + Input: T.Buffer[(4, 16, 16, 64), "float16"], + Weight: T.Buffer[(64, 3, 3, 64), "float16"], + Conv2d_nhwc: T.Buffer[(4, 16, 16, 64), "float32"], +) -> None: + PadInput = T.alloc_buffer([4, 18, 18, 64], dtype="float16") + for i0, i1, i2, i3 in T.grid(4, 18, 18, 64): + with T.block("PadInput"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( + ((((i1_1 >= 1) and (i1_1 < 17)) and (i2_1 >= 1)) and (i2_1 < 17)), + Input[i0_1, (i1_1 - 1), (i2_1 - 1), i3_1], + T.float32(0), + dtype="float32", + ) + for i0, i1, i2, i3, i4, i5, i6 in T.grid(4, 16, 16, 64, 3, 3, 64): + with T.block("conv2d_nhwc"): + n, h, w, co, rh, rw, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + with T.init(): + Conv2d_nhwc[n, h, w, co] = T.float32(0) + Conv2d_nhwc[n, h, w, co] = Conv2d_nhwc[n, h, w, co] + ( + T.cast(PadInput[n, h + rh, w + rw, rc], 'float32') + * T.cast(Weight[co, rh, rw, rc], 'float32') + ) + +@T.prim_func +def batch_matmul( + X: T.Buffer[(16, 32, 128), "float16"], + W: T.Buffer[(16, 128, 64), "float16"], + Y: T.Buffer[(16, 32, 64), "float32"] +) -> None: + for i0, i1, i2, i3 in T.grid(16, 32, 64, 128): + with T.block("batch_matmul"): + b, m, n, k = T.axis.remap("SSSR", [i0, i1, i2, i3]) + with T.init(): + Y[b, m, n] = T.float32(0) + Y[b, m, n] += T.cast(X[b, m, k], "float32") * T.cast(W[b, k, n], "float32") + + def collect_loops(prim_func): loops = [] @@ -252,9 +321,21 @@ def matmul_16x16x16xf16f16f16_desc( assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i2) +def get_intrin_desc(intrin_name): + return TensorIntrin.get(intrin_name).desc +def test_get_tensorize_layout_info(): + s = Schedule(conv2d_nhwc_hwio) + block = s.get_block('conv2d_nhwc') + print('get mapping') + info = get_tensorize_layout_info(s, block, get_intrin_desc(cuda_intrin.WMMA_SYNC_16x16x16_f16f16f32_INTRIN)) + print(info.mapping) + +def test_get_tensorize_layout_info_gmm(): + s = Schedule(batch_matmul) + block = s.get_block('batch_matmul') + print(batch_matmul.script()) + info = get_tensorize_layout_info(s, block, get_intrin_desc(cuda_intrin.WMMA_SYNC_16x16x16_f16f16f32_INTRIN)) + print(info.mapping) + if __name__ == "__main__": - test_suggest_index_map_simple() - test_suggest_index_map_bijective() - test_get_tensorize_loop_mapping_dense_vnni() - test_get_tensorize_loop_mapping_conv2d_nchwc_vnni() - test_get_tensorize_loop_mapping_matmul_mma() + test_get_tensorize_layout_info() \ No newline at end of file From 59015c1e1014ce5369d5951c8bf1b17e4739b9b9 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 15 Jun 2022 14:55:44 -0700 Subject: [PATCH 2/6] Tensorize mapping proposer --- .../tvm/meta_schedule/testing/te_workload.py | 68 +++++++++ python/tvm/tir/schedule/analysis.py | 29 +++- src/tir/schedule/analysis.h | 18 ++- src/tir/schedule/analysis/analysis.cc | 64 ++++----- .../unittest/test_tir_schedule_analysis.py | 129 ++++++------------ 5 files changed, 178 insertions(+), 130 deletions(-) diff --git a/python/tvm/meta_schedule/testing/te_workload.py b/python/tvm/meta_schedule/testing/te_workload.py index 52f5f49b0a12..28a2df628c53 100644 --- a/python/tvm/meta_schedule/testing/te_workload.py +++ b/python/tvm/meta_schedule/testing/te_workload.py @@ -701,6 +701,74 @@ def softmax_mn(m, n) -> Tuple[te.Tensor, te.Tensor]: # pylint: disable=invalid- return (a, b) +def conv2d_nhwc_f16( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, +): + inputs = te.placeholder((N, H, W, CI), name="inputs", dtype="float16") + weight = te.placeholder( + (kernel_size, kernel_size, CI // groups, CO), name="weight", dtype="float16" + ) + batch_size, in_h, in_w, _ = inputs.shape + k_h, k_w, channel_per_group, out_channel = weight.shape + out_channel_per_group = out_channel // groups + + out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 + out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 + rh = te.reduce_axis((0, k_h), name="rh") + rw = te.reduce_axis((0, k_w), name="rw") + rc = te.reduce_axis((0, channel_per_group), name="rc") + + padded = topi.nn.pad(inputs, [0, padding, padding, 0]) + output = te.compute( + (batch_size, out_h, out_w, out_channel), + lambda n, h, w, co: te.sum( + ( + tir.Cast( + value=padded[ + n, + h * stride + rh * dilation, + w * stride + rw * dilation, + co // out_channel_per_group * channel_per_group + rc, + ], + dtype="float32", + ) + * tir.Cast(value=weight[rh, rw, rc, co], dtype="float32") + ), + axis=[rh, rw, rc], + ), + name="conv2d_nhwc", + ) + return (inputs, weight, output) + + +def batch_matmul_nkkm_f16( # pylint: disable=invalid-name,missing-docstring + B: int, + N: int, + M: int, + K: int, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + x = te.placeholder((B, N, K), name="X", dtype="float16") + y = te.placeholder((B, K, M), name="Y", dtype="float16") + k = te.reduce_axis((0, K), name="k") + z = te.compute( # pylint: disable=invalid-name + (B, N, M), + lambda b, i, j: te.sum( + tir.Cast("float32", x[b][i][k]) * tir.Cast("float32", y[b][k][j]), axis=[k] + ), + name="Z", + ) + return (x, y, z) + + def create_te_workload(name: str, idx: int) -> tir.PrimFunc: workload_func, params = CONFIGS[name] return te.create_prim_func(workload_func(*params[idx])) # type: ignore diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py index c6bb12fea4b9..cdb4aa9cfa20 100644 --- a/python/tvm/tir/schedule/analysis.py +++ b/python/tvm/tir/schedule/analysis.py @@ -91,10 +91,33 @@ def get_tensorize_loop_mapping( @tvm._ffi.register_object("tir.schedule.AutoTensorizeMappingInfo") class AutoTensorizeMappingInfo(Object): - """TODO""" + """Necessary information used to perform transformations for tensorization.""" -def get_tensorize_layout_info( +def get_auto_tensorize_mapping_info( sch: Schedule, block: BlockRV, desc_func: PrimFunc ) -> Optional[AutoTensorizeMappingInfo]: - return _ffi_api.GetAutoTensorizeMappingInfo(sch, block, desc_func) # type: ignore \ No newline at end of file + """Get mapping info between a target block and an intrinsic description including layout + transformations to apply. + + Parameters + ---------- + sch : Schedule + The schedule to be tensorized + block : BlockRV + The compute block for auto tensorization + desc_func : PrimFunc + The prim func describing the computation to be tensorized + + Returns + ------- + auto_tensorize_mapping_info : Optional[AutoTensorizeMappingInfo] + AutoTensorizeMappingInfo structure if potential mappings found, None otherwise. + + Note + ---- + Returning a valid AutoTensorizeMappingInfo doesn't guarantee the block can be tensorized. + We will need to apply the suggested layout transformations and then match against the tensor + intrinsics. + """ + return _ffi_api.GetAutoTensorizeMappingInfo(sch, block, desc_func) # type: ignore diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 8a7a9358d574..329a9a57971b 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -711,16 +711,22 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, class AutoTensorizeMappingInfoNode : public Object { public: /*! \brief Possible mappings to apply to block iters */ - Array mapping; + Array mappings; + + /* Additional information from AutoTensorizeExtractor */ + /*! \brief Mapping from LHS buffer to RHS buffer */ Map lhs_buffer_map; - - Map> rhs_indices_map; - Array lhs_iters, rhs_iters; + /*! \brief Buffer indices on RHS */ + Map> rhs_buffer_indices; + /*! \brief Block iters on LHS */ + Array lhs_iters; + /*! \brief Block iters on RHS */ + Array rhs_iters; void VisitAttrs(AttrVisitor* v) { - v->Visit("mapping", &mapping); - v->Visit("rhs_indices_map", &rhs_indices_map); + v->Visit("mappings", &mappings); + v->Visit("rhs_buffer_indices", &rhs_buffer_indices); v->Visit("lhs_iters", &lhs_iters); v->Visit("rhs_iters", &rhs_iters); } diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 205feac57217..5465eca41ece 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -2265,33 +2265,25 @@ TVM_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping") return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func); }); -<<<<<<< HEAD -======= /******** Auto Tensorization ********/ /*! \brief IndexMap proposer for layout transformation in auto tensorization. */ -class MappingProposer { +class AutoTensorizeMappingProposer { public: - static Array ProposeMappings(const AutoTensorizeExtractor* extractor) { - MappingProposer proposer(extractor); + static Array ProposeMappings(const AutoTensorizeExtractor* extractor, + arith::Analyzer* analyzer) { + AutoTensorizeMappingProposer proposer(extractor, analyzer); proposer.CollectFeasibleSet(); - proposer.ProposeAllFuseMapping(); - return proposer.mappings_; + return proposer.ProposeAllFuseMapping(); } private: - explicit MappingProposer(const AutoTensorizeExtractor* extractor) : extractor_(extractor) {} + explicit AutoTensorizeMappingProposer(const AutoTensorizeExtractor* extractor, + arith::Analyzer* analyzer) + : extractor_(extractor), analyzer_(analyzer) {} using VarSet = std::unordered_set; - std::string to_string(const VarSet& vs) { - std::ostringstream os; - for (const auto& v : vs) { - os << v << ", "; - } - return os.str(); - }; - void CollectFeasibleSet() { // Collect the set of potential iter var mapping between the workload and the tensor intrin. // We analyze the appearance of each variable in the buffer indices of each buffer on LHS and @@ -2375,12 +2367,11 @@ class MappingProposer { } for (const auto& iter : extractor_->lhs_iters_) { - // lhs_representers.push_back(iter->var.copy_with_suffix("_l")); lhs_feasible_vars_[iter->var] = mask_to_rhs_vars[lhs_buffer_masks[iter->var.get()]]; } } - void ProposeAllFuseMapping() { + Array ProposeAllFuseMapping() { // Now we have calcuated potential mapping for each iter var on LHS. For iters on LHS mapped to // the same iter on RHS, they will be fused in the original order in LHS block iters. We will // generate IndexMap to represent such fusion on LHS. For example, if n, h, w on LHS are mapped @@ -2416,33 +2407,34 @@ class MappingProposer { } else if (rhs_candidates.size() == 1) { Var rhs_var = *rhs_candidates.begin(); PrimExpr fused_lhs = fused_lhs_iters.at(rhs_var); - PrimExpr updated_fused_lhs = fused_lhs * lhs_iter_extents.at(lhs_iter_var) + index_map_src[i]; + PrimExpr updated_fused_lhs = + fused_lhs * lhs_iter_extents.at(lhs_iter_var) + index_map_src[i]; fused_lhs_iters.Set(rhs_var, updated_fused_lhs); } else { // non-unique mapping is not supported - return; + return {}; } } - arith::Analyzer analyzer; for (const auto& iter : extractor_->rhs_iters_) { - index_map_tgt.push_back(analyzer.Simplify(fused_lhs_iters[iter->var])); + index_map_tgt.push_back(analyzer_->Simplify(fused_lhs_iters[iter->var])); } - mappings_.push_back(IndexMap(index_map_src, index_map_tgt)); - LOG(INFO) << mappings_[0]; + // At most one mapping is supported. + return {IndexMap(index_map_src, index_map_tgt)}; } - public: - Array lhs_representers; - std::unordered_map lhs_buffer_map_; - // std::unordered_map rhs_feasible_vars_; - std::unordered_map lhs_feasible_vars_; - Array mappings_; + private: + // The extractor that has extracted information for auto tensorization from the workload and the + // tensor intrin. const AutoTensorizeExtractor* extractor_; + // The arithmetic analyzer. + arith::Analyzer* analyzer_; + /*! \brief Potential mappings on RHS for each variable on LHS */ + std::unordered_map lhs_feasible_vars_; }; Optional GetAutoTensorizeMappingInfo(const tir::ScheduleState& self, - const tir::StmtSRef& block_sref, - const tir::PrimFunc& desc_func) { + const tir::StmtSRef& block_sref, + const tir::PrimFunc& desc_func) { arith::Analyzer analyzer; const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref); // Step 1. Analyze desc_func, extract its block, loops and loop vars @@ -2455,15 +2447,14 @@ Optional GetAutoTensorizeMappingInfo(const tir::Schedu if (!extractor.VisitStmt(block->block, desc_info.desc_block->block)) { return NullOpt; } - Array mappings = MappingProposer::ProposeMappings(&extractor); + Array mappings = AutoTensorizeMappingProposer::ProposeMappings(&extractor, &analyzer); if (mappings.empty()) { return NullOpt; } ObjectPtr ret = make_object(); - // Only using 1 layout now - ret->mapping = std::move(mappings); + ret->mappings = std::move(mappings); ret->lhs_buffer_map = std::move(extractor.lhs_buffer_map_); - ret->rhs_indices_map = std::move(extractor.rhs_buffer_indices_map_); + ret->rhs_buffer_indices = std::move(extractor.rhs_buffer_indices_map_); ret->lhs_iters = std::move(extractor.lhs_iters_); ret->rhs_iters = std::move(extractor.rhs_iters_); return AutoTensorizeMappingInfo(ret); @@ -2476,6 +2467,5 @@ TVM_REGISTER_GLOBAL("tir.schedule.GetAutoTensorizeMappingInfo") return GetAutoTensorizeMappingInfo(sch->state(), sch->GetSRef(block), desc_func); }); ->>>>>>> 19a13545e ([WIP] Tensorize Mapping proposer) } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py index b06cd6f42ec9..6761203a5a4d 100644 --- a/tests/python/unittest/test_tir_schedule_analysis.py +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -16,17 +16,22 @@ # under the License. # pylint: disable=missing-docstring from typing import List - +import pytest import tvm import tvm.testing from tvm.tir.function import TensorIntrin from tvm.tir.tensor_intrin.x86 import dot_product_16x4_u8i8i32_desc -from tvm.tir.tensor_intrin import cuda as cuda_intrin +from tvm.tir.tensor_intrin.cuda import WMMA_SYNC_16x16x16_f16f16f32_INTRIN from tvm.tir import Evaluate, For, ForKind, IndexMap, Var, decl_buffer, floordiv, floormod, Schedule from tvm.tir.analysis import expr_deep_equal -from tvm.tir.schedule.analysis import suggest_index_map, get_tensorize_loop_mapping, get_tensorize_layout_info, TensorizeInfo +from tvm.tir.schedule.analysis import ( + get_auto_tensorize_mapping_info, + suggest_index_map, + get_tensorize_loop_mapping, + TensorizeInfo, +) from tvm.script import tir as T from tvm.tir.stmt_functor import pre_order_visit from tvm.meta_schedule.testing import te_workload @@ -159,72 +164,6 @@ def main( "int32", ) -@T.prim_func -def conv2d_nhwc_hwio( - Input: T.Buffer[(4, 16, 16, 64), "float16"], - Weight: T.Buffer[(3, 3, 64, 64), "float16"], - Conv2d_nhwc: T.Buffer[(4, 16, 16, 64), "float32"], -) -> None: - PadInput = T.alloc_buffer([4, 18, 18, 64], dtype="float16") - for i0, i1, i2, i3 in T.grid(4, 18, 18, 64): - with T.block("PadInput"): - i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( - ((((i1_1 >= 1) and (i1_1 < 17)) and (i2_1 >= 1)) and (i2_1 < 17)), - Input[i0_1, (i1_1 - 1), (i2_1 - 1), i3_1], - T.float32(0), - dtype="float32", - ) - for i0, i1, i2, i3, i4, i5, i6 in T.grid(4, 16, 16, 64, 3, 3, 64): - with T.block("conv2d_nhwc"): - n, h, w, co, rh, rw, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) - with T.init(): - Conv2d_nhwc[n, h, w, co] = T.float32(0) - Conv2d_nhwc[n, h, w, co] = Conv2d_nhwc[n, h, w, co] + ( - T.cast(PadInput[n, h + rh, w + rw, rc], 'float32') - * T.cast(Weight[rh, rw, rc, co], 'float32') - ) - -@T.prim_func -def conv2d_nhwc_ohwi( - Input: T.Buffer[(4, 16, 16, 64), "float16"], - Weight: T.Buffer[(64, 3, 3, 64), "float16"], - Conv2d_nhwc: T.Buffer[(4, 16, 16, 64), "float32"], -) -> None: - PadInput = T.alloc_buffer([4, 18, 18, 64], dtype="float16") - for i0, i1, i2, i3 in T.grid(4, 18, 18, 64): - with T.block("PadInput"): - i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( - ((((i1_1 >= 1) and (i1_1 < 17)) and (i2_1 >= 1)) and (i2_1 < 17)), - Input[i0_1, (i1_1 - 1), (i2_1 - 1), i3_1], - T.float32(0), - dtype="float32", - ) - for i0, i1, i2, i3, i4, i5, i6 in T.grid(4, 16, 16, 64, 3, 3, 64): - with T.block("conv2d_nhwc"): - n, h, w, co, rh, rw, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) - with T.init(): - Conv2d_nhwc[n, h, w, co] = T.float32(0) - Conv2d_nhwc[n, h, w, co] = Conv2d_nhwc[n, h, w, co] + ( - T.cast(PadInput[n, h + rh, w + rw, rc], 'float32') - * T.cast(Weight[co, rh, rw, rc], 'float32') - ) - -@T.prim_func -def batch_matmul( - X: T.Buffer[(16, 32, 128), "float16"], - W: T.Buffer[(16, 128, 64), "float16"], - Y: T.Buffer[(16, 32, 64), "float32"] -) -> None: - for i0, i1, i2, i3 in T.grid(16, 32, 64, 128): - with T.block("batch_matmul"): - b, m, n, k = T.axis.remap("SSSR", [i0, i1, i2, i3]) - with T.init(): - Y[b, m, n] = T.float32(0) - Y[b, m, n] += T.cast(X[b, m, k], "float32") * T.cast(W[b, k, n], "float32") - - def collect_loops(prim_func): loops = [] @@ -321,21 +260,43 @@ def matmul_16x16x16xf16f16f16_desc( assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i2) -def get_intrin_desc(intrin_name): - return TensorIntrin.get(intrin_name).desc -def test_get_tensorize_layout_info(): - s = Schedule(conv2d_nhwc_hwio) - block = s.get_block('conv2d_nhwc') - print('get mapping') - info = get_tensorize_layout_info(s, block, get_intrin_desc(cuda_intrin.WMMA_SYNC_16x16x16_f16f16f32_INTRIN)) - print(info.mapping) +def check_index_map(workload, block_name, intrin_name, expected_index_map): + s = Schedule(workload) + block = s.get_block(block_name) + desc_func = TensorIntrin.get(intrin_name).desc + info = get_auto_tensorize_mapping_info(s, block, desc_func) + assert len(info.mappings) == 1 + assert IndexMap.from_func(expected_index_map).is_equivalent_to(info.mappings[0]) + + +def test_get_auto_tensorize_mapping_info_conv2d(): + conv2d = create_prim_func(te_workload.conv2d_nhwc_f16(4, 16, 16, 64, 64, 3, 1, 1)) + check_index_map( + conv2d, + "conv2d_nhwc", + WMMA_SYNC_16x16x16_f16f16f32_INTRIN, + lambda n, h, w, c, rh, rw, rc: (n * 256 + h * 16 + w, c, rh * 192 + rw * 64 + rc), + ) + + +def test_get_auto_tensorize_mapping_info_conv2d_unit_batch(): + conv2d = create_prim_func(te_workload.conv2d_nhwc_f16(1, 16, 16, 64, 64, 3, 1, 1)) + check_index_map( + conv2d, + "conv2d_nhwc", + WMMA_SYNC_16x16x16_f16f16f32_INTRIN, + # unit iter is not mapped + lambda n, h, w, c, rh, rw, rc: (n, h * 16 + w, c, rh * 192 + rw * 64 + rc), + ) + + +@pytest.mark.parametrize("b,m,n,k", [(1, 512, 512, 512), (16, 32, 32, 32)]) +def test_get_auto_tensorize_mapping_info_batch_matmul(b, m, n, k): + matmul = create_prim_func(te_workload.batch_matmul_nkkm_f16(b, m, n, k)) + check_index_map( + matmul, "Z", WMMA_SYNC_16x16x16_f16f16f32_INTRIN, lambda b, m, n, k: (b, m, n, k) + ) -def test_get_tensorize_layout_info_gmm(): - s = Schedule(batch_matmul) - block = s.get_block('batch_matmul') - print(batch_matmul.script()) - info = get_tensorize_layout_info(s, block, get_intrin_desc(cuda_intrin.WMMA_SYNC_16x16x16_f16f16f32_INTRIN)) - print(info.mapping) if __name__ == "__main__": - test_get_tensorize_layout_info() \ No newline at end of file + tvm.testing.main() From 427d21700c6286ed32444a4f2247aba5face207a Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 16 Jun 2022 13:25:13 -0700 Subject: [PATCH 3/6] AutoTensorizeExtractor -> AutoTensorizeComparator --- src/tir/schedule/analysis.h | 2 +- src/tir/schedule/analysis/analysis.cc | 8 ++++---- src/tir/schedule/ir_comparator.cc | 20 ++++++++++---------- src/tir/schedule/ir_comparator.h | 15 +++++++++++---- 4 files changed, 26 insertions(+), 19 deletions(-) diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 329a9a57971b..277438d85d6f 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -713,7 +713,7 @@ class AutoTensorizeMappingInfoNode : public Object { /*! \brief Possible mappings to apply to block iters */ Array mappings; - /* Additional information from AutoTensorizeExtractor */ + /* Additional information from AutoTensorizeComparator */ /*! \brief Mapping from LHS buffer to RHS buffer */ Map lhs_buffer_map; diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 5465eca41ece..821d81bb844f 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -2270,7 +2270,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping") /*! \brief IndexMap proposer for layout transformation in auto tensorization. */ class AutoTensorizeMappingProposer { public: - static Array ProposeMappings(const AutoTensorizeExtractor* extractor, + static Array ProposeMappings(const AutoTensorizeComparator* extractor, arith::Analyzer* analyzer) { AutoTensorizeMappingProposer proposer(extractor, analyzer); proposer.CollectFeasibleSet(); @@ -2278,7 +2278,7 @@ class AutoTensorizeMappingProposer { } private: - explicit AutoTensorizeMappingProposer(const AutoTensorizeExtractor* extractor, + explicit AutoTensorizeMappingProposer(const AutoTensorizeComparator* extractor, arith::Analyzer* analyzer) : extractor_(extractor), analyzer_(analyzer) {} @@ -2425,7 +2425,7 @@ class AutoTensorizeMappingProposer { private: // The extractor that has extracted information for auto tensorization from the workload and the // tensor intrin. - const AutoTensorizeExtractor* extractor_; + const AutoTensorizeComparator* extractor_; // The arithmetic analyzer. arith::Analyzer* analyzer_; /*! \brief Potential mappings on RHS for each variable on LHS */ @@ -2443,7 +2443,7 @@ Optional GetAutoTensorizeMappingInfo(const tir::Schedu // Ignore the scope of buffers when comparing, since we can do cache_read/write const StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false); const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); - AutoTensorizeExtractor extractor(self->mod); + AutoTensorizeComparator extractor(self->mod); if (!extractor.VisitStmt(block->block, desc_info.desc_block->block)) { return NullOpt; } diff --git a/src/tir/schedule/ir_comparator.cc b/src/tir/schedule/ir_comparator.cc index a1f1aedf85bf..2b68a7f9bc41 100644 --- a/src/tir/schedule/ir_comparator.cc +++ b/src/tir/schedule/ir_comparator.cc @@ -357,16 +357,16 @@ void TensorizeComparator::EmitError(const std::string& error_message) { /******** AutoTensorize Extractor ********/ -bool AutoTensorizeExtractor::VisitExprDefault_(const Object* op, const PrimExpr& other) { +bool AutoTensorizeComparator::VisitExprDefault_(const Object* op, const PrimExpr& other) { return false; } -bool AutoTensorizeExtractor::VisitStmtDefault_(const Object* op, const Stmt& other) { +bool AutoTensorizeComparator::VisitStmtDefault_(const Object* op, const Stmt& other) { return false; } template -bool AutoTensorizeExtractor::CompareArray(const Array& lhs, const Array& rhs, F cmp) { +bool AutoTensorizeComparator::CompareArray(const Array& lhs, const Array& rhs, F cmp) { if (lhs.same_as(rhs)) return true; if (lhs.size() != rhs.size()) return false; for (size_t i = 0; i < lhs.size(); ++i) { @@ -375,20 +375,20 @@ bool AutoTensorizeExtractor::CompareArray(const Array& lhs, const Array& r return true; } -bool AutoTensorizeExtractor::VisitStmt_(const BlockNode* op, const Stmt& other) { +bool AutoTensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt& other) { const auto* rhs = other.as(); // Check block equality. // All iter vars and buffer regions including the order should match. // When checking iter vars, DefEqual is used to remap variables. if (!is_scope_block) { - if (!CompareArray(op->iter_vars, rhs->iter_vars, &AutoTensorizeExtractor::CompareIterVar)) { + if (!CompareArray(op->iter_vars, rhs->iter_vars, &AutoTensorizeComparator::CompareIterVar)) { return false; } if (!CompareAnnotationMap(op->annotations, rhs->annotations)) { return false; } if (!CompareArray(op->alloc_buffers, rhs->alloc_buffers, - &AutoTensorizeExtractor::CompareBuffer)) { + &AutoTensorizeComparator::CompareBuffer)) { return false; } for (const IterVar& block_iter : op->iter_vars) { @@ -418,7 +418,7 @@ bool AutoTensorizeExtractor::VisitStmt_(const BlockNode* op, const Stmt& other) return VisitStmt(op->body, rhs->body); } -bool AutoTensorizeExtractor::CompareBuffer(const Buffer& lhs, const Buffer& rhs) { +bool AutoTensorizeComparator::CompareBuffer(const Buffer& lhs, const Buffer& rhs) { if (lhs.same_as(rhs)) return true; auto it = rhs_buffer_map_.find(rhs); bool equal; @@ -435,18 +435,18 @@ bool AutoTensorizeExtractor::CompareBuffer(const Buffer& lhs, const Buffer& rhs) return equal; } -bool AutoTensorizeExtractor::VisitStmt_(const BufferStoreNode* op, const Stmt& other) { +bool AutoTensorizeComparator::VisitStmt_(const BufferStoreNode* op, const Stmt& other) { const auto* rhs = other.as(); return CompareBufferAccess(op, rhs) && VisitExpr(op->value, rhs->value); } -bool AutoTensorizeExtractor::VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) { +bool AutoTensorizeComparator::VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) { const auto* rhs = other.as(); return CompareBufferAccess(op, rhs); } template -bool AutoTensorizeExtractor::CompareBufferAccess(const T* lhs, const T* rhs) { +bool AutoTensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) { if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false; auto it_lhs = lhs_buffer_indices_map_.find(lhs->buffer); if (it_lhs == lhs_buffer_indices_map_.end()) { diff --git a/src/tir/schedule/ir_comparator.h b/src/tir/schedule/ir_comparator.h index df0536a13a15..3c7f4511f733 100644 --- a/src/tir/schedule/ir_comparator.h +++ b/src/tir/schedule/ir_comparator.h @@ -110,12 +110,17 @@ class TensorizeComparator : public ExprComparator, public StmtComparator { std::unordered_map equal_map_; }; -/*! \brief IR comparator for auto tensorization. Extract correspondence between the IR of the - * workload and the tensor intrin. +/*! \brief IR comparator for auto tensorization. This comparator is used to extract correspondence + * between the IR of the workload (LHS) and the tensor intrin (RHS). Unlike `TensorizeComparator`, + * this comparator has relaxed requirements during comparison. It ignores the loop structure + * (number of loops and their extents) and buffer indices. It only requires the LHS and the + * RHS to have the same arithmetic operations and the same dtype. With such relaxed requirements, + * workloads that can only match the tensor intrin after certain transformations (e.g. im2col for + * conv2d) are allowed for auto tensorization. */ -class AutoTensorizeExtractor : public TensorizeComparator { +class AutoTensorizeComparator : public TensorizeComparator { public: - explicit AutoTensorizeExtractor(const IRModule& lhs_mod) + explicit AutoTensorizeComparator(const IRModule& lhs_mod) : TensorizeComparator(lhs_mod, /* assert_mode=*/false) {} private: @@ -134,6 +139,8 @@ class AutoTensorizeExtractor : public TensorizeComparator { bool CompareBufferAccess(const T* lhs, const T* rhs); public: + // Additional information extracted from LHS (the workload) and RHS (the tensor intrin). + /*! \brief Block iters in the LHS stmt. */ std::vector lhs_iters_; /*! \brief Block iters in the RHS stmt. */ From 58c796e97d0f6d61715c17a01a84a77df08baf0b Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 16 Jun 2022 13:59:30 -0700 Subject: [PATCH 4/6] Remove duplicate ComparaArray --- src/tir/schedule/ir_comparator.cc | 16 +++------------- src/tir/schedule/ir_comparator.h | 6 ++---- 2 files changed, 5 insertions(+), 17 deletions(-) diff --git a/src/tir/schedule/ir_comparator.cc b/src/tir/schedule/ir_comparator.cc index 2b68a7f9bc41..d8ac40ef0586 100644 --- a/src/tir/schedule/ir_comparator.cc +++ b/src/tir/schedule/ir_comparator.cc @@ -333,12 +333,12 @@ bool TensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) { return true; } -template -bool TensorizeComparator::CompareArray(const Array& lhs, const Array& rhs, F cmp) { +template +bool TensorizeComparator::CompareArray(const Array& lhs, const Array& rhs, F Self::*cmp) { if (lhs.same_as(rhs)) return true; if (lhs.size() != rhs.size()) return false; for (size_t i = 0; i < lhs.size(); ++i) { - if (!(this->*cmp)(lhs[i], rhs[i])) return false; + if (!(static_cast(this)->*cmp)(lhs[i], rhs[i])) return false; } return true; } @@ -365,16 +365,6 @@ bool AutoTensorizeComparator::VisitStmtDefault_(const Object* op, const Stmt& ot return false; } -template -bool AutoTensorizeComparator::CompareArray(const Array& lhs, const Array& rhs, F cmp) { - if (lhs.same_as(rhs)) return true; - if (lhs.size() != rhs.size()) return false; - for (size_t i = 0; i < lhs.size(); ++i) { - if (!(this->*cmp)(lhs[i], rhs[i])) return false; - } - return true; -} - bool AutoTensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt& other) { const auto* rhs = other.as(); // Check block equality. diff --git a/src/tir/schedule/ir_comparator.h b/src/tir/schedule/ir_comparator.h index 3c7f4511f733..a6d0455a2ee7 100644 --- a/src/tir/schedule/ir_comparator.h +++ b/src/tir/schedule/ir_comparator.h @@ -90,8 +90,8 @@ class TensorizeComparator : public ExprComparator, public StmtComparator { bool CompareAnnotationMap(const Map& lhs, const Map& rhs); template bool CompareBufferAccess(const T* lhs, const T* rhs); - template - bool CompareArray(const Array& lhs, const Array& rhs, F cmp); + template + bool CompareArray(const Array& lhs, const Array& rhs, F Self::*cmp); bool CompareRange(const Range& lhs, const Range& rhs); bool CompareIterVar(const IterVar& lhs, const IterVar& rhs); void EmitError(const std::string& error_message); @@ -132,8 +132,6 @@ class AutoTensorizeComparator : public TensorizeComparator { bool VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) override; - template - bool CompareArray(const Array& lhs, const Array& rhs, F cmp); bool CompareBuffer(const Buffer& lhs, const Buffer& rhs) override; template bool CompareBufferAccess(const T* lhs, const T* rhs); From b32dd62a781119e99f181fb7f91601a76ef71530 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 16 Jun 2022 15:16:18 -0700 Subject: [PATCH 5/6] lint --- src/tir/schedule/ir_comparator.h | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/tir/schedule/ir_comparator.h b/src/tir/schedule/ir_comparator.h index a6d0455a2ee7..394d82867393 100644 --- a/src/tir/schedule/ir_comparator.h +++ b/src/tir/schedule/ir_comparator.h @@ -110,13 +110,14 @@ class TensorizeComparator : public ExprComparator, public StmtComparator { std::unordered_map equal_map_; }; -/*! \brief IR comparator for auto tensorization. This comparator is used to extract correspondence - * between the IR of the workload (LHS) and the tensor intrin (RHS). Unlike `TensorizeComparator`, - * this comparator has relaxed requirements during comparison. It ignores the loop structure - * (number of loops and their extents) and buffer indices. It only requires the LHS and the - * RHS to have the same arithmetic operations and the same dtype. With such relaxed requirements, - * workloads that can only match the tensor intrin after certain transformations (e.g. im2col for - * conv2d) are allowed for auto tensorization. +/*! + * \brief IR comparator for auto tensorization. + * This comparator is used to extract correspondence between the IR of the workload (LHS) and the + * tensor intrin (RHS). Unlike `TensorizeComparator`, this comparator has relaxed requirements + * during comparison. It ignores the loop structure (number of loops and their extents) and buffer + * indices. It only requires the LHS and the RHS to have the same arithmetic operations and the same + * dtype. With such relaxed requirements, workloads that can only match the tensor intrin after + * certain transformations (e.g. im2col for conv2d) are allowed for auto tensorization. */ class AutoTensorizeComparator : public TensorizeComparator { public: From b1ab4fea4d8ac4492cdaf7b545e66bc793615a65 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 17 Jun 2022 14:32:32 -0700 Subject: [PATCH 6/6] check iter type --- src/tir/schedule/analysis.h | 1 + src/tir/schedule/analysis/analysis.cc | 20 +++++++++++++++----- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 277438d85d6f..b30cef829f1e 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -726,6 +726,7 @@ class AutoTensorizeMappingInfoNode : public Object { void VisitAttrs(AttrVisitor* v) { v->Visit("mappings", &mappings); + v->Visit("lhs_buffer_map", &lhs_buffer_map); v->Visit("rhs_buffer_indices", &rhs_buffer_indices); v->Visit("lhs_iters", &lhs_iters); v->Visit("rhs_iters", &rhs_iters); diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 821d81bb844f..3ee1ed28b857 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -2288,7 +2288,8 @@ class AutoTensorizeMappingProposer { // Collect the set of potential iter var mapping between the workload and the tensor intrin. // We analyze the appearance of each variable in the buffer indices of each buffer on LHS and // RHS. The appearance of a variable in the buffer indices is encoded as bit-masks (BufferMask). - // Variables on the LHS and the RHS with the same bit-mask are potential mappings. + // Variables on the LHS and the RHS with the same bit-mask and the same iter type are potential + // mappings. // // For example, consider the conv2d case. We will try to match the workload // conv2d[n, h, w, c] = sum_{rh, rw, rc} X[n, h + rh, w + rw, c + rc] * W[rh, rw, rc, c] @@ -2302,7 +2303,7 @@ class AutoTensorizeMappingProposer { // both buffer conv2d and W, and not in other buffers. Therefore, {n, h, w} <=> m is a potential // mapping. - // Note: the mapping is not unique when multiple variables in RHS has the same bit-mask. + // Note: the mapping is not unique when multiple variables on RHS has the same bit-mask. // This is currently not supported. using BufferMask = std::vector; @@ -2358,16 +2359,25 @@ class AutoTensorizeMappingProposer { } } - // Step 3: Find variables on LHS and RHS with the same buffer mask + // Step 3: Find variables on LHS and RHS with the same buffer mask. Ensure LHS and RHS vars + // have the same iter type. std::unordered_map mask_to_rhs_vars; for (const auto& kv : rhs_buffer_masks) { const VarNode* rhs_var = kv.first; const BufferMask& mask = kv.second; mask_to_rhs_vars[mask].insert(GetRef(rhs_var)); } - + std::unordered_map rhs_var_iter_type; + for (const auto& iter : extractor_->rhs_iters_) { + rhs_var_iter_type.emplace(iter->var.get(), iter->iter_type); + } for (const auto& iter : extractor_->lhs_iters_) { - lhs_feasible_vars_[iter->var] = mask_to_rhs_vars[lhs_buffer_masks[iter->var.get()]]; + auto& potential_mappings = lhs_feasible_vars_[iter->var]; + VarSet rhs_candidates = mask_to_rhs_vars[lhs_buffer_masks[iter->var.get()]]; + std::copy_if( + rhs_candidates.begin(), rhs_candidates.end(), + std::inserter(potential_mappings, potential_mappings.begin()), + [&](const Var& var) { return rhs_var_iter_type.at(var.get()) == iter->iter_type; }); } }