From b2b8e3cccc73a62e71ef6e34dcd2a70602e47120 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 8 Apr 2023 22:18:14 -0400 Subject: [PATCH 1/3] [ARITH] Enhance IterMapSimplify for symbolic This PR refactors and enhances DetectIterMap and IterMapSimplify to enable symbolic shape simplification. Specifically, we add a routine to combine multiple IterSplitExpr into one if they come from the same source. It is helpful to distinguish iterator from normal constants in the simplification process. IterMapSimplify takes advantage of these information. This improvements is helpful to simplify the indices in flattened buffer when there is symbolic shape involved and normal simplifier. Also updated FlattenBuffer to take benefit of the enhanced simplifier. Test cases are added. --- include/tvm/arith/iter_affine_map.h | 3 +- python/tvm/arith/__init__.py | 1 + python/tvm/arith/iter_affine_map.py | 43 +++ src/arith/canonical_simplify.cc | 10 +- src/arith/ir_mutator_with_analyzer.cc | 14 +- src/arith/ir_mutator_with_analyzer.h | 29 +- src/arith/iter_affine_map.cc | 335 ++++++++++++++---- src/arith/pattern_match.h | 18 - src/arith/product_normal_form.h | 89 +++++ .../schedule/primitive/loop_transformation.cc | 3 + src/tir/transforms/flatten_buffer.cc | 40 ++- .../unittest/test_arith_iter_affine_map.py | 97 +++++ .../test_tir_transform_flatten_buffer.py | 45 +++ 13 files changed, 634 insertions(+), 93 deletions(-) create mode 100644 src/arith/product_normal_form.h diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 0d8bd574ae6e..d89d3126a6f9 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -349,12 +349,13 @@ IterMapResult DetectIterMap(const Array& indices, const Map IterMapSimplify(const Array& indices, const Map& input_iters, const PrimExpr& input_pred, IterMapLevel check_level, - bool simplify_trivial_iterators = true); + arith::Analyzer* analyzer, bool simplify_trivial_iterators = true); /*! * \brief Apply the inverse of the affine transformation to the outputs. diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py index 401836aa1968..e2f6127e292e 100644 --- a/python/tvm/arith/__init__.py +++ b/python/tvm/arith/__init__.py @@ -30,6 +30,7 @@ from .iter_affine_map import IterMapExpr, IterMark, IterSplitExpr, IterSumExpr from .iter_affine_map import ( detect_iter_map, + iter_map_simplify, normalize_iter_map_to_expr, subspace_divide, inverse_affine_iter_map, diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py index 54dbcef32590..34487d00f02d 100644 --- a/python/tvm/arith/iter_affine_map.py +++ b/python/tvm/arith/iter_affine_map.py @@ -156,6 +156,49 @@ def detect_iter_map( ) +def iter_map_simplify( + indices, + input_iters, + predicate=True, + check_level=IterMapLevel.Surjective, + simplify_trivial_iterators=True, +): + """Simplify the indices using iter map detection. + + Parameters + ---------- + indices : List[PrimExpr] + The input indices + + input_iters : Map[Var, Range] + The domain of each input iterators. + + predicate : PrimExpr + The predicate constraints on the input iterators + + check_level : Union[str, IterMapLevel] + Checking level of iteration mapping + + simplify_trivial_iterators: bool + If true, iterators with extent of 1 will be replaced with a + constant value. + + Returns + ------- + results : IterMapResult + The iter map matching result. + The result's .indices is empty array if no match can be found. + + """ + if isinstance(check_level, str): + check_level = IterMapLevel.from_str(check_level) + elif check_level is None: + check_level = IterMapLevel.NoCheck + return _ffi_api.IterMapSimplify( + indices, input_iters, predicate, check_level, simplify_trivial_iterators + ) + + def normalize_iter_map_to_expr(expr): """Given an IterMapExpr, transform it to normal PrimExpr diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 14c91934d3b2..2ac126010964 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -27,6 +27,7 @@ #include "const_fold.h" #include "pattern_match.h" +#include "product_normal_form.h" #include "rewrite_simplify.h" namespace tvm { @@ -808,12 +809,17 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op) { } // normal path. + // this only happens when b is symbolic a = Normalize(a); b = Normalize(b); - if (op->a.same_as(a) && op->b.same_as(b)) { + + PrimExpr ret = MulAndNormalize(a, b); + const MulNode* mul = ret.as(); + + if (mul && mul->a.same_as(op->a) && mul->b.same_as(op->b)) { return GetRef(op); } else { - return Mul(a, b); + return ret; } } diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index 199f06191e4e..c201a245e190 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -31,13 +31,17 @@ namespace arith { using namespace tir; Stmt IRMutatorWithAnalyzer::VisitStmt_(const ForNode* op) { - analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); + // record the loop variable as iterators + Range dom = Range::FromMinExtent(op->min, op->extent); + analyzer_->Bind(op->loop_var, dom); + iter_vars_.Set(op->loop_var, dom); return StmtExprMutator::VisitStmt_(op); } Stmt IRMutatorWithAnalyzer::VisitStmt_(const BlockNode* op) { for (const auto& iter_var : op->iter_vars) { analyzer_->Bind(iter_var->var, iter_var->dom); + iter_vars_.Set(iter_var->var, iter_var->dom); } return StmtExprMutator::VisitStmt_(op); } @@ -75,7 +79,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { Optional else_case; { With ctx(analyzer_, real_condition); - then_case = this->VisitStmt(op->then_case); + WithRecordIterPredicate(real_condition, [&] { then_case = this->VisitStmt(op->then_case); }); } if (op->else_case) { With ctx(analyzer_, analyzer_->rewrite_simplify(Not(real_condition))); @@ -102,7 +106,9 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == tir::attr::thread_extent || op->attr_key == tir::attr::virtual_thread) { IterVar iv = Downcast(op->node); ICHECK_NE(iv->thread_tag.length(), 0U); - analyzer_->Bind(iv->var, Range::FromMinExtent(0, op->value)); + Range dom = Range::FromMinExtent(make_zero(op->value.dtype()), op->value); + analyzer_->Bind(iv->var, dom); + iter_vars_.Set(iv->var, dom); Stmt stmt = StmtExprMutator::VisitStmt_(op); return stmt; } else { @@ -135,7 +141,7 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) { PrimExpr true_value, false_value; { With constraint(analyzer_, cond); - true_value = this->VisitExpr(op->args[1]); + WithRecordIterPredicate(cond, [&] { true_value = this->VisitExpr(op->args[1]); }); } { With constraint(analyzer_, analyzer_->rewrite_simplify(Not(cond))); diff --git a/src/arith/ir_mutator_with_analyzer.h b/src/arith/ir_mutator_with_analyzer.h index 3bd3a98a8445..ed62c91df913 100644 --- a/src/arith/ir_mutator_with_analyzer.h +++ b/src/arith/ir_mutator_with_analyzer.h @@ -25,6 +25,7 @@ #define TVM_ARITH_IR_MUTATOR_WITH_ANALYZER_H_ #include +#include #include #include @@ -63,8 +64,34 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator { protected: /*! \brief internal analyzer field. */ Analyzer* analyzer_; + // the following two fields are useful in case we want + // note however that iter map analysis are usually more + // expensive and we only encourage doing them during + // necessary cases like layout remapping + /*! \brief Recorded loop iterators */ + Map iter_vars_; + /*! \brief iterator predicates */ + Array iter_predicates_; + /*! + * \brief Run callback while trying to record iter predicate + * \param conditon Condition to be checked. + * \param callback The callback to be called. + */ + template + void WithRecordIterPredicate(PrimExpr condition, FLambda callback) { + auto f_use_itervar = [this](const tir::VarNode* v) { + return iter_vars_.count(GetRef(v)); + }; + // simple heuristics for detecting predicate + if (tir::UsesVar(condition, f_use_itervar)) { + iter_predicates_.push_back(condition); + callback(); + iter_predicates_.pop_back(); + } else { + callback(); + } + } }; - } // namespace arith } // namespace tvm #endif // TVM_ARITH_IR_MUTATOR_WITH_ANALYZER_H_ diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 05af5b40702d..7522a819661f 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -33,6 +33,7 @@ #include "../support/utils.h" #include "const_fold.h" #include "pattern_match.h" +#include "product_normal_form.h" namespace tvm { namespace arith { @@ -678,7 +679,20 @@ class IterMapRewriter : public ExprMutator { iter_min = max(predicate_induced_min.value(), iter_min); } if (predicate_induced_max.defined()) { - iter_max = min(predicate_induced_max.value(), iter_max); + // NOTE: important to do explicit prove here + // because we have a domain knowledge that most predicates + // tries to constraint the expression and we favor predicate_induced_max + // when available. + // + // This path can help enable predicate simplfication for + // symbolic cases like: + // + // z = x * 32 + y < n * 16 where x in [0, (n+1)//2), y in [0, 32) + if (analyzer_->CanProve(predicate_induced_max.value() <= iter_max)) { + iter_max = predicate_induced_max.value(); + } else { + iter_max = min(predicate_induced_max.value(), iter_max); + } } if (!is_zero(iter_min)) { // structured form's offset should be updated @@ -733,6 +747,217 @@ class IterMapRewriter : public ExprMutator { } } + /** + * \brief Helper method to find base iterator which is the + * iterator with the smallest scale. + * + * \param expr The expression to search. + * \param skip_flag Whether to skip the position + * \param match_source Whether to only match the same source. + * \param rbegin The last index to start reverse searching, -1 means everything. + * \return Whether we can find one. + */ + int FindBaseIter(const IterSumExpr& expr, const std::vector& skip_flag, + Optional match_source, int rbegin = -1) { + if (rbegin == -1) { + rbegin = static_cast(expr->args.size()) - 1; + } + // First, find the scale with minimum size of constant scale. + // use reverse search as usually smallest is ordered on the right + int base_index = -1; + int64_t min_const_scale = 0; + for (int i = rbegin; i >= 0; --i) { + if (skip_flag[i]) continue; + if (match_source.defined() && !match_source.same_as(expr->args[i]->source)) continue; + if (const auto* op = expr->args[i]->scale.as()) { + if (base_index == -1 || op->value < min_const_scale) { + min_const_scale = op->value; + base_index = static_cast(i); + } + } + } + // cannot find constant scale, try to find scale that comes with + // smallest product size, which usually is smallest in symbolic land + // x < x * y + int min_reduce_size = 0; + for (int i = rbegin; i >= 0; --i) { + if (skip_flag[i]) continue; + if (match_source.defined() && !match_source.same_as(expr->args[i]->source)) continue; + int reduce_size = 0; + auto fcollect = [&](const PrimExpr&) { ++reduce_size; }; + UnpackReduction(expr->args[i]->scale, fcollect); + if (base_index == -1 || reduce_size < min_reduce_size) { + min_reduce_size = reduce_size; + base_index = static_cast(i); + } + } + return base_index; + } + + /*! + * \brief Helper method to find iterator with exact the expected scale. + * \param expr The expression. + * \param skip_flag skip_flag the position already visited to skip. + * \param match_source Must match the same source. + * \param expected_scale The expected_scale. + * \param rbegin The last index to start reverse searching, -1 means everything. + * \return -1 if not no match found, otherwise return the index. + */ + int FindIterWithExactScale(const IterSumExpr& expr, const std::vector& skip_flag, + const PrimExpr& expected_scale, Optional match_source, + int rbegin = -1) { + if (rbegin == -1) { + rbegin = static_cast(expr->args.size()) - 1; + } + // use reverse search, as smallest scale usually are near the end. + for (int j = rbegin; j >= 0; --j) { + if (skip_flag[j]) continue; + if (match_source.defined() && !match_source.same_as(expr->args[j]->source)) continue; + const PrimExpr& cur_scale = expr->args[j]->scale; + // for bijective mapping, the matched scale must equal to expected scale + if (analyzer_->CanProveEqual(cur_scale, expected_scale)) { + return j; + } + } + return -1; + } + + /*! + * \brief Helper method to find iterator whose scale is smaller + * than but closest to the expected scale. + * \param expr The expression. + * \param skip_flag skip_flag the position already visited to skip. + * \param expected_scale The expected_scale. + * \return -1 if not no match found, otherwise return the index. + */ + int FindIterSmallerClosestToScale(const IterSumExpr& expr, const std::vector& skip_flag, + const PrimExpr& expected_scale, PrimExpr* out_matched_scale) { + // use reverse search, as smallest scale usually are near the end. + int matched_pos = -1; + PrimExpr matched_scale; + for (int j = static_cast(expr->args.size()) - 1; j >= 0; --j) { + if (skip_flag[j]) continue; + const PrimExpr& cur_scale = expr->args[j]->scale; + // find the closest scale which is less or equal to expected scale + if (analyzer_->CanProveGreaterEqual(expected_scale - cur_scale, 0) && + analyzer_->CanProveGreaterEqual(cur_scale, 0)) { + if (matched_pos == -1 || analyzer_->CanProveLess(matched_scale - cur_scale, 0)) { + matched_pos = j; + matched_scale = cur_scale; + } + } + } + *out_matched_scale = matched_scale; + return matched_pos; + } + /*! + * \brief IterSum = x1*c1 + x2*c2 + ... + xn*cn + base + * = (x1*s1 + x2*s2 + ... + xn)*cn + base + * + * Try to combine consecutives IterSplit. + * This is helpful to combine iterators from consecutive splits. + * + * \param expr The input sum. + * \param check_level The check level if iter mapping. + * \return The sum with the fused IterMark and extra offset if succeed. + */ + Optional TryCombineSplitFromSameSource(IterSumExpr expr) { + if (expr->args.size() <= 1) return NullOpt; + std::unordered_map hit_count; + // most iter map are small n < 5 + // so we can afford N^2 complexity + bool has_overlap = false; + for (size_t i = 0; i < expr->args.size(); ++i) { + auto it = hit_count.find(expr->args[i]->source); + if (it != hit_count.end()) { + ++it->second; + has_overlap = true; + } else { + hit_count[expr->args[i]->source] = 1; + } + } + if (!has_overlap) return NullOpt; + + std::vector visited(expr->args.size(), false); + Array reverse_flattened_iters; + + // Start eliminating the iterators + for (int rend = static_cast(expr->args.size()) - 1; rend >= 0;) { + if (visited[rend]) { + --rend; + continue; + } + if (hit_count.at(expr->args[rend]->source) == 1) { + reverse_flattened_iters.push_back(expr->args[rend]); + visited[rend] = true; + --rend; + continue; + } + // NOTE: split have the following pattern + // + // result = (source // lower_factor) % extent * scale + // = (source % (extent * lower_factor)) // lower_factor * scale (rule A) + // + // Try to simplify with the following rule: + // + // ((x // (c * s)) % m) * s + ((x // c) % s) + // => (x // c) % (m * s) + // + // Assume we have the following split relations: + // - lhs = ((x // (c * s)) % m) * (s * t) + // - rhs = ((x // c) % s) * t + // - result = combine(lhs, rhs) = (x // c) % (m * s) * t + // + // Key things to match: + // - lhs->lower_factor == rhs->lower_factor * rhs->extent + // - lhs->scale == rhs->extent * rhs->scale + // + // The final result contains the following relation + // - result->lower_factor = rhs->lower_factor + // - result->scale = rhs->scale + // - result->extent = lhs->extent * rhs->extent + // Find base index, must have a candidate to make progress + int matched_index = FindBaseIter(expr, visited, expr->args[rend]->source, rend); + ICHECK_NE(matched_index, -1); + visited[matched_index] = true; + IterSplitExpr rhs_iter = expr->args[matched_index]; + + while (true) { + // NOTE: mul order [lower_factor, extent, scale] + PrimExpr lhs_scale = MulAndNormalize(rhs_iter->extent, rhs_iter->scale); + matched_index = FindIterWithExactScale(expr, visited, lhs_scale, rhs_iter->source, rend); + if (matched_index == -1) break; + IterSplitExpr lhs_iter = expr->args[matched_index]; + ICHECK(rhs_iter->source.same_as(lhs_iter->source)); + PrimExpr lhs_lower_factor = MulAndNormalize(rhs_iter->lower_factor, rhs_iter->extent); + if (!analyzer_->CanProveEqual(lhs_iter->lower_factor, lhs_lower_factor)) break; + // all patterns match + visited[matched_index] = true; + // Update rhs iter to result, only update of extent is necessary + rhs_iter.CopyOnWrite()->extent = MulAndNormalize(lhs_iter->extent, rhs_iter->extent); + } + // push back the combined iter in rhs_iter + reverse_flattened_iters.push_back(rhs_iter); + } + + // if we simplify to sum([iter_sum] * scale), fold to previous iter sum + if (reverse_flattened_iters.size() == 1 && is_zero(expr->base)) { + IterSplitExpr iter = reverse_flattened_iters[0]; + if (is_one(iter->lower_factor) && + analyzer_->CanProveEqual(iter->source->extent, iter->extent)) { + if (auto* ptr = iter->source->source.as()) { + IterSumExpr ref = GetRef(ptr); + MulToLhs(ref.CopyOnWrite(), iter->scale); + return ref; + } + } + } + + IterSumExpr simplified_sum = expr; + simplified_sum.CopyOnWrite()->args = reverse_flattened_iters; + return simplified_sum; + } + /*! * \brief IterSum = x1*c1 + x2*c2 + ... + xn*cn + base * = (x1*s1 + x2*s2 + ... + xn)*cn + base @@ -745,64 +970,45 @@ class IterMapRewriter : public ExprMutator { * \return The sum with the fused IterMark and extra offset if succeed. */ Optional TryFuseIters(IterSumExpr expr, IterMapLevel check_level) { + if (auto opt = TryCombineSplitFromSameSource(expr)) { + expr = opt.value(); + } // select the iterators in order std::vector visited(expr->args.size(), false); + int base_index = FindBaseIter(expr, visited, NullOpt); + if (base_index == -1) return NullOpt; + PrimExpr base_scale = expr->args[base_index]->scale; + std::vector flattened_iters, grouped_iters; - // canonicalize the expression into two different forms: flattened form and structured form - // step0. check if find the base scale first - Optional base_scale = NullOpt; - size_t base_index = 0; - for (size_t i = 0; i < expr->args.size(); ++i) { - if (const auto* op = expr->args[i]->scale.as()) { - if (!base_scale || op->value < base_scale.value()->value) { - base_scale = GetRef(op); - base_index = i; - } - } - } - if (!base_scale) { - return NullOpt; - } + // check if it can be remapped into a fused pattern. PrimExpr expected_extra_base = 0; PrimExpr tail_extent = 0; - PrimExpr expected_scale = base_scale.value(); - for (size_t i = 0; i < expr->args.size();) { - // find position such that expr->args[j] match expected scale - int j = i == 0 ? base_index : expr->args.size() - 1; + PrimExpr expected_scale = base_scale; - size_t matched_pos = expr->args.size(); + for (size_t i = 0; i < expr->args.size();) { PrimExpr matched_scale{nullptr}; bool is_exact_match{false}; - - for (; j >= 0; --j) { - if (visited[j]) { - continue; - } - const PrimExpr& cur_scale = expr->args[j]->scale; - - // for bijective mapping, the matched scale must equal to expected scale - if (analyzer_->CanProveEqual(cur_scale, expected_scale)) { - matched_pos = j; - matched_scale = cur_scale; - is_exact_match = true; - break; - } - if (check_level != IterMapLevel::Bijective && base_scale.value()->value == 1) { - // find the closest scale which is less or equal to expected scale - if (analyzer_->CanProveGreaterEqual(expected_scale - cur_scale, 0) && - analyzer_->CanProveGreaterEqual(cur_scale, 0)) { - if (matched_pos == expr->args.size() || - analyzer_->CanProveLess(matched_scale - cur_scale, 0)) { - matched_pos = j; - matched_scale = cur_scale; - } - } + // find position such that expr->args[j] match expected scale + // if it is first step, we can simply start with base index + int matched_pos = + i == 0 ? base_index : FindIterWithExactScale(expr, visited, expected_scale, NullOpt); + if (matched_pos != -1) { + matched_scale = expected_scale; + is_exact_match = true; + } + if (matched_pos == -1) { + // if exact scale is not possible, try to find an iter with scale + // that is smaller but closest to the scale. + if (check_level != IterMapLevel::Bijective && is_const_int(base_scale, 1)) { + matched_pos = + FindIterSmallerClosestToScale(expr, visited, expected_scale, &matched_scale); } } - if (matched_pos == expr->args.size()) { + if (matched_pos == -1) { return NullOpt; } + ICHECK(matched_scale.defined()); // look for the longest constrained iter started from expr->args[j] // Example: expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) // predicate: j*2 + k < 9 @@ -841,25 +1047,27 @@ class IterMapRewriter : public ExprMutator { auto iter = sum_fuse_map_.find(constraint_to_match.value()); ICHECK(iter != sum_fuse_map_.end()); const IterMarkWithOffset& iter_matched = iter->second; - grouped_iters.emplace_back(iter_matched.mark, div(matched_scale, base_scale.value())); + grouped_iters.emplace_back(iter_matched.mark, floordiv(matched_scale, base_scale)); expected_extra_base += iter_matched.offset * matched_scale; if (!is_exact_match) { tail_extent += expected_scale - matched_scale; } - expected_scale = matched_scale * iter_matched.mark->extent; + // NOTE: order [lower_factor, extent, scale] + expected_scale = MulAndNormalize(iter_matched.mark->extent, matched_scale); // move forward i += constraint_to_match.value()->args.size(); } else { // constraint_to_match not found, skip this iterator visited[matched_pos] = true; IterSplitExpr arg = expr->args[matched_pos]; - arg.CopyOnWrite()->scale = analyzer_->Simplify(div(arg->scale, base_scale.value())); + arg.CopyOnWrite()->scale = analyzer_->Simplify(floordiv(arg->scale, base_scale)); flattened_iters.push_back(arg); grouped_iters.push_back(arg); if (!is_exact_match) { tail_extent += expected_scale - matched_scale; } - expected_scale = matched_scale * expr->args[matched_pos]->extent; + // NOTE: order [lower_factor, extent, scale] + expected_scale = MulAndNormalize(expr->args[matched_pos]->extent, matched_scale); ++i; } } @@ -875,20 +1083,18 @@ class IterMapRewriter : public ExprMutator { auto it = sum_fuse_map_.find(flattened_form); if (it != sum_fuse_map_.end()) { // old iter - if (!analyzer_->CanProveEqual(expected_extra_base, it->second.offset * base_scale.value())) { + if (!analyzer_->CanProveEqual(expected_extra_base, it->second.offset * base_scale)) { // the extra offset is not consistent with old return NullOpt; } - return IterSumExpr({IterSplitExpr(it->second.mark, base_scale.value())}, + return IterSumExpr({IterSplitExpr(it->second.mark, base_scale)}, expr->base + expected_extra_base); } else { // new iter, form a new mark - IterMark mark = - IterMark(structured_form, div(expected_scale, base_scale.value()) + tail_extent); + IterMark mark = IterMark(structured_form, div(expected_scale, base_scale) + tail_extent); sum_fuse_map_[flattened_form] = IterMarkWithOffset(mark, 0); flattened_map_[structured_form] = flattened_form; - return IterSumExpr({IterSplitExpr(mark, base_scale.value())}, - expr->base + expected_extra_base); + return IterSumExpr({IterSplitExpr(mark, base_scale)}, expr->base + expected_extra_base); } } @@ -1167,6 +1373,7 @@ IterMapResult DetectIterMap(const Array& indices, const Mappadding_predicate = rewriter.padding_predicate(); + // // Step1: IterIndependenceChecker checks if the iterator are independent. if (!rewriter.CheckMapping(rewrite_indices, check_level)) { @@ -1760,10 +1967,9 @@ TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed(NormalizeIter Array IterMapSimplify(const Array& indices, const Map& input_iters, const PrimExpr& input_pred, IterMapLevel check_level, - bool simplify_trivial_iterators) { + arith::Analyzer* ana, bool simplify_trivial_iterators) { if (!IterRangeSanityCheck(input_iters)) return indices; - Analyzer analyzer; - auto res = DetectIterMap(indices, input_iters, input_pred, check_level, &analyzer, + auto res = DetectIterMap(indices, input_iters, input_pred, check_level, ana, /*simplify_trivial_iterators=*/simplify_trivial_iterators); Array rewrite = res->indices; @@ -1772,11 +1978,20 @@ Array IterMapSimplify(const Array& indices, const Map simplified; simplified.reserve(rewrite.size()); - IterMapToExprNormalizer converter(&analyzer); + IterMapToExprNormalizer converter(ana); for (const auto& expr : rewrite) simplified.push_back(converter.Convert(expr)); return simplified; } +TVM_REGISTER_GLOBAL("arith.IterMapSimplify") + .set_body_typed([](const Array& indices, const Map& input_iters, + const PrimExpr& input_pred, int check_level, + bool simplify_trivial_iterators) { + arith::Analyzer ana; + return IterMapSimplify(indices, input_iters, input_pred, IterMapLevel(check_level), &ana, + simplify_trivial_iterators); + }); + /*! * \brief Divider to divide the bindings into two sets of bindings(outer and inner) * such that binding_i = Y_i * E(Xi) + Xi, where E(X) is the extent of X. diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index 0bb172e56053..d057a840e8b7 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -914,24 +914,6 @@ inline std::enable_if_t<(std::is_base_of_v, TPattern> && ... & matches_one_of(const TPattern&... patterns) { return PMatchesOneOf(patterns...); } - -/*! - * \brief Unpack reduction by calling each leaf via fleaf. - * - * \param value The expression value. - * \tparam TNode the reduction node to match. - * \tparam FLeaf The callback function at leaf. - */ -template -inline void UnpackReduction(const PrimExpr& value, FLeaf fleaf) { - if (const TNode* node = value.as()) { - UnpackReduction(node->a, fleaf); - UnpackReduction(node->b, fleaf); - } else { - fleaf(value); - } -} - } // namespace arith } // namespace tvm #endif // TVM_ARITH_PATTERN_MATCH_H_ diff --git a/src/arith/product_normal_form.h b/src/arith/product_normal_form.h new file mode 100644 index 000000000000..178b372b2801 --- /dev/null +++ b/src/arith/product_normal_form.h @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file product_normal_form.h + * \brief Centralized location related to simplifying prod of results. + */ +#ifndef TVM_ARITH_PRODUCT_NORMAL_FORM_H_ +#define TVM_ARITH_PRODUCT_NORMAL_FORM_H_ + +#include +#include + +namespace tvm { +namespace arith { + +/*! + * \brief Unpack reduction by calling each leaf via fleaf + * + * \param value The expression value. + * \tparam TNode the reduction node to match. + * \tparam FLeaf The callback function at leaf. + */ +template +inline void UnpackReduction(const PrimExpr& value, FLeaf fleaf) { + if (const TNode* node = value.as()) { + UnpackReduction(node->a, fleaf); + UnpackReduction(node->b, fleaf); + } else { + fleaf(value); + } +} + +/*! + * \brief Helper function to multiply extent and and re-normalize. + * + * Multiply extent scale and re-normalize to form (x * y) * z + * + * NOTE on multiplication order: when have have shape (s[0], s[1], s[2]), + * we prefer to multiple in order of s[0] * s[1] * s[2] + * + * That means when we are looking at the pattern of split iterator: + * + * - result = (source // lower_factor) % extent * scale + * + * We should take the order of lower_factor, extent, scale. + * Please do best keeping this order to make future simplifcation easy. + * + * \param lhs The lhs iterator + * \param rhs The rhs iterator + * \return the result. + */ +inline PrimExpr MulAndNormalize(const PrimExpr& lhs, const PrimExpr& rhs) { + int64_t cscale = 1; + PrimExpr res = tir::make_const(lhs.dtype(), 1); + auto fcollect = [&](PrimExpr val) { + if (const auto* intimm = val.as()) { + cscale *= intimm->value; + } else { + res = res * val; + } + }; + UnpackReduction(lhs, fcollect); + UnpackReduction(rhs, fcollect); + if (cscale != 1) { + res = res * tir::make_const(res.dtype(), cscale); + } + return res; +} + +} // namespace arith +} // namespace tvm +#endif // TVM_ARITH_PRODUCT_NORMAL_FORM_H_ diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index a26843b7bd05..b0cdfa00c92e 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -120,6 +120,7 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { /*input_iters=*/loop_var2extent_, /*input_pred=*/op->predicate, /*check_level=*/arith::IterMapLevel::Surjective, + /*analyzer=*/&analzyer_, /*simplify_trivial_iterators=*/!preserve_unit_iters_); if (v.same_as(op->iter_values)) { return GetRef(op); @@ -134,6 +135,8 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { MapNode* opaque_blocks_; /*! \brief The range of loops */ Map loop_var2extent_; + /*! \brief Internal analyzer */ + arith::Analyzer analzyer_; /*! \brief Whether or not to simplify unit iterators */ bool preserve_unit_iters_; }; diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index d51a44887f54..5a248dfbc311 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -21,10 +21,12 @@ * \file flatten_buffer.cc */ +#include #include #include #include +#include "../../arith/ir_mutator_with_analyzer.h" #include "ir_utils.h" namespace tvm { @@ -34,10 +36,11 @@ namespace tir { * \brief Transform multi-dimension BufferLoad/BufferStore into device-supported dimension * for the TIR not contains opaque block. */ -class BufferFlattener : public StmtExprMutator { +class BufferFlattener : public arith::IRMutatorWithAnalyzer { public: static PrimFunc Flatten(PrimFunc func) { - auto pass = BufferFlattener(); + arith::Analyzer ana; + auto pass = BufferFlattener(&ana); auto writer = func.CopyOnWrite(); writer->body = pass.VisitStmt(func->body); // The buffers in func->buffer_map are deliberately left @@ -48,7 +51,10 @@ class BufferFlattener : public StmtExprMutator { } private: - BufferFlattener() {} + using IRMutatorWithAnalyzer::VisitStmt; + using IRMutatorWithAnalyzer::VisitStmt_; + + explicit BufferFlattener(arith::Analyzer* ana) : IRMutatorWithAnalyzer(ana) {} Stmt VisitStmt_(const BlockNode* op) final { ICHECK_EQ(op->match_buffers.size(), 0) @@ -158,13 +164,17 @@ class BufferFlattener : public StmtExprMutator { return it->second; } auto flattened = buf.GetFlattenedBuffer(); + auto writer = flattened.CopyOnWrite(); // TODO(Lunderberg): Move the handling of boolean into a // dedicated pass. if (flattened->dtype == DataType::Bool()) { - auto writer = flattened.CopyOnWrite(); writer->dtype = DataType::Int(8); } + // canonicalize shape + for (size_t i = 0; i < flattened->shape.size(); ++i) { + writer->shape.Set(i, analyzer_->canonical_simplify(flattened->shape[i])); + } buffer_remap_[buf] = flattened; return flattened; @@ -206,10 +216,26 @@ class BufferFlattener : public StmtExprMutator { } } + Array GetSimplifiedElemOffset(const Buffer& buffer, const Array& indices) { + auto flattened_indices = buffer->ElemOffset(indices); + // Use IterMapSimplify to enable constant fold of fused indices + // IterMapSimplify is more powerful and time-consuming than normal + // simplify as it tries to deal with symbolic fusion + // + // Only use to handle indices during layout transformations + // So we restrict the use to here + PrimExpr pred = const_true(); + for (PrimExpr val : iter_predicates_) { + pred = pred && val; + } + return arith::IterMapSimplify(flattened_indices, this->iter_vars_, pred, + arith::IterMapLevel::Surjective, this->analyzer_); + } + template Node VisitBufferAccess(Node node) { ICHECK(node->buffer.defined()); - auto flattened_indices = node->buffer->ElemOffset(node->indices); + auto flattened_indices = GetSimplifiedElemOffset(node->buffer, node->indices); Buffer flattened_buffer = GetFlattenedBuffer(node->buffer); auto writer = node.CopyOnWrite(); @@ -232,8 +258,8 @@ class BufferFlattener : public StmtExprMutator { max_values.push_back(range->min + range->extent - 1); } - Array flattened_min = orig_buf->ElemOffset(min_values); - Array flattened_max = orig_buf->ElemOffset(max_values); + Array flattened_min = GetSimplifiedElemOffset(orig_buf, min_values); + Array flattened_max = GetSimplifiedElemOffset(orig_buf, max_values); Array flattened_ranges; ICHECK_EQ(flattened_min.size(), flattened_max.size()); diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index 5ce729604504..7d8f5edb99ea 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -84,6 +84,29 @@ def assert_iter_sum_pattern( tvm.ir.assert_structural_equal(sum_expr, expect_iter) +def assert_iter_map_simplfy( + expect_dict, dom_map, predicate=True, check_level="surjective", simplify_trivial_iterators=True +): + keys = list(expect_dict.keys()) + imap = tvm.arith.detect_iter_map( + keys, + dom_map, + predicate=predicate, + check_level=check_level, + simplify_trivial_iterators=simplify_trivial_iterators, + ) + res = tvm.arith.iter_map_simplify( + keys, + dom_map, + predicate=predicate, + check_level=check_level, + simplify_trivial_iterators=simplify_trivial_iterators, + ) + for i, input_expr in enumerate(keys): + expected_expr = expect_dict[input_expr] + tvm.ir.assert_structural_equal(res[i], expected_expr) + + def assert_iter_sum_failure(iters, dom_map, predicate=True, check_level="surjective"): res = tvm.arith.detect_iter_map( list(iters), dom_map, predicate=predicate, check_level=check_level @@ -1087,5 +1110,79 @@ def test_overlapped_fuse(): assert_iter_sum_failure([5 * x + 2 * y], var_dom([(x, 4), (y, 3)]), check_level="surjective") +def test_iter_map_simplify_symbolic_case(): + """Test itermap simplify""" + x = tvm.tir.Var("x", "int64") + y = tvm.tir.Var("y", "int64") + z = x * 32 + y + + n = tvm.tir.SizeVar("n", "int64") + + def simple_fuse0(x): + return (x // n) * n + x % n + + assert_iter_map_simplfy({simple_fuse0(x): x}, var_dom([(x, n * 32)])) + + assert_iter_map_simplfy({simple_fuse0(z): z}, var_dom([(x, n), (y, 32)])) + + def fsymbolic_fuse0(x): + return ((x // (n * n)) % 32) * (n * n) + ((x // n) % n) * n + x % n + + assert_iter_map_simplfy({fsymbolic_fuse0(x): x}, var_dom([(x, n * n * 32)])) + + assert_iter_map_simplfy({fsymbolic_fuse0(z): z}, var_dom([(x, n * n), (y, 32)])) + + def fsymbolic_fuse1(x): + return ((x % (n * n * 32)) // (n * n) * n + (x % (n * n) // n)) * n + x % n + + assert_iter_map_simplfy({fsymbolic_fuse1(x): x}, var_dom([(x, n * n * 32)])) + + assert_iter_map_simplfy({fsymbolic_fuse1(z): z}, var_dom([(x, n * n), (y, 32)])) + + def fsymbolic_fuse2(i): + return (i // (n * n) * n + i % (n * n) // n) * n + i % n + + assert_iter_map_simplfy({fsymbolic_fuse2(x): x}, var_dom([(x, n * n * 32)])) + + +def test_iter_map_simplify_symbolic_predicate(): + """Test itermap simplify""" + x = tvm.tir.Var("x", "int64") + y = tvm.tir.Var("y", "int64") + + n = tvm.tir.SizeVar("n", "int64") + + def simple_fuse0(x): + return (x // n) * n + x % n + + z = x * 32 + y + assert_iter_map_simplfy( + {simple_fuse0(z): z}, var_dom([(x, (n + 1) // 2), (y, 32)]), predicate=(z < n * 16) + ) + + def fsymbolic_fuse2(i): + return (i // (n * n) * n + i % (n * n) // n) * n + i % n + + z = x * 64 + y + assert_iter_map_simplfy( + {fsymbolic_fuse2(z): z}, + var_dom([(x, (n * n + 1) // 2), (y, 64)]), + predicate=(z < n * n * 32), + ) + +def test_iter_map_simplify_unit_loop_order(): + """Test itermap simplify""" + x = tvm.tir.Var("x", "int64") + y = tvm.tir.Var("y", "int64") + z = tvm.tir.Var("y", "int64") + + # trivial iterators can be found at any when comparing via scale + # ensure order unchange + assert_iter_map_simplfy( + {x+y+z: x+y+z}, var_dom([(x, 1), (y, 1), (z, 1)]), + simplify_trivial_iterators=False + ) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index c68dbd9ada6d..20f91b639497 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -148,6 +148,51 @@ def expected(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: C[i * m + j] = B[j] * 2.0 +class TestFusedSymbolic(BaseCompare): + """Dynamically-sized arrrays with fused iterator which can be flattened""" + + def before(a: T.handle, b: T.handle, n: T.int32) -> None: + A = T.match_buffer(a, (32, n, n), "float32") + B = T.match_buffer(b, (32, n, n), "float32") + + for i in range(0, n * n * 32): + B[i // (n * n), (i % (n * n)) // n, i % n] = A[i // (n * n), (i % (n * n)) // n, i % n] + + def expected(a: T.handle, b: T.handle, n: T.int32) -> None: + input_A = T.match_buffer(a, (32, n, n), "float32") + input_B = T.match_buffer(b, (32, n, n), "float32") + A = T.Buffer(n * n * 32, "float32", data=input_A.data) + B = T.Buffer(n * n * 32, "float32", data=input_B.data) + + for i in range(0, n * n * 32): + B[i] = A[i] + + +class TestFusedSymbolicWithPredicate(BaseCompare): + """Dynamically-sized arrrays with fused iterator which can be flattened with extra predicate""" + + def before(a: T.handle, b: T.handle, n: T.int32) -> None: + A = T.match_buffer(a, (32, n, n), "float32") + B = T.match_buffer(b, (32, n, n), "float32") + for bx, tx in T.grid((n * n + 1) // 2, 64): + if bx * 64 + tx < n * n * 32: + B[ + (bx * 64 + tx) // (n * n), ((bx * 64 + tx) % (n * n)) // n, (bx * 64 + tx) % n + ] = A[ + (bx * 64 + tx) // (n * n), ((bx * 64 + tx) % (n * n)) // n, (bx * 64 + tx) % n + ] + + def expected(a: T.handle, b: T.handle, n: T.int32) -> None: + input_A = T.match_buffer(a, (32, n, n), "float32") + input_B = T.match_buffer(b, (32, n, n), "float32") + A = T.Buffer(n * n * 32, "float32", data=input_A.data) + B = T.Buffer(n * n * 32, "float32", data=input_B.data) + + for bx, tx in T.grid((n * n + 1) // 2, 64): + if bx * 64 + tx < n * n * 32: + B[bx * 64 + tx] = A[bx * 64 + tx] + + class TestMultiAlloc(BaseCompare): """If multiple allocations occur, all are flattened.""" From e26e407982608593023be29052b9c0bc7be56a41 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 9 Apr 2023 20:33:28 -0400 Subject: [PATCH 2/3] Fix testcases that relates to unit loop ordering and simplifcations. Co-authored-by: Junru Shao --- src/arith/iter_affine_map.cc | 19 +- src/script/printer/tir/expr.cc | 26 +- .../unittest/test_arith_iter_affine_map.py | 13 +- ...t_meta_schedule_postproc_rewrite_layout.py | 12 +- ...schedule_postproc_rewrite_unbound_block.py | 58 +- ...le_schedule_rule_cross_thread_reduction.py | 14 +- .../test_meta_schedule_schedule_rule_mlt.py | 100 +- ..._meta_schedule_schedule_rule_mlt_intrin.py | 151 +- ...test_meta_schedule_schedule_rule_mlt_tc.py | 118 +- .../unittest/test_meta_schedule_space_cpu.py | 2492 ++++++++--------- .../unittest/test_meta_schedule_space_cuda.py | 1018 ++++--- .../test_meta_schedule_space_cuda_async.py | 115 +- .../test_meta_schedule_space_cuda_winograd.py | 402 ++- .../test_meta_schedule_trace_apply.py | 58 +- 14 files changed, 2211 insertions(+), 2385 deletions(-) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 7522a819661f..0ffe298c6ab1 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -879,7 +879,7 @@ class IterMapRewriter : public ExprMutator { if (!has_overlap) return NullOpt; std::vector visited(expr->args.size(), false); - Array reverse_flattened_iters; + std::vector reverse_flattened_iters; // Start eliminating the iterators for (int rend = static_cast(expr->args.size()) - 1; rend >= 0;) { @@ -940,21 +940,10 @@ class IterMapRewriter : public ExprMutator { reverse_flattened_iters.push_back(rhs_iter); } - // if we simplify to sum([iter_sum] * scale), fold to previous iter sum - if (reverse_flattened_iters.size() == 1 && is_zero(expr->base)) { - IterSplitExpr iter = reverse_flattened_iters[0]; - if (is_one(iter->lower_factor) && - analyzer_->CanProveEqual(iter->source->extent, iter->extent)) { - if (auto* ptr = iter->source->source.as()) { - IterSumExpr ref = GetRef(ptr); - MulToLhs(ref.CopyOnWrite(), iter->scale); - return ref; - } - } - } - IterSumExpr simplified_sum = expr; - simplified_sum.CopyOnWrite()->args = reverse_flattened_iters; + // flip the order so we preserve the original order + simplified_sum.CopyOnWrite()->args = + Array(reverse_flattened_iters.rbegin(), reverse_flattened_iters.rend()); return simplified_sum; } diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index 710f2eab22e2..438eaf06e3b5 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -321,18 +321,20 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return OperationDoc(OperationDocNode::Kind::kDiv, {a, b}); }); -#define TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NodeType, NodeObj, NodeFunc, OpString, OpKind) \ - TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \ - .set_dispatch("", \ - [](tir::NodeType node, ObjectPath p, IRDocsifier d) -> Doc { \ - ExprDoc a = d->AsDoc(node->a, p->Attr("a")); \ - ExprDoc b = d->AsDoc(node->b, p->Attr("b")); \ - PrimExpr ret = tvm::NodeFunc(node->a, node->b); \ - if (!ret->IsInstance()) { \ - return TIR(d, OpString)->Call({a, b}); \ - } \ - return OperationDoc(OperationDocNode::Kind::OpKind, {a, b}); \ - }); +#define TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NodeType, NodeObj, NodeFunc, OpString, OpKind) \ + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \ + .set_dispatch( \ + "", [](tir::NodeType node, ObjectPath p, IRDocsifier d) -> Doc { \ + ExprDoc a = d->AsDoc(node->a, p->Attr("a")); \ + ExprDoc b = d->AsDoc(node->b, p->Attr("b")); \ + PrimExpr ret = tvm::NodeFunc(node->a, node->b); \ + if (const auto* ret_node = ret.as()) { \ + if (ret_node->a.same_as(node->a) && ret_node->b.same_as(node->b)) { \ + return OperationDoc(OperationDocNode::Kind::OpKind, {a, b}); \ + } \ + } \ + return TIR(d, OpString)->Call({a, b}); \ + }); TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Add, AddNode, add, "Add", kAdd); TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Sub, SubNode, sub, "Sub", kSub); diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index 7d8f5edb99ea..8fdf6b157076 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -1170,17 +1170,24 @@ def fsymbolic_fuse2(i): predicate=(z < n * n * 32), ) + def test_iter_map_simplify_unit_loop_order(): """Test itermap simplify""" x = tvm.tir.Var("x", "int64") y = tvm.tir.Var("y", "int64") - z = tvm.tir.Var("y", "int64") + z = tvm.tir.Var("z", "int64") # trivial iterators can be found at any when comparing via scale # ensure order unchange assert_iter_map_simplfy( - {x+y+z: x+y+z}, var_dom([(x, 1), (y, 1), (z, 1)]), - simplify_trivial_iterators=False + {x + y + z: x + y + z}, var_dom([(x, 1), (y, 1), (z, 1)]), simplify_trivial_iterators=False + ) + + # Even with simplifcation, it should follow the original order + assert_iter_map_simplfy( + {x + y + (z // 4) * 4 + z % 4: x + y + z}, + var_dom([(x, 1), (y, 1), (z, 32)]), + simplify_trivial_iterators=False, ) diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py index c03ba83c0229..94d76a76922c 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py @@ -248,7 +248,7 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((3, 3, 64, 64), for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i0_3_init, i1_3_init, i2_3_init in T.grid(1, 1, 14, 2, 1, 4, 1): for i3_3_fused_init in T.vectorized(2): with T.block("conv2d_nhwc_init"): - nn = T.axis.spatial(1, i0_2_init + i0_3_init + i0_1) + nn = T.axis.spatial(1, i0_1 + i0_2_init + i0_3_init) yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2_init * 4 + i1_3_init) xx = T.axis.spatial(56, i2_3_init + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2_init) ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2_init * 2 + i3_3_fused_init) @@ -259,7 +259,7 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((3, 3, 64, 64), for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 1, 14, 2, 3, 3, 32, 1, 4, 1): for i3_3_fused in T.vectorized(2): with T.block("conv2d_nhwc_update"): - nn = T.axis.spatial(1, i0_2 + i0_3 + i0_1) + nn = T.axis.spatial(1, i0_1 + i0_2 + i0_3) yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2 * 4 + i1_3) xx = T.axis.spatial(56, i2_3 + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2) ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2 * 2 + i3_3_fused) @@ -333,7 +333,7 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((3, 3, 64, 64), for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i0_3_init, i1_3_init, i2_3_init in T.grid(1, 1, 14, 2, 1, 4, 1): for i3_3_fused_init in T.vectorized(2): with T.block("conv2d_nhwc_init"): - nn = T.axis.spatial(1, i0_2_init + i0_3_init + i0_1) + nn = T.axis.spatial(1, i0_1 + i0_2_init + i0_3_init) yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2_init * 4 + i1_3_init) xx = T.axis.spatial(56, i2_3_init + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2_init) ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2_init * 2 + i3_3_fused_init) @@ -344,7 +344,7 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((3, 3, 64, 64), for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 1, 14, 2, 3, 3, 32, 1, 4, 1): for i3_3_fused in T.vectorized(2): with T.block("conv2d_nhwc_update"): - nn = T.axis.spatial(1, i0_2 + i0_3 + i0_1) + nn = T.axis.spatial(1, i0_1 + i0_2 + i0_3) yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2 * 4 + i1_3) xx = T.axis.spatial(56, i2_3 + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2) ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2 * 2 + i3_3_fused) @@ -425,7 +425,7 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((3, 3, 64, 64), for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i0_3_init, i1_3_init, i2_3_init in T.grid(1, 1, 14, 2, 1, 4, 1): for i3_3_fused_init in T.vectorized(2): with T.block("conv2d_nhwc_init"): - nn = T.axis.spatial(1, i0_2_init + i0_3_init + i0_1) + nn = T.axis.spatial(1, i0_1 + i0_2_init + i0_3_init) yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2_init * 4 + i1_3_init) xx = T.axis.spatial(56, i2_3_init + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2_init) ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2_init * 2 + i3_3_fused_init) @@ -436,7 +436,7 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((3, 3, 64, 64), for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 1, 14, 2, 3, 3, 32, 1, 4, 1): for i3_3_fused in T.vectorized(2): with T.block("conv2d_nhwc_update"): - nn = T.axis.spatial(1, i0_2 + i0_3 + i0_1) + nn = T.axis.spatial(1, i0_1 + i0_2 + i0_3) yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2 * 4 + i1_3) xx = T.axis.spatial(56, i2_3 + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2) ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2 * 2 + i3_3_fused) diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py index 963f660ffb67..719d9c2f9515 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py @@ -158,34 +158,16 @@ def main( ax0 = T.axis.spatial( 64, ( - ( - i0_i1_fused_0_i0_i1_fused_1_fused_0 * 1024 - + i0_i1_fused_0_i0_i1_fused_1_fused_1 - ) - // 32 - * 32 - + ( - i0_i1_fused_0_i0_i1_fused_1_fused_0 * 1024 - + i0_i1_fused_0_i0_i1_fused_1_fused_1 - ) - % 32 + i0_i1_fused_0_i0_i1_fused_1_fused_0 * 1024 + + i0_i1_fused_0_i0_i1_fused_1_fused_1 ) // 768, ) ax1 = T.axis.spatial( 768, ( - ( - i0_i1_fused_0_i0_i1_fused_1_fused_0 * 1024 - + i0_i1_fused_0_i0_i1_fused_1_fused_1 - ) - // 32 - * 32 - + ( - i0_i1_fused_0_i0_i1_fused_1_fused_0 * 1024 - + i0_i1_fused_0_i0_i1_fused_1_fused_1 - ) - % 32 + i0_i1_fused_0_i0_i1_fused_1_fused_0 * 1024 + + i0_i1_fused_0_i0_i1_fused_1_fused_1 ) % 768, ) @@ -213,38 +195,18 @@ def main( ax0 = T.axis.spatial( 64, ( - ( - i0_i1_fused_0_i0_i1_fused_1_fused_0 * 262144 - + i0_i1_fused_0_i0_i1_fused_1_fused_1 * 1024 - + i0_i1_fused_0_i0_i1_fused_1_fused_2 - ) - // 32 - * 32 - + ( - i0_i1_fused_0_i0_i1_fused_1_fused_0 * 262144 - + i0_i1_fused_0_i0_i1_fused_1_fused_1 * 1024 - + i0_i1_fused_0_i0_i1_fused_1_fused_2 - ) - % 32 + i0_i1_fused_0_i0_i1_fused_1_fused_0 * 262144 + + i0_i1_fused_0_i0_i1_fused_1_fused_1 * 1024 + + i0_i1_fused_0_i0_i1_fused_1_fused_2 ) // 768, ) ax1 = T.axis.spatial( 768, ( - ( - i0_i1_fused_0_i0_i1_fused_1_fused_0 * 262144 - + i0_i1_fused_0_i0_i1_fused_1_fused_1 * 1024 - + i0_i1_fused_0_i0_i1_fused_1_fused_2 - ) - // 32 - * 32 - + ( - i0_i1_fused_0_i0_i1_fused_1_fused_0 * 262144 - + i0_i1_fused_0_i0_i1_fused_1_fused_1 * 1024 - + i0_i1_fused_0_i0_i1_fused_1_fused_2 - ) - % 32 + i0_i1_fused_0_i0_i1_fused_1_fused_0 * 262144 + + i0_i1_fused_0_i0_i1_fused_1_fused_1 * 1024 + + i0_i1_fused_0_i0_i1_fused_1_fused_2 ) % 768, ) diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py index 489b0ddef0e4..6f446ae14eda 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py @@ -121,7 +121,7 @@ def softmax_mn_1( for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_maxelem"): T.where(ax1_0 * 512 + ax1_1 < 256) - i0_1 = T.axis.spatial(256, ax0 + i0) + i0_1 = T.axis.spatial(256, i0 + ax0) k = T.axis.reduce(256, ax1_0 * 512 + ax1_1) T.reads(A[i0_1, k]) T.writes(T_softmax_maxelem_shared[i0_1]) @@ -188,7 +188,7 @@ def softmax_mn_2( for ax0, ax1_0 in T.grid(1, 32): for ax1_1 in T.thread_binding(8, thread="threadIdx.x"): with T.block("T_softmax_expsum"): - i0_4 = T.axis.spatial(256, ax0 + i0_3) + i0_4 = T.axis.spatial(256, i0_3 + ax0) k = T.axis.reduce(256, ax1_0 * 8 + ax1_1) T.reads(T_softmax_exp[i0_4, k]) T.writes(T_softmax_expsum_shared[i0_4]) @@ -225,7 +225,7 @@ def softmax_mn_3( for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_maxelem"): T.where(ax1_0 * 512 + ax1_1 < 256) - i0_1 = T.axis.spatial(256, ax0 + i0) + i0_1 = T.axis.spatial(256, i0 + ax0) k = T.axis.reduce(256, ax1_0 * 512 + ax1_1) T.reads(A[i0_1, k]) T.writes(T_softmax_maxelem_shared[i0_1]) @@ -249,7 +249,7 @@ def softmax_mn_3( for ax0, ax1_0 in T.grid(1, 32): for ax1_1 in T.thread_binding(8, thread="threadIdx.x"): with T.block("T_softmax_expsum"): - i0_4 = T.axis.spatial(256, ax0 + i0_3) + i0_4 = T.axis.spatial(256, i0_3 + ax0) k = T.axis.reduce(256, ax1_0 * 8 + ax1_1) T.reads(T_softmax_exp[i0_4, k]) T.writes(T_softmax_expsum_shared[i0_4]) @@ -388,7 +388,7 @@ def softmax_mn_after_inline_2( for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_expsum"): T.where(ax1_0 * 512 + ax1_1 < 256) - i0_2 = T.axis.spatial(256, ax0 + i0_3) + i0_2 = T.axis.spatial(256, i0_3 + ax0) k = T.axis.reduce(256, ax1_0 * 512 + ax1_1) T.reads(A[i0_2, k], T_softmax_maxelem[i0_2]) T.writes(T_softmax_expsum_shared[i0_2]) @@ -424,7 +424,7 @@ def softmax_mn_after_inline_3( for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_maxelem"): T.where(ax1_0 * 512 + ax1_1 < 256) - i0_1 = T.axis.spatial(256, ax0 + i0_3) + i0_1 = T.axis.spatial(256, i0_3 + ax0) k = T.axis.reduce(256, ax1_0 * 512 + ax1_1) T.reads(A[i0_1, k]) T.writes(T_softmax_maxelem_shared[i0_1]) @@ -437,7 +437,7 @@ def softmax_mn_after_inline_3( for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_expsum"): T.where(ax1_0 * 512 + ax1_1 < 256) - i0_2 = T.axis.spatial(256, ax0 + i0_3) + i0_2 = T.axis.spatial(256, i0_3 + ax0) k = T.axis.reduce(256, ax1_0 * 512 + ax1_1) T.reads(A[i0_2, k], T_softmax_maxelem_shared[i0_2]) T.writes(T_softmax_expsum_shared[i0_2]) diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py index 497915cd6564..10012c0be5b0 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py @@ -527,58 +527,58 @@ def cpu_conv2d_nhwc( weight: T.Buffer((3, 3, 64, 64), "float16"), conv2d_nhwc: T.Buffer((1, 56, 56, 64), "float16"), ) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body - # with T.block("root") - PadInput = T.alloc_buffer([1, 58, 58, 64], dtype="float16") + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + PadInput = T.alloc_buffer((1, 58, 58, 64), "float16") for i0, i1, i2, i3 in T.grid(1, 58, 58, 64): with T.block("PadInput"): - i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1]) - T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) - PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( - 1 <= i1_1 and i1_1 < 57 and 1 <= i2_1 and i2_1 < 57, - inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1], + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) + PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else( + 1 <= v_i1 and v_i1 < 57 and 1 <= v_i2 and v_i2 < 57, + inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3], T.float16(0), - dtype="float16", ) for ( - i0_0, - i1_0, - i2_0, - i3_0, - i4_0, - i5_0, - i6_0, - i0_1_1, - i1_1_1, - i2_1_1, - i3_1_1, - i4_1, - i5_1, - i6_1, - i0_2, - i1_2, - i2_2, - i3_2, + n_0, + h_0, + w_0, + co_0, + rh_0, + rw_0, + rc_0, + n_1, + h_1, + w_1, + co_1, + rh_1, + rw_1, + rc_1, + n_2, + h_2, + w_2, + co_2, ) in T.grid(1, 1, 2, 1, 3, 3, 16, 1, 14, 2, 1, 1, 1, 4, 1, 4, 14, 64): with T.block("conv2d_nhwc"): - n = T.axis.spatial(1, i0_1_1 + i0_2 + i0_0) - h = T.axis.spatial(56, i1_0 * 56 + i1_1_1 * 4 + i1_2) - w = T.axis.spatial(56, i2_0 * 28 + i2_1_1 * 14 + i2_2) - co = T.axis.spatial(64, i3_0 * 64 + i3_1_1 * 64 + i3_2) - rh = T.axis.reduce(3, i4_1 + i4_0) - rw = T.axis.reduce(3, i5_0 + i5_1) - rc = T.axis.reduce(64, i6_0 * 4 + i6_1) - T.reads(PadInput[n, h + rh, w + rw, co // 64 * 64 + rc], weight[rh, rw, rc, co]) - T.writes(conv2d_nhwc[n, h, w, co]) + v_n = T.axis.spatial(1, n_0 + n_1 + n_2) + v_h = T.axis.spatial(56, h_0 * 56 + h_1 * 4 + h_2) + v_w = T.axis.spatial(56, w_0 * 28 + w_1 * 14 + w_2) + v_co = T.axis.spatial(64, co_0 * 64 + co_1 * 64 + co_2) + v_rh = T.axis.reduce(3, rh_0 + rh_1) + v_rw = T.axis.reduce(3, rw_0 + rw_1) + v_rc = T.axis.reduce(64, rc_0 * 4 + rc_1) + T.reads( + PadInput[v_n, v_h + v_rh, v_w + v_rw, v_co // 64 * 64 + v_rc], + weight[v_rh, v_rw, v_rc, v_co], + ) + T.writes(conv2d_nhwc[v_n, v_h, v_w, v_co]) T.block_attr({"meta_schedule.tiling_structure": "SRSRS"}) with T.init(): - conv2d_nhwc[n, h, w, co] = T.float16(0) - conv2d_nhwc[n, h, w, co] = ( - conv2d_nhwc[n, h, w, co] - + PadInput[n, h + rh, w + rw, co // 64 * 64 + rc] * weight[rh, rw, rc, co] + conv2d_nhwc[v_n, v_h, v_w, v_co] = T.float16(0) + conv2d_nhwc[v_n, v_h, v_w, v_co] = ( + conv2d_nhwc[v_n, v_h, v_w, v_co] + + PadInput[v_n, v_h + v_rh, v_w + v_rw, v_co // 64 * 64 + v_rc] + * weight[v_rh, v_rw, v_rc, v_co] ) target_hexagon = target.hexagon("v69", num_cores=4) @@ -741,11 +741,11 @@ def pool_blocked_cache_read_write( X: T.Buffer((1, 2, 8, 8, 8, 8, 32), "uint8"), pool: T.Buffer((1, 2, 4, 4, 8, 8, 32), "uint8"), ): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - pool_global = T.alloc_buffer([1, 2, 4, 4, 8, 8, 32], dtype="uint8") - X_global = T.alloc_buffer([1, 2, 8, 8, 8, 8, 32], dtype="uint8") + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + pool_global = T.alloc_buffer((1, 2, 4, 4, 8, 8, 32), "uint8") + X_global = T.alloc_buffer((1, 2, 8, 8, 8, 8, 32), "uint8") for b_0, c_o_0, h_o_0, w_o_0, h_i_0, w_i_0, c_i_0 in T.grid(1, 2, 4, 1, 8, 1, 4): - for ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused in T.serial(896): + for ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused in range(896): with T.block("X_global"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(2, c_o_0) @@ -763,11 +763,11 @@ def pool_blocked_cache_read_write( 2, 2, 1, 1, 1, 4, 1, 8, 8 ): with T.block("pool"): - v_b = T.axis.spatial(1, b_1 + b_0) + v_b = T.axis.spatial(1, b_0 + b_1) v_c_o = T.axis.spatial(2, c_o_0 + c_o_1) - v_h_o = T.axis.spatial(4, h_o_1 + h_o_0) + v_h_o = T.axis.spatial(4, h_o_0 + h_o_1) v_w_o = T.axis.spatial(4, w_o_0 * 4 + w_o_1) - v_h_i = T.axis.spatial(8, h_i_1 + h_i_0) + v_h_i = T.axis.spatial(8, h_i_0 + h_i_1) v_w_i = T.axis.spatial(8, w_i_0 * 8 + w_i_1) v_c_i = T.axis.spatial(32, c_i_0 * 8 + c_i_1) v_wh, v_ww = T.axis.remap("RR", [wh, ww]) diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py index a1c024d287ad..1fadce6957a3 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py @@ -25,8 +25,8 @@ from tvm.script import tir as T from tvm.target import Target from tvm.tir.tensor_intrin.arm_cpu import DP4A_INTRIN -from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN from tvm.tir.tensor_intrin.x86 import AVX512_DOT_16x4_INTRIN as AVX512_INTRIN +from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN def test_x86_conv2d_nchwc(intrin=VNNI_INTRIN, target="llvm -mcpu=cascadelake -num-cores=4"): @@ -70,26 +70,27 @@ def conv2d_nchwc( # fmt: off @T.prim_func def x86_conv2d_nchwc_0(placeholder: T.Buffer((1, 4, 56, 56, 16), "uint8"), placeholder_1: T.Buffer((16, 4, 1, 1, 4, 16, 4), "int8"), conv2d_NCHWc_int8: T.Buffer((1, 16, 56, 56, 16), "int32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - conv2d_NCHWc_int8_global = T.alloc_buffer([1, 16, 56, 56, 16], dtype="int32") + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + # with T.block("root"): + conv2d_NCHWc_int8_global = T.alloc_buffer((1, 16, 56, 56, 16), "int32") for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1 in T.grid(1, 8, 28, 56, 1, 1, 2, 1, 1, 1): for i5_0, i6_0, i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3 in T.grid(1, 1, 1, 4, 1, 1, 1, 2, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 1, 1): with T.block("conv2d_NCHWc_int8_o"): - n = T.axis.spatial(1, i0_2 + i0_3 + i0_0 + i0_1) + n = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3) oc_chunk = T.axis.spatial(16, i1_0 * 2 + i1_1 + i1_2 + i1_3) oh = T.axis.spatial(56, i2_0 * 2 + i2_1 * 2 + i2_2 + i2_3) - ow = T.axis.spatial(56, i3_3 + i3_0 + i3_1 + i3_2) - oc_block_o = T.axis.spatial(1, i4_0_2 + i4_0_3 + i4_0_0 + i4_0_1) - kh = T.axis.reduce(1, i5_1 + i5_0) + ow = T.axis.spatial(56, i3_0 + i3_1 + i3_2 + i3_3) + oc_block_o = T.axis.spatial(1, i4_0_0 + i4_0_1 + i4_0_2 + i4_0_3) + kh = T.axis.reduce(1, i5_0 + i5_1) kw = T.axis.reduce(1, i6_0 + i6_1) ic_outer = T.axis.reduce(4, i7_0 * 4 + i7_1) ic_f_inner = T.axis.reduce(4, i8_0 + i8_1) - ic_s_inner_o = T.axis.reduce(1, i9_0_1 + i9_0_0) - T.reads(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4]) - T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, 0 : 16]) - T.block_attr({"meta_schedule.auto_tensorize":intrin}) + ic_s_inner_o = T.axis.reduce(1, i9_0_0 + i9_0_1) + T.reads(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4:ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:16, 0:4]) + T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": intrin}) with T.init(): - for i4_1 in T.serial(16): + for i4_1 in range(16): with T.block("conv2d_NCHWc_int8_init"): oc_block_i_init = T.axis.spatial(16, i4_1) T.reads() @@ -100,8 +101,8 @@ def x86_conv2d_nchwc_0(placeholder: T.Buffer((1, 4, 56, 56, 16), "uint8"), place oc_block_i, ic_s_inner_i = T.axis.remap("SR", [i4_1, i9_1]) T.reads(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i], placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner_i], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block_i, ic_s_inner_i]) T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) - conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i] = conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i] + T.cast(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner_i], "int32") * T.cast(placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block_i, ic_s_inner_i], "int32") + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i] = conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i] + T.Cast("int32", placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner_i]) * T.Cast("int32", placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block_i, ic_s_inner_i]) for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 1, 2, 1, 16): with T.block("conv2d_NCHWc_int8_global"): v0 = T.axis.spatial(1, ax0) @@ -115,26 +116,27 @@ def x86_conv2d_nchwc_0(placeholder: T.Buffer((1, 4, 56, 56, 16), "uint8"), place @T.prim_func def x86_conv2d_nchwc_1(placeholder: T.Buffer((1, 4, 56, 56, 16), "uint8"), placeholder_1: T.Buffer((16, 4, 1, 1, 4, 16, 4), "int8"), conv2d_NCHWc_int8: T.Buffer((1, 16, 56, 56, 16), "int32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - conv2d_NCHWc_int8_global = T.alloc_buffer([1, 16, 56, 56, 16], dtype="int32") + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + # with T.block("root"): + conv2d_NCHWc_int8_global = T.alloc_buffer((1, 16, 56, 56, 16), "int32") for i0_0, i1_0, i2_0, i3_0, i4_0_0 in T.grid(1, 8, 28, 56, 1): for i0_1, i1_1, i2_1, i3_1, i4_0_1, i5_0, i6_0, i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3 in T.grid(1, 2, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 2, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 1, 1): with T.block("conv2d_NCHWc_int8_o"): - n = T.axis.spatial(1, i0_2 + i0_3 + i0_0 + i0_1) + n = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3) oc_chunk = T.axis.spatial(16, i1_0 * 2 + i1_1 + i1_2 + i1_3) oh = T.axis.spatial(56, i2_0 * 2 + i2_1 * 2 + i2_2 + i2_3) - ow = T.axis.spatial(56, i3_3 + i3_0 + i3_1 + i3_2) - oc_block_o = T.axis.spatial(1, i4_0_2 + i4_0_3 + i4_0_0 + i4_0_1) - kh = T.axis.reduce(1, i5_1 + i5_0) + ow = T.axis.spatial(56, i3_0 + i3_1 + i3_2 + i3_3) + oc_block_o = T.axis.spatial(1, i4_0_0 + i4_0_1 + i4_0_2 + i4_0_3) + kh = T.axis.reduce(1, i5_0 + i5_1) kw = T.axis.reduce(1, i6_0 + i6_1) ic_outer = T.axis.reduce(4, i7_0 * 4 + i7_1) ic_f_inner = T.axis.reduce(4, i8_0 + i8_1) - ic_s_inner_o = T.axis.reduce(1, i9_0_1 + i9_0_0) - T.reads(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4]) - T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, 0 : 16]) - T.block_attr({"meta_schedule.auto_tensorize":intrin}) + ic_s_inner_o = T.axis.reduce(1, i9_0_0 + i9_0_1) + T.reads(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4:ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:16, 0:4]) + T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": intrin}) with T.init(): - for i4_1 in T.serial(16): + for i4_1 in range(16): with T.block("conv2d_NCHWc_int8_init"): oc_block_i_init = T.axis.spatial(16, i4_1) T.reads() @@ -145,8 +147,8 @@ def x86_conv2d_nchwc_1(placeholder: T.Buffer((1, 4, 56, 56, 16), "uint8"), place oc_block_i, ic_s_inner_i = T.axis.remap("SR", [i4_1, i9_1]) T.reads(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i], placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner_i], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block_i, ic_s_inner_i]) T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) - conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i] = conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i] + T.cast(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner_i], "int32") * T.cast(placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block_i, ic_s_inner_i], "int32") + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i] = conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i] + T.Cast("int32", placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner_i]) * T.Cast("int32", placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block_i, ic_s_inner_i]) for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 2, 2, 1, 16): with T.block("conv2d_NCHWc_int8_global"): v0 = T.axis.spatial(1, ax0) @@ -160,24 +162,25 @@ def x86_conv2d_nchwc_1(placeholder: T.Buffer((1, 4, 56, 56, 16), "uint8"), place @T.prim_func def x86_conv2d_nchwc_2(placeholder: T.Buffer((1, 4, 56, 56, 16), "uint8"), placeholder_1: T.Buffer((16, 4, 1, 1, 4, 16, 4), "int8"), conv2d_NCHWc_int8: T.Buffer((1, 16, 56, 56, 16), "int32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + # with T.block("root"): for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1, i5_0, i6_0, i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3 in T.grid(1, 8, 28, 56, 1, 1, 2, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 2, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 1, 1): with T.block("conv2d_NCHWc_int8_o"): - n = T.axis.spatial(1, i0_2 + i0_3 + i0_0 + i0_1) + n = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3) oc_chunk = T.axis.spatial(16, i1_0 * 2 + i1_1 + i1_2 + i1_3) oh = T.axis.spatial(56, i2_0 * 2 + i2_1 * 2 + i2_2 + i2_3) - ow = T.axis.spatial(56, i3_3 + i3_0 + i3_1 + i3_2) - oc_block_o = T.axis.spatial(1, i4_0_2 + i4_0_3 + i4_0_0 + i4_0_1) - kh = T.axis.reduce(1, i5_1 + i5_0) + ow = T.axis.spatial(56, i3_0 + i3_1 + i3_2 + i3_3) + oc_block_o = T.axis.spatial(1, i4_0_0 + i4_0_1 + i4_0_2 + i4_0_3) + kh = T.axis.reduce(1, i5_0 + i5_1) kw = T.axis.reduce(1, i6_0 + i6_1) ic_outer = T.axis.reduce(4, i7_0 * 4 + i7_1) ic_f_inner = T.axis.reduce(4, i8_0 + i8_1) - ic_s_inner_o = T.axis.reduce(1, i9_0_1 + i9_0_0) - T.reads(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4]) - T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16]) - T.block_attr({"meta_schedule.auto_tensorize":intrin}) + ic_s_inner_o = T.axis.reduce(1, i9_0_0 + i9_0_1) + T.reads(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4:ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:16, 0:4]) + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": intrin}) with T.init(): - for i4_1 in T.serial(16): + for i4_1 in range(16): with T.block("conv2d_NCHWc_int8_init"): oc_block_i_init = T.axis.spatial(16, i4_1) T.reads() @@ -188,8 +191,8 @@ def x86_conv2d_nchwc_2(placeholder: T.Buffer((1, 4, 56, 56, 16), "uint8"), place oc_block_i, ic_s_inner_i = T.axis.remap("SR", [i4_1, i9_1]) T.reads(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i], placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner_i], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block_i, ic_s_inner_i]) T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) - conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i] = conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i] + T.cast(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner_i], "int32") * T.cast(placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block_i, ic_s_inner_i], "int32") + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i] = conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i] + T.Cast("int32", placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner_i]) * T.Cast("int32", placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block_i, ic_s_inner_i]) # fmt: on decision_0 = [ ("SamplePerfectTile", [1, 1, 1, 1]), @@ -302,18 +305,16 @@ def dp4a_dense_0( W: T.Buffer((128, 128), "int8"), compute: T.Buffer((128, 128), "int32"), ) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body - # with T.block("root") - compute_local = T.alloc_buffer([128, 128], dtype="int32", scope="local") - X_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared") - W_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared") - for i0_0_i1_0_fused in T.thread_binding(1, thread="blockIdx.x"): - for i0_1_i1_1_fused in T.thread_binding(512, thread="vthread.x"): - for i0_2_i1_2_fused in T.thread_binding(2, thread="threadIdx.x"): - for i2_0_0 in T.serial(1): - for ax0_ax1_fused in T.serial(16384): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + # with T.block("root"): + compute_local = T.alloc_buffer((128, 128), "int32", scope="local") + X_shared = T.alloc_buffer((128, 128), "int8", scope="shared") + W_shared = T.alloc_buffer((128, 128), "int8", scope="shared") + for i_0_j_0_fused in T.thread_binding(1, thread="blockIdx.x"): + for i_1_j_1_fused in T.thread_binding(512, thread="vthread.x"): + for i_2_j_2_fused in T.thread_binding(2, thread="threadIdx.x"): + for k_0_0 in range(1): + for ax0_ax1_fused in range(16384): with T.block("X_shared"): v0 = T.axis.spatial(128, ax0_ax1_fused // 128) v1 = T.axis.spatial(128, ax0_ax1_fused % 128) @@ -321,7 +322,7 @@ def dp4a_dense_0( T.writes(X_shared[v0, v1]) T.block_attr({"meta_schedule.cooperative_fetch": 1}) X_shared[v0, v1] = X[v0, v1] - for ax0_ax1_fused in T.serial(16384): + for ax0_ax1_fused in range(16384): with T.block("W_shared"): v0 = T.axis.spatial(128, ax0_ax1_fused // 128) v1 = T.axis.spatial(128, ax0_ax1_fused % 128) @@ -329,47 +330,43 @@ def dp4a_dense_0( T.writes(W_shared[v0, v1]) T.block_attr({"meta_schedule.cooperative_fetch": 1}) W_shared[v0, v1] = W[v0, v1] - for i2_0_1, i0_3, i1_3, i2_0_2, i0_4, i1_4 in T.grid(1, 2, 4, 32, 2, 1): + for k_0_1, i_3, j_3, k_0_2, i_4, j_4 in T.grid(1, 2, 4, 32, 2, 1): with T.block("compute_o"): - i = T.axis.spatial( - 128, - i0_1_i1_1_fused // 32 * 8 - + i0_2_i1_2_fused * 4 - + i0_3 * 2 - + i0_4, + v_i = T.axis.spatial( + 128, i_1_j_1_fused // 32 * 8 + i_2_j_2_fused * 4 + i_3 * 2 + i_4 ) - j = T.axis.spatial(128, i1_4 + i0_1_i1_1_fused % 32 * 4 + i1_3) - k_o = T.axis.reduce(32, i2_0_0 * 32 + i2_0_1 * 32 + i2_0_2) + v_j = T.axis.spatial(128, i_1_j_1_fused % 32 * 4 + j_3 + j_4) + v_k_o = T.axis.reduce(32, k_0_0 * 32 + k_0_1 * 32 + k_0_2) T.reads( - X_shared[i, k_o * 4 : k_o * 4 + 4], - W_shared[j, k_o * 4 : k_o * 4 + 4], + X_shared[v_i, v_k_o * 4 : v_k_o * 4 + 4], + W_shared[v_j, v_k_o * 4 : v_k_o * 4 + 4], ) - T.writes(compute_local[i, j]) + T.writes(compute_local[v_i, v_j]) T.block_attr({"meta_schedule.auto_tensorize": "dp4a"}) with T.init(): with T.block("compute_init"): T.reads() - T.writes(compute_local[i, j]) - compute_local[i, j] = 0 - for i2_1 in T.serial(4): + T.writes(compute_local[v_i, v_j]) + compute_local[v_i, v_j] = 0 + for k_1 in range(4): with T.block("compute"): - k_i = T.axis.reduce(4, i2_1) + v_k_i = T.axis.reduce(4, k_1) T.reads( - compute_local[i, j], - X_shared[i, k_o * 4 + k_i], - W_shared[j, k_o * 4 + k_i], + compute_local[v_i, v_j], + X_shared[v_i, v_k_o * 4 + v_k_i], + W_shared[v_j, v_k_o * 4 + v_k_i], ) - T.writes(compute_local[i, j]) + T.writes(compute_local[v_i, v_j]) T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) - compute_local[i, j] = compute_local[i, j] + T.cast( - X_shared[i, k_o * 4 + k_i], "int32" - ) * T.cast(W_shared[j, k_o * 4 + k_i], "int32") + compute_local[v_i, v_j] = compute_local[v_i, v_j] + T.Cast( + "int32", X_shared[v_i, v_k_o * 4 + v_k_i] + ) * T.Cast("int32", W_shared[v_j, v_k_o * 4 + v_k_i]) for ax0, ax1 in T.grid(4, 4): with T.block("compute_local"): v0 = T.axis.spatial( - 128, i0_1_i1_1_fused // 32 * 8 + i0_2_i1_2_fused * 4 + ax0 + 128, i_1_j_1_fused // 32 * 8 + i_2_j_2_fused * 4 + ax0 ) - v1 = T.axis.spatial(128, i0_1_i1_1_fused % 32 * 4 + ax1) + v1 = T.axis.spatial(128, i_1_j_1_fused % 32 * 4 + ax1) T.reads(compute_local[v0, v1]) T.writes(compute[v0, v1]) compute[v0, v1] = compute_local[v0, v1] diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py index 97ee53f4e409..e101a63d138b 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py @@ -17,7 +17,6 @@ # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring import pytest - import tvm import tvm.testing from tvm import meta_schedule as ms @@ -82,7 +81,8 @@ def test_matmul_relu(shared_scope): # fmt: off @T.prim_func def matmul_relu_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "float16"), compute: T.Buffer((128, 128), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + # with T.block("root"): C_reindex_shared = T.alloc_buffer((4, 8, 2, 1, 16, 16), scope=shared_scope) C_reindex_shared_wmma_accumulator = T.alloc_buffer((4, 8, 2, 1, 16, 16), scope="wmma.accumulator") A_reindex_shared = T.alloc_buffer((128, 128), "float16", scope=shared_scope) @@ -139,7 +139,7 @@ def matmul_relu_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "f for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(1, 1, 2, 2, 1): with T.block("C_o"): v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0_3 * 2 + ax0_0_4) - v1_o = T.axis.spatial(8, ax1_0_4 + ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0_3) + v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0_3 + ax1_0_4) v2_o = T.axis.reduce(8, ax2_0_0 * 8 + ax2_0_1 * 2 + ax2_0_2) T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, 0:16, 0:16]) @@ -189,7 +189,6 @@ def matmul_relu_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "f T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16]) T.block_attr({"meta_schedule.cooperative_fetch": 4}) compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16] = T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0)) - # fmt: on decision_0 = [ ("SamplePerfectTile", [4, 1, 1, 1, 2]), @@ -389,15 +388,16 @@ def test_conv2d(shared_scope): intrin_suffix = shared_scope.replace(".", "_") # fmt: off @T.prim_func - def conv2d_0(inputs: T.Buffer((1, 16, 16, 32), "float16"), weight: T.Buffer((3, 3, 32, 32), "float16"), conv2d_nhwc: T.Buffer((1, 16, 16, 32), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + def conv2d_0(inputs: T.Buffer((1, 16, 16, 32), "float16"), weight: T.Buffer((3, 3, 32, 32), "float16"), conv2d_nhwc: T.Buffer((1, 16, 16, 32), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + # with T.block("root"): PadInput = T.alloc_buffer((1, 18, 18, 32), "float16") - conv2d_nhwc_reindex_shared = T.alloc_buffer((16, 2, 1, 1, 16, 16), scope=shared_scope) - conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((16, 2, 1, 1, 16, 16), scope="wmma.accumulator") - PadInput_reindex_shared = T.alloc_buffer((256, 288), "float16", scope=shared_scope) - weight_reindex_shared = T.alloc_buffer((288, 32), "float16", scope=shared_scope) - PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer((256, 288), "float16", scope="wmma.matrix_a") - weight_reindex_shared_wmma_matrix_b = T.alloc_buffer((288, 32), "float16", scope="wmma.matrix_b") + conv2d_nhwc_reindex_shared_dyn = T.alloc_buffer((16, 2, 1, 1, 16, 16), scope=shared_scope) + conv2d_nhwc_reindex_shared_dyn_wmma_accumulator = T.alloc_buffer((16, 2, 1, 1, 16, 16), scope="wmma.accumulator") + PadInput_reindex_shared_dyn = T.alloc_buffer((256, 288), "float16", scope=shared_scope) + weight_reindex_shared_dyn = T.alloc_buffer((288, 32), "float16", scope=shared_scope) + PadInput_reindex_shared_dyn_wmma_matrix_a = T.alloc_buffer((256, 288), "float16", scope="wmma.matrix_a") + weight_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((288, 32), "float16", scope="wmma.matrix_b") for i0, i1, i2, i3 in T.grid(1, 18, 18, 32): with T.block("PadInput"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) @@ -409,96 +409,96 @@ def conv2d_0(inputs: T.Buffer((1, 16, 16, 32), "float16"), weight: T.Buffer((3, for ax0_0_2_ax1_0_2_fused in T.thread_binding(1, thread="threadIdx.y"): for ax2_0_0 in range(1): for ax0_ax1_fused in range(4608): - with T.block("PadInput_reindex_shared"): + with T.block("PadInput_reindex_shared.dyn"): v0 = T.axis.spatial(256, ax0_0_1_ax1_0_1_fused * 16 + ax0_ax1_fused // 288) v1 = T.axis.spatial(288, ax0_ax1_fused % 288) T.reads(PadInput[v0 // 256, v1 // 96 + v0 // 16, v1 % 96 // 32 + v0 % 16, v1 % 32]) - T.writes(PadInput_reindex_shared[v0, v1]) + T.writes(PadInput_reindex_shared_dyn[v0, v1]) T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2}) - PadInput_reindex_shared[v0, v1] = PadInput[v0 // 256, v1 // 96 + v0 // 16, v1 % 96 // 32 + v0 % 16, v1 % 32] + PadInput_reindex_shared_dyn[v0, v1] = PadInput[v0 // 256, v1 // 96 + v0 // 16, v1 % 96 // 32 + v0 % 16, v1 % 32] for ax0_ax1_fused in range(4608): - with T.block("weight_reindex_shared"): + with T.block("weight_reindex_shared.dyn"): v0 = T.axis.spatial(288, ax0_ax1_fused // 16) v1 = T.axis.spatial(32, ax0_0_0_ax1_0_0_fused * 16 + ax0_ax1_fused % 16) T.reads(weight[v0 // 96, v0 % 96 // 32, v0 % 32, v1]) - T.writes(weight_reindex_shared[v0, v1]) + T.writes(weight_reindex_shared_dyn[v0, v1]) T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8}) - weight_reindex_shared[v0, v1] = weight[v0 // 96, v0 % 96 // 32, v0 % 32, v1] + weight_reindex_shared_dyn[v0, v1] = weight[v0 // 96, v0 % 96 // 32, v0 % 32, v1] for ax2_0_1 in range(18): for ax0_0, ax1_0 in T.grid(1, 1): - with T.block("PadInput_reindex_shared_wmma.matrix_a_o"): + with T.block("PadInput_reindex_shared.dyn_wmma.matrix_a_o"): v0_o = T.axis.spatial(16, ax0_0_1_ax1_0_1_fused + ax0_0) v1_o = T.axis.spatial(18, ax2_0_1 + ax1_0) - T.reads(PadInput_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) - T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.reads(PadInput_reindex_shared_dyn[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(PadInput_reindex_shared_dyn_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_a_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("PadInput_reindex_shared_wmma.matrix_a"): + with T.block("PadInput_reindex_shared.dyn_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.reads(PadInput_reindex_shared_dyn[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(PadInput_reindex_shared_dyn_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + PadInput_reindex_shared_dyn_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_shared_dyn[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0_0, ax1_0 in T.grid(1, 1): - with T.block("weight_reindex_shared_wmma.matrix_b_o"): + with T.block("weight_reindex_shared.dyn_wmma.matrix_b_o"): v0_o = T.axis.spatial(18, ax2_0_1 + ax0_0) v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused + ax1_0) - T.reads(weight_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) - T.writes(weight_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.reads(weight_reindex_shared_dyn[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(weight_reindex_shared_dyn_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_b_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("weight_reindex_shared_wmma.matrix_b"): + with T.block("weight_reindex_shared.dyn_wmma.matrix_b"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(weight_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(weight_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - weight_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = weight_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.reads(weight_reindex_shared_dyn[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(weight_reindex_shared_dyn_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + weight_reindex_shared_dyn_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = weight_reindex_shared_dyn[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(1, 1, 1, 1, 1): with T.block("conv2d_nhwc_o"): - v0_o = T.axis.spatial(16, ax0_0_4 + ax0_0_1_ax1_0_1_fused + ax0_0_3) + v0_o = T.axis.spatial(16, ax0_0_1_ax1_0_1_fused + ax0_0_3 + ax0_0_4) v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused + ax1_0_3 + ax1_0_4) v2_o = T.axis.reduce(18, ax2_0_0 * 18 + ax2_0_1 + ax2_0_2) - T.reads(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, 0:16, 0:16]) + T.reads(PadInput_reindex_shared_dyn_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], weight_reindex_shared_dyn_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o, 0, 0, 0:16, 0:16]) T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax0_1, ax1_1 in T.grid(16, 16): with T.block("conv2d_nhwc_init"): v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads() - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i_init, v1_i_init]) - conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i_init, v1_i_init] = T.float32(0) + T.writes(conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i_init, v1_i_init]) + conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i_init, v1_i_init] = T.float32(0) for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): with T.block("conv2d_nhwc"): v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i], PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], weight_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i]) + T.reads(conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i], PadInput_reindex_shared_dyn_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], weight_reindex_shared_dyn_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i]) T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) - conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i] + T.Cast("float32", PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", weight_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i] = conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i] + T.Cast("float32", PadInput_reindex_shared_dyn_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", weight_reindex_shared_dyn_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) for ax2 in range(1): for ax0_ax1_fused in T.thread_binding(1, thread="threadIdx.y"): for ax2_1, ax3 in T.grid(1, 1): - with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): + with T.block("conv2d_nhwc_reindex_shared.dyn_wmma.accumulator_o"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0_0_1_ax1_0_1_fused, ax0_0_0_ax1_0_0_fused, ax2_1, ax3]) v4_o = T.axis.spatial(1, 0) v5_o = T.axis.spatial(1, 0) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) - T.writes(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, 0:16, 0:16]) + T.reads(conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) + T.writes(conv2d_nhwc_reindex_shared_dyn[v0, v1, v2, v3, 0:16, 0:16]) T.block_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) for ax4, ax5 in T.grid(16, 16): - with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): + with T.block("conv2d_nhwc_reindex_shared.dyn_wmma.accumulator"): v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) - T.writes(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4_i, v5_i]) - conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] + T.reads(conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) + T.writes(conv2d_nhwc_reindex_shared_dyn[v0, v1, v2, v3, v4_i, v5_i]) + conv2d_nhwc_reindex_shared_dyn[v0, v1, v2, v3, v4_i, v5_i] = conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] for ax0_ax1_ax3_ax4_ax5_fused in range(256): - with T.block("conv2d_nhwc_reindex_shared"): + with T.block("conv2d_nhwc_reindex_shared.dyn"): v0, v1, v2 = T.axis.remap("SSS", [ax0_0_1_ax1_0_1_fused, ax0_0_0_ax1_0_0_fused, ax2]) v3 = T.axis.spatial(1, 0) v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused // 16) v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) - T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.reads(conv2d_nhwc_reindex_shared_dyn[v0, v1, v2, v3, v4, v5]) T.writes(conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16]) T.block_attr({"meta_schedule.cooperative_fetch": 3}) - conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16] = conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5] + conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16] = conv2d_nhwc_reindex_shared_dyn[v0, v1, v2, v3, v4, v5] # fmt: on decision_0 = [ ("SamplePerfectTile", [1, 16, 1, 1, 1]), @@ -751,10 +751,7 @@ def test_padded_matmul_relu(): # fmt: off @T.prim_func def padded_matmul_relu_0(A: T.Buffer((127, 127), "float16"), B: T.Buffer((127, 127), "float16"), compute: T.Buffer((127, 127), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body - # with T.block("root") + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) C_reindex_shared = T.alloc_buffer((4, 8, 2, 1, 16, 16), scope="shared") C_reindex_shared_wmma_accumulator = T.alloc_buffer((4, 8, 2, 1, 16, 16), scope="wmma.accumulator") A_reindex_shared = T.alloc_buffer((128, 128), "float16", scope="shared") @@ -811,7 +808,7 @@ def padded_matmul_relu_0(A: T.Buffer((127, 127), "float16"), B: T.Buffer((127, 1 for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(1, 1, 2, 2, 1): with T.block("C_o"): v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0_3 * 2 + ax0_0_4) - v1_o = T.axis.spatial(8, ax1_0_4 + ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0_3) + v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0_3 + ax1_0_4) v2_o = T.axis.reduce(8, ax2_0_0 * 8 + ax2_0_1 * 2 + ax2_0_2) T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, 0:16, 0:16]) @@ -862,7 +859,6 @@ def padded_matmul_relu_0(A: T.Buffer((127, 127), "float16"), B: T.Buffer((127, 1 T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16]) T.block_attr({"meta_schedule.cooperative_fetch": 4}) compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16] = T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0)) - # fmt: on decision_0 = [ @@ -903,7 +899,7 @@ def test_conv_1x1(): # fmt: off @T.prim_func def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer((1, 1, 64, 64), "float16"), conv2d_nhwc: T.Buffer((1, 16, 16, 64), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) conv2d_nhwc_reindex_shared = T.alloc_buffer((16, 4, 1, 1, 16, 16), scope="shared") conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((16, 4, 1, 1, 16, 16), scope="wmma.accumulator") PadInput_reindex_shared = T.alloc_buffer((256, 64), "float16", scope="shared") @@ -961,10 +957,10 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer( weight_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] = weight_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] for ax2_0_3, ax3_0_3, ax0_2, ax1_2, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(1, 1, 1, 1, 4, 1, 1): with T.block("conv2d_nhwc_o"): - v0 = T.axis.reduce(1, ax0_2 + ax0_0 + ax0_1) - v1 = T.axis.reduce(1, ax1_1 + ax1_2 + ax1_0) - v2_o = T.axis.spatial(16, ax2_0_4 + ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused + ax2_0_3) - v3_o = T.axis.spatial(4, ax3_0_4 + ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused + ax3_0_3) + v0 = T.axis.reduce(1, ax0_0 + ax0_1 + ax0_2) + v1 = T.axis.reduce(1, ax1_0 + ax1_1 + ax1_2) + v2_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused + ax2_0_3 + ax2_0_4) + v3_o = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused + ax3_0_3 + ax3_0_4) v4_o = T.axis.reduce(4, ax4_0_0 * 4 + ax4_0_1 * 4 + ax4_0_2) T.reads(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16:v2_o * 16 + 16, v4_o * 16:v4_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16:v4_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, 0:16, 0:16]) diff --git a/tests/python/unittest/test_meta_schedule_space_cpu.py b/tests/python/unittest/test_meta_schedule_space_cpu.py index 93e1bdad4438..b57fb042888c 100644 --- a/tests/python/unittest/test_meta_schedule_space_cpu.py +++ b/tests/python/unittest/test_meta_schedule_space_cpu.py @@ -42,81 +42,77 @@ def _design_space(mod): def test_cpu_c1d(): # fmt: off @T.prim_func - def c1d_0(inputs: T.Buffer((1, 256, 64), "float32"), weight: T.Buffer((3, 64, 128), "float32"), conv1d_nlc: T.Buffer((1, 128, 128), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + def c1d_0(inputs: T.Buffer((1, 256, 64), "float32"), weight: T.Buffer((3, 64, 128), "float32"), conv1d_nlc: T.Buffer((1, 128, 128), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":512, "meta_schedule.vectorize":64}) - PadInput = T.alloc_buffer([1, 258, 64], dtype="float32") - conv1d_nlc_global = T.alloc_buffer([1, 128, 128], dtype="float32") + PadInput = T.alloc_buffer((1, 258, 64), dtype="float32") + conv1d_nlc_global = T.alloc_buffer((1, 128, 128), dtype="float32") for i0, i1, i2 in T.grid(1, 258, 64): with T.block("PadInput"): - i0_1, i1_1, i2_1 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(inputs[i0_1, i1_1 - 1, i2_1]) - T.writes(PadInput[i0_1, i1_1, i2_1]) - PadInput[i0_1, i1_1, i2_1] = T.if_then_else(1 <= i1_1 and i1_1 < 257, inputs[i0_1, i1_1 - 1, i2_1], T.float32(0), dtype="float32") - for i0_0, i1_0, i2_0, i0_1_1, i1_1_1, i2_1_1 in T.grid(1, 1, 2, 1, 1, 8): - for i3_0, i4_0, i0_2, i1_2, i2_2, i3_1, i4_1, i0_3, i1_3, i2_3 in T.grid(1, 64, 1, 64, 8, 3, 1, 1, 2, 1): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(inputs[v_i0, v_i1 - 1, v_i2]) + T.writes(PadInput[v_i0, v_i1, v_i2]) + PadInput[v_i0, v_i1, v_i2] = T.if_then_else(1 <= v_i1 and v_i1 < 257, inputs[v_i0, v_i1 - 1, v_i2], T.float32(0)) + for n_0, l_0, co_0, n_1, l_1, co_1 in T.grid(1, 1, 2, 1, 1, 8): + for rl_0, rc_0, n_2, l_2, co_2, rl_1, rc_1, n_3, l_3, co_3 in T.grid(1, 64, 1, 64, 8, 3, 1, 1, 2, 1): with T.block("conv1d_nlc"): - n = T.axis.spatial(1, i0_1_1 + i0_2 + i0_3 + i0_0) - l = T.axis.spatial(128, i1_0 * 128 + i1_1_1 * 128 + i1_2 * 2 + i1_3) - co = T.axis.spatial(128, i2_3 + i2_0 * 64 + i2_1_1 * 8 + i2_2) - rl = T.axis.reduce(3, i3_0 * 3 + i3_1) - rc = T.axis.reduce(64, i4_1 + i4_0) - T.reads(PadInput[n, l * 2 + rl, co // 128 * 64 + rc], weight[rl, rc, co]) - T.writes(conv1d_nlc_global[n, l, co]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) + v_l = T.axis.spatial(128, l_0 * 128 + l_1 * 128 + l_2 * 2 + l_3) + v_co = T.axis.spatial(128, co_0 * 64 + co_1 * 8 + co_2 + co_3) + v_rl = T.axis.reduce(3, rl_0 * 3 + rl_1) + v_rc = T.axis.reduce(64, rc_0 + rc_1) + T.reads(PadInput[v_n, v_l * 2 + v_rl, v_co // 128 * 64 + v_rc], weight[v_rl, v_rc, v_co]) + T.writes(conv1d_nlc_global[v_n, v_l, v_co]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - conv1d_nlc_global[n, l, co] = T.float32(0) - conv1d_nlc_global[n, l, co] = conv1d_nlc_global[n, l, co] + PadInput[n, l * 2 + rl, co // 128 * 64 + rc] * weight[rl, rc, co] + conv1d_nlc_global[v_n, v_l, v_co] = T.float32(0) + conv1d_nlc_global[v_n, v_l, v_co] = conv1d_nlc_global[v_n, v_l, v_co] + PadInput[v_n, v_l * 2 + v_rl, v_co // 128 * 64 + v_rc] * weight[v_rl, v_rc, v_co] for ax0, ax1, ax2 in T.grid(1, 128, 8): with T.block("conv1d_nlc_global"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(128, i2_0 * 64 + i2_1_1 * 8 + ax2) + v2 = T.axis.spatial(128, co_0 * 64 + co_1 * 8 + ax2) T.reads(conv1d_nlc_global[v0, v1, v2]) T.writes(conv1d_nlc[v0, v1, v2]) conv1d_nlc[v0, v1, v2] = conv1d_nlc_global[v0, v1, v2] @T.prim_func def c1d_1(inputs: T.Buffer((1, 256, 64), "float32"), weight: T.Buffer((3, 64, 128), "float32"), conv1d_nlc: T.Buffer((1, 128, 128), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":512, "meta_schedule.vectorize":64}) - PadInput = T.alloc_buffer([1, 258, 64], dtype="float32") - conv1d_nlc_global = T.alloc_buffer([1, 128, 128], dtype="float32") - for i0_0, i1_0, i2_0 in T.grid(1, 1, 2): - for i0_1, i1_1, i2_1 in T.grid(1, 1, 8): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) + PadInput = T.alloc_buffer((1, 258, 64)) + conv1d_nlc_global = T.alloc_buffer((1, 128, 128)) + for n_0, l_0, co_0 in T.grid(1, 1, 2): + for n_1, l_1, co_1 in T.grid(1, 1, 8): for ax0, ax1, ax2 in T.grid(1, 257, 64): with T.block("PadInput"): - i0 = T.axis.spatial(1, ax0) - i1 = T.axis.spatial(258, ax1) - i2 = T.axis.spatial(64, ax2) - T.reads(inputs[i0, i1 - 1, i2]) - T.writes(PadInput[i0, i1, i2]) - PadInput[i0, i1, i2] = T.if_then_else(1 <= i1 and i1 < 257, inputs[i0, i1 - 1, i2], T.float32(0), dtype="float32") - for i3_0, i4_0, i0_2, i1_2, i2_2, i3_1, i4_1, i0_3, i1_3, i2_3 in T.grid(1, 64, 1, 64, 8, 3, 1, 1, 2, 1): + v_i0 = T.axis.spatial(1, ax0) + v_i1 = T.axis.spatial(258, ax1) + v_i2 = T.axis.spatial(64, ax2) + T.reads(inputs[v_i0, v_i1 - 1, v_i2]) + T.writes(PadInput[v_i0, v_i1, v_i2]) + PadInput[v_i0, v_i1, v_i2] = T.if_then_else(1 <= v_i1 and v_i1 < 257, inputs[v_i0, v_i1 - 1, v_i2], T.float32(0)) + for rl_0, rc_0, n_2, l_2, co_2, rl_1, rc_1, n_3, l_3, co_3 in T.grid(1, 64, 1, 64, 8, 3, 1, 1, 2, 1): with T.block("conv1d_nlc"): - n = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0) - l = T.axis.spatial(128, i1_0 * 128 + i1_1 * 128 + i1_2 * 2 + i1_3) - co = T.axis.spatial(128, i2_3 + i2_0 * 64 + i2_1 * 8 + i2_2) - rl = T.axis.reduce(3, i3_0 * 3 + i3_1) - rc = T.axis.reduce(64, i4_1 + i4_0) - T.reads(PadInput[n, l * 2 + rl, co // 128 * 64 + rc], weight[rl, rc, co]) - T.writes(conv1d_nlc_global[n, l, co]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) + v_l = T.axis.spatial(128, l_0 * 128 + l_1 * 128 + l_2 * 2 + l_3) + v_co = T.axis.spatial(128, co_0 * 64 + co_1 * 8 + co_2 + co_3) + v_rl = T.axis.reduce(3, rl_0 * 3 + rl_1) + v_rc = T.axis.reduce(64, rc_0 + rc_1) + T.reads(PadInput[v_n, v_l * 2 + v_rl, v_co // 128 * 64 + v_rc], weight[v_rl, v_rc, v_co]) + T.writes(conv1d_nlc_global[v_n, v_l, v_co]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - conv1d_nlc_global[n, l, co] = T.float32(0) - conv1d_nlc_global[n, l, co] = conv1d_nlc_global[n, l, co] + PadInput[n, l * 2 + rl, co // 128 * 64 + rc] * weight[rl, rc, co] + conv1d_nlc_global[v_n, v_l, v_co] = T.float32(0) + conv1d_nlc_global[v_n, v_l, v_co] = conv1d_nlc_global[v_n, v_l, v_co] + PadInput[v_n, v_l * 2 + v_rl, v_co // 128 * 64 + v_rc] * weight[v_rl, v_rc, v_co] for ax0, ax1, ax2 in T.grid(1, 128, 64): with T.block("conv1d_nlc_global"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(128, i2_0 * 64 + ax2) + v2 = T.axis.spatial(128, co_0 * 64 + ax2) T.reads(conv1d_nlc_global[v0, v1, v2]) T.writes(conv1d_nlc[v0, v1, v2]) conv1d_nlc[v0, v1, v2] = conv1d_nlc_global[v0, v1, v2] @@ -125,24 +121,23 @@ def c1d_1(inputs: T.Buffer((1, 256, 64), "float32"), weight: T.Buffer((3, 64, 12 def c1d_2(inputs: T.Buffer((1, 256, 64), "float32"), weight: T.Buffer((3, 64, 128), "float32"), conv1d_nlc: T.Buffer((1, 128, 128), "float32")) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64}) - for i0_0, i1_0, i2_0, i0_1, i1_1, i2_1, i3_0, i4_0, i0_2, i1_2, i2_2, i3_1, i4_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 1, 8, 1, 64, 1, 64, 8, 3, 1, 1, 2, 1): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) + for n_0, l_0, co_0, n_1, l_1, co_1, rl_0, rc_0, n_2, l_2, co_2, rl_1, rc_1, n_3, l_3, co_3 in T.grid(1, 1, 2, 1, 1, 8, 1, 64, 1, 64, 8, 3, 1, 1, 2, 1): with T.block("conv1d_nlc"): - n = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0) - l = T.axis.spatial(128, i1_0 * 128 + i1_1 * 128 + i1_2 * 2 + i1_3) - co = T.axis.spatial(128, i2_3 + i2_0 * 64 + i2_1 * 8 + i2_2) - rl = T.axis.reduce(3, i3_0 * 3 + i3_1) - rc = T.axis.reduce(64, i4_1 + i4_0) - T.reads(inputs[n, l * 2 + rl - 1, co // 128 * 64 + rc], weight[rl, rc, co]) - T.writes(conv1d_nlc[n, l, co]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) + v_l = T.axis.spatial(128, l_0 * 128 + l_1 * 128 + l_2 * 2 + l_3) + v_co = T.axis.spatial(128, co_0 * 64 + co_1 * 8 + co_2 + co_3) + v_rl = T.axis.reduce(3, rl_0 * 3 + rl_1) + v_rc = T.axis.reduce(64, rc_0 + rc_1) + T.reads(inputs[v_n, v_l * 2 + v_rl - 1, v_co // 128 * 64 + v_rc], weight[v_rl, v_rc, v_co]) + T.writes(conv1d_nlc[v_n, v_l, v_co]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - conv1d_nlc[n, l, co] = T.float32(0) - conv1d_nlc[n, l, co] = conv1d_nlc[n, l, co] + T.if_then_else(1 <= l * 2 + rl and l * 2 + rl < 257, inputs[n, l * 2 + rl - 1, co // 128 * 64 + rc], T.float32(0), dtype="float32") * weight[rl, rc, co] + conv1d_nlc[v_n, v_l, v_co] = T.float32(0) + conv1d_nlc[v_n, v_l, v_co] = conv1d_nlc[v_n, v_l, v_co] + T.if_then_else(1 <= v_l * 2 + v_rl and v_l * 2 + v_rl < 257, inputs[v_n, v_l * 2 + v_rl - 1, v_co // 128 * 64 + v_rc], T.float32(0)) * weight[v_rl, v_rc, v_co] # fmt: on decision_0 = [ @@ -187,127 +182,121 @@ def test_cpu_c2d(): # fmt: off @T.prim_func def c2d_0(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64}) - PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") - conv2d_nhwc_global = T.alloc_buffer([1, 112, 112, 64], dtype="float32") - for i0_0, i1_0, i2_0, i3_0, i0_1, i1_1, i2_1 in T.grid(1, 7, 4, 2, 1, 1, 28): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) + PadInput = T.alloc_buffer((1, 230, 230, 3)) + conv2d_nhwc_global = T.alloc_buffer((1, 112, 112, 64)) + for n_0, h_0, w_0, co_0, n_1, h_1, w_1 in T.grid(1, 7, 4, 2, 1, 1, 28): for ax0, ax1, ax2, ax3 in T.grid(1, 37, 7, 3): with T.block("PadInput"): - i0 = T.axis.spatial(1, ax0) - i1 = T.axis.spatial(230, i1_0 * 32 + ax1) - i2 = T.axis.spatial(230, i2_0 * 56 + i2_1 * 2 + ax2) - i3 = T.axis.spatial(3, ax3) - T.reads(inputs[i0, i1 - 3, i2 - 3, i3]) - T.writes(PadInput[i0, i1, i2, i3]) - PadInput[i0, i1, i2, i3] = T.if_then_else(3 <= i1 and i1 < 227 and 3 <= i2 and i2 < 227, inputs[i0, i1 - 3, i2 - 3, i3], T.float32(0), dtype="float32") - for i3_1 in T.serial(8): - for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(7, 7, 1, 1, 2, 1, 1, 1, 1, 3, 1, 8, 1, 4): + v_i0 = T.axis.spatial(1, ax0) + v_i1 = T.axis.spatial(230, h_0 * 32 + ax1) + v_i2 = T.axis.spatial(230, w_0 * 56 + w_1 * 2 + ax2) + v_i3 = T.axis.spatial(3, ax3) + T.reads(inputs[v_i0, v_i1 - 3, v_i2 - 3, v_i3]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) + PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(3 <= v_i1 and v_i1 < 227 and 3 <= v_i2 and v_i2 < 227, inputs[v_i0, v_i1 - 3, v_i2 - 3, v_i3], T.float32(0)) + for co_1 in range(8): + for rh_0, rw_0, rc_0, n_2, h_2, w_2, co_2, rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3 in T.grid(7, 7, 1, 1, 2, 1, 1, 1, 1, 3, 1, 8, 1, 4): with T.block("conv2d_nhwc"): - n = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2) - h = T.axis.spatial(112, i1_0 * 16 + i1_1 * 16 + i1_2 * 8 + i1_3) - w = T.axis.spatial(112, i2_3 + i2_0 * 28 + i2_1 + i2_2) - co = T.axis.spatial(64, i3_0 * 32 + i3_1 * 4 + i3_2 * 4 + i3_3) - rh = T.axis.reduce(7, i4_1 + i4_0) - rw = T.axis.reduce(7, i5_0 + i5_1) - rc = T.axis.reduce(3, i6_0 * 3 + i6_1) - T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight[rh, rw, rc, co]) - T.writes(conv2d_nhwc_global[n, h, w, co]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) + v_h = T.axis.spatial(112, h_0 * 16 + h_1 * 16 + h_2 * 8 + h_3) + v_w = T.axis.spatial(112, w_0 * 28 + w_1 + w_2 + w_3) + v_co = T.axis.spatial(64, co_0 * 32 + co_1 * 4 + co_2 * 4 + co_3) + v_rh = T.axis.reduce(7, rh_0 + rh_1) + v_rw = T.axis.reduce(7, rw_0 + rw_1) + v_rc = T.axis.reduce(3, rc_0 * 3 + rc_1) + T.reads(PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc], weight[v_rh, v_rw, v_rc, v_co]) + T.writes(conv2d_nhwc_global[v_n, v_h, v_w, v_co]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - conv2d_nhwc_global[n, h, w, co] = T.float32(0) - conv2d_nhwc_global[n, h, w, co] = conv2d_nhwc_global[n, h, w, co] + PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc] * weight[rh, rw, rc, co] + conv2d_nhwc_global[v_n, v_h, v_w, v_co] = T.float32(0) + conv2d_nhwc_global[v_n, v_h, v_w, v_co] = conv2d_nhwc_global[v_n, v_h, v_w, v_co] + PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc] * weight[v_rh, v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3 in T.grid(1, 16, 1, 4): with T.block("conv2d_nhwc_global"): v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(112, i1_0 * 16 + ax1) - v2 = T.axis.spatial(112, i2_0 * 28 + i2_1 + ax2) - v3 = T.axis.spatial(64, i3_0 * 32 + i3_1 * 4 + ax3) + v1 = T.axis.spatial(112, h_0 * 16 + ax1) + v2 = T.axis.spatial(112, w_0 * 28 + w_1 + ax2) + v3 = T.axis.spatial(64, co_0 * 32 + co_1 * 4 + ax3) T.reads(conv2d_nhwc_global[v0, v1, v2, v3]) T.writes(conv2d_nhwc[v0, v1, v2, v3]) conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3] @T.prim_func def c2d_1(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":512, "meta_schedule.vectorize":64}) - PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") - conv2d_nhwc_global = T.alloc_buffer([1, 112, 112, 64], dtype="float32") + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) + PadInput = T.alloc_buffer((1, 230, 230, 3)) + conv2d_nhwc_global = T.alloc_buffer((1, 112, 112, 64)) for i0, i1, i2, i3 in T.grid(1, 230, 230, 3): with T.block("PadInput"): - i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1]) - T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) - PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(3 <= i1_1 and i1_1 < 227 and 3 <= i2_1 and i2_1 < 227, inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1], T.float32(0), dtype="float32") - for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 7, 4, 2): - for i0_1_1, i1_1_1, i2_1_1, i3_1_1, i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 1, 28, 8, 7, 7, 1, 1, 2, 1, 1, 1, 1, 3, 1, 8, 1, 4): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(inputs[v_i0, v_i1 - 3, v_i2 - 3, v_i3]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) + PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(3 <= v_i1 and v_i1 < 227 and 3 <= v_i2 and v_i2 < 227, inputs[v_i0, v_i1 - 3, v_i2 - 3, v_i3], T.float32(0)) + for n_0, h_0, w_0, co_0 in T.grid(1, 7, 4, 2): + for n_1, h_1, w_1, co_1, rh_0, rw_0, rc_0, n_2, h_2, w_2, co_2, rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3 in T.grid(1, 1, 28, 8, 7, 7, 1, 1, 2, 1, 1, 1, 1, 3, 1, 8, 1, 4): with T.block("conv2d_nhwc"): - n = T.axis.spatial(1, i0_3 + i0_0 + i0_1_1 + i0_2) - h = T.axis.spatial(112, i1_0 * 16 + i1_1_1 * 16 + i1_2 * 8 + i1_3) - w = T.axis.spatial(112, i2_3 + i2_0 * 28 + i2_1_1 + i2_2) - co = T.axis.spatial(64, i3_0 * 32 + i3_1_1 * 4 + i3_2 * 4 + i3_3) - rh = T.axis.reduce(7, i4_1 + i4_0) - rw = T.axis.reduce(7, i5_0 + i5_1) - rc = T.axis.reduce(3, i6_0 * 3 + i6_1) - T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight[rh, rw, rc, co]) - T.writes(conv2d_nhwc_global[n, h, w, co]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) + v_h = T.axis.spatial(112, h_0 * 16 + h_1 * 16 + h_2 * 8 + h_3) + v_w = T.axis.spatial(112, w_0 * 28 + w_1 + w_2 + w_3) + v_co = T.axis.spatial(64, co_0 * 32 + co_1 * 4 + co_2 * 4 + co_3) + v_rh = T.axis.reduce(7, rh_0 + rh_1) + v_rw = T.axis.reduce(7, rw_0 + rw_1) + v_rc = T.axis.reduce(3, rc_0 * 3 + rc_1) + T.reads(PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc], weight[v_rh, v_rw, v_rc, v_co]) + T.writes(conv2d_nhwc_global[v_n, v_h, v_w, v_co]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - conv2d_nhwc_global[n, h, w, co] = T.float32(0) - conv2d_nhwc_global[n, h, w, co] = conv2d_nhwc_global[n, h, w, co] + PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc] * weight[rh, rw, rc, co] + conv2d_nhwc_global[v_n, v_h, v_w, v_co] = T.float32(0) + conv2d_nhwc_global[v_n, v_h, v_w, v_co] = conv2d_nhwc_global[v_n, v_h, v_w, v_co] + PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc] * weight[v_rh, v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3 in T.grid(1, 16, 28, 32): with T.block("conv2d_nhwc_global"): v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(112, i1_0 * 16 + ax1) - v2 = T.axis.spatial(112, i2_0 * 28 + ax2) - v3 = T.axis.spatial(64, i3_0 * 32 + ax3) + v1 = T.axis.spatial(112, h_0 * 16 + ax1) + v2 = T.axis.spatial(112, w_0 * 28 + ax2) + v3 = T.axis.spatial(64, co_0 * 32 + ax3) T.reads(conv2d_nhwc_global[v0, v1, v2, v3]) T.writes(conv2d_nhwc[v0, v1, v2, v3]) conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3] @T.prim_func def c2d_2(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":0, "meta_schedule.vectorize":64}) - PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") - for i0_0, i1_0 in T.grid(1, 7): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) + PadInput = T.alloc_buffer((1, 230, 230, 3)) + for n_0, h_0 in T.grid(1, 7): for ax0, ax1, ax2, ax3 in T.grid(1, 37, 229, 3): with T.block("PadInput"): - i0 = T.axis.spatial(1, ax0) - i1 = T.axis.spatial(230, i1_0 * 32 + ax1) - i2 = T.axis.spatial(230, ax2) - i3 = T.axis.spatial(3, ax3) - T.reads(inputs[i0, i1 - 3, i2 - 3, i3]) - T.writes(PadInput[i0, i1, i2, i3]) - PadInput[i0, i1, i2, i3] = T.if_then_else(3 <= i1 and i1 < 227 and 3 <= i2 and i2 < 227, inputs[i0, i1 - 3, i2 - 3, i3], T.float32(0), dtype="float32") - for i2_0, i3_0, i0_1, i1_1, i2_1, i3_1, i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(4, 2, 1, 1, 28, 8, 7, 7, 1, 1, 2, 1, 1, 1, 1, 3, 1, 8, 1, 4): + v_i0 = T.axis.spatial(1, ax0) + v_i1 = T.axis.spatial(230, h_0 * 32 + ax1) + v_i2 = T.axis.spatial(230, ax2) + v_i3 = T.axis.spatial(3, ax3) + T.reads(inputs[v_i0, v_i1 - 3, v_i2 - 3, v_i3]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) + PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(3 <= v_i1 and v_i1 < 227 and 3 <= v_i2 and v_i2 < 227, inputs[v_i0, v_i1 - 3, v_i2 - 3, v_i3], T.float32(0)) + for w_0, co_0, n_1, h_1, w_1, co_1, rh_0, rw_0, rc_0, n_2, h_2, w_2, co_2, rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3 in T.grid(4, 2, 1, 1, 28, 8, 7, 7, 1, 1, 2, 1, 1, 1, 1, 3, 1, 8, 1, 4): with T.block("conv2d_nhwc"): - n = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2) - h = T.axis.spatial(112, i1_0 * 16 + i1_1 * 16 + i1_2 * 8 + i1_3) - w = T.axis.spatial(112, i2_3 + i2_0 * 28 + i2_1 + i2_2) - co = T.axis.spatial(64, i3_0 * 32 + i3_1 * 4 + i3_2 * 4 + i3_3) - rh = T.axis.reduce(7, i4_1 + i4_0) - rw = T.axis.reduce(7, i5_0 + i5_1) - rc = T.axis.reduce(3, i6_0 * 3 + i6_1) - T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight[rh, rw, rc, co]) - T.writes(conv2d_nhwc[n, h, w, co]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) + v_h = T.axis.spatial(112, h_0 * 16 + h_1 * 16 + h_2 * 8 + h_3) + v_w = T.axis.spatial(112, w_0 * 28 + w_1 + w_2 + w_3) + v_co = T.axis.spatial(64, co_0 * 32 + co_1 * 4 + co_2 * 4 + co_3) + v_rh = T.axis.reduce(7, rh_0 + rh_1) + v_rw = T.axis.reduce(7, rw_0 + rw_1) + v_rc = T.axis.reduce(3, rc_0 * 3 + rc_1) + T.reads(PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc], weight[v_rh, v_rw, v_rc, v_co]) + T.writes(conv2d_nhwc[v_n, v_h, v_w, v_co]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - conv2d_nhwc[n, h, w, co] = T.float32(0) - conv2d_nhwc[n, h, w, co] = conv2d_nhwc[n, h, w, co] + PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc] * weight[rh, rw, rc, co] + conv2d_nhwc[v_n, v_h, v_w, v_co] = T.float32(0) + conv2d_nhwc[v_n, v_h, v_w, v_co] = conv2d_nhwc[v_n, v_h, v_w, v_co] + PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc] * weight[v_rh, v_rw, v_rc, v_co] # fmt: on decision_0 = [ @@ -358,142 +347,136 @@ def test_cpu_c3d(): # fmt: off @T.prim_func def c3d_0(inputs: T.Buffer((1, 16, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 7, 3, 64), "float32"), conv3d_ndhwc: T.Buffer((1, 8, 112, 112, 64), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":512, "meta_schedule.vectorize":64}) - PadInput = T.alloc_buffer([1, 22, 230, 230, 3], dtype="float32") - conv3d_ndhwc_global = T.alloc_buffer([1, 8, 112, 112, 64], dtype="float32") - for i0_0, i1_0, i2_0, i3_0, i4_0 in T.grid(1, 2, 4, 1, 2): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) + PadInput = T.alloc_buffer((1, 22, 230, 230, 3)) + conv3d_ndhwc_global = T.alloc_buffer((1, 8, 112, 112, 64)) + for n_0, d_0, h_0, w_0, co_0 in T.grid(1, 2, 4, 1, 2): for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 13, 61, 229, 3): with T.block("PadInput"): - i0 = T.axis.spatial(1, ax0) - i1 = T.axis.spatial(22, i1_0 * 8 + ax1) - i2 = T.axis.spatial(230, i2_0 * 56 + ax2) - i3 = T.axis.spatial(230, ax3) - i4 = T.axis.spatial(3, ax4) - T.reads(inputs[i0, i1 - 3, i2 - 3, i3 - 3, i4]) - T.writes(PadInput[i0, i1, i2, i3, i4]) - PadInput[i0, i1, i2, i3, i4] = T.if_then_else(3 <= i1 and i1 < 19 and 3 <= i2 and i2 < 227 and 3 <= i3 and i3 < 227, inputs[i0, i1 - 3, i2 - 3, i3 - 3, i4], T.float32(0), dtype="float32") - for i0_1, i1_1, i2_1, i3_1, i4_1 in T.grid(1, 4, 4, 14, 1): - for i5_0, i6_0, i7_0, i8_0, i0_2, i1_2, i2_2, i3_2, i4_2, i5_1, i6_1, i7_1, i8_1, i0_3, i1_3, i2_3, i3_3, i4_3 in T.grid(1, 7, 7, 3, 1, 1, 1, 1, 32, 7, 1, 1, 1, 1, 1, 7, 8, 1): + v_i0 = T.axis.spatial(1, ax0) + v_i1 = T.axis.spatial(22, d_0 * 8 + ax1) + v_i2 = T.axis.spatial(230, h_0 * 56 + ax2) + v_i3 = T.axis.spatial(230, ax3) + v_i4 = T.axis.spatial(3, ax4) + T.reads(inputs[v_i0, v_i1 - 3, v_i2 - 3, v_i3 - 3, v_i4]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3, v_i4]) + PadInput[v_i0, v_i1, v_i2, v_i3, v_i4] = T.if_then_else(3 <= v_i1 and v_i1 < 19 and 3 <= v_i2 and v_i2 < 227 and 3 <= v_i3 and v_i3 < 227, inputs[v_i0, v_i1 - 3, v_i2 - 3, v_i3 - 3, v_i4], T.float32(0)) + for n_1, d_1, h_1, w_1, co_1 in T.grid(1, 4, 4, 14, 1): + for rd_0, rh_0, rw_0, rc_0, n_2, d_2, h_2, w_2, co_2, rd_1, rh_1, rw_1, rc_1, n_3, d_3, h_3, w_3, co_3 in T.grid(1, 7, 7, 3, 1, 1, 1, 1, 32, 7, 1, 1, 1, 1, 1, 7, 8, 1): with T.block("conv3d_ndhwc"): - n = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0) - d = T.axis.spatial(8, i1_3 + i1_0 * 4 + i1_1 + i1_2) - h = T.axis.spatial(112, i2_0 * 28 + i2_1 * 7 + i2_2 * 7 + i2_3) - w = T.axis.spatial(112, i3_0 * 112 + i3_1 * 8 + i3_2 * 8 + i3_3) - co = T.axis.spatial(64, i4_3 + i4_0 * 32 + i4_1 * 32 + i4_2) - rd = T.axis.reduce(7, i5_0 * 7 + i5_1) - rh = T.axis.reduce(7, i6_1 + i6_0) - rw = T.axis.reduce(7, i7_0 + i7_1) - rc = T.axis.reduce(3, i8_1 + i8_0) - T.reads(PadInput[n, d * 2 + rd, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight[rd, rh, rw, rc, co]) - T.writes(conv3d_ndhwc_global[n, d, h, w, co]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) + v_d = T.axis.spatial(8, d_0 * 4 + d_1 + d_2 + d_3) + v_h = T.axis.spatial(112, h_0 * 28 + h_1 * 7 + h_2 * 7 + h_3) + v_w = T.axis.spatial(112, w_0 * 112 + w_1 * 8 + w_2 * 8 + w_3) + v_co = T.axis.spatial(64, co_0 * 32 + co_1 * 32 + co_2 + co_3) + v_rd = T.axis.reduce(7, rd_0 * 7 + rd_1) + v_rh = T.axis.reduce(7, rh_0 + rh_1) + v_rw = T.axis.reduce(7, rw_0 + rw_1) + v_rc = T.axis.reduce(3, rc_0 + rc_1) + T.reads(PadInput[v_n, v_d * 2 + v_rd, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc], weight[v_rd, v_rh, v_rw, v_rc, v_co]) + T.writes(conv3d_ndhwc_global[v_n, v_d, v_h, v_w, v_co]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - conv3d_ndhwc_global[n, d, h, w, co] = T.float32(0) - conv3d_ndhwc_global[n, d, h, w, co] = conv3d_ndhwc_global[n, d, h, w, co] + PadInput[n, d * 2 + rd, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc] * weight[rd, rh, rw, rc, co] + conv3d_ndhwc_global[v_n, v_d, v_h, v_w, v_co] = T.float32(0) + conv3d_ndhwc_global[v_n, v_d, v_h, v_w, v_co] = conv3d_ndhwc_global[v_n, v_d, v_h, v_w, v_co] + PadInput[v_n, v_d * 2 + v_rd, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc] * weight[v_rd, v_rh, v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 1, 7, 8, 32): with T.block("conv3d_ndhwc_global"): v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(8, i1_0 * 4 + i1_1 + ax1) - v2 = T.axis.spatial(112, i2_0 * 28 + i2_1 * 7 + ax2) - v3 = T.axis.spatial(112, i3_1 * 8 + ax3) - v4 = T.axis.spatial(64, i4_0 * 32 + ax4) + v1 = T.axis.spatial(8, d_0 * 4 + d_1 + ax1) + v2 = T.axis.spatial(112, h_0 * 28 + h_1 * 7 + ax2) + v3 = T.axis.spatial(112, w_1 * 8 + ax3) + v4 = T.axis.spatial(64, co_0 * 32 + ax4) T.reads(conv3d_ndhwc_global[v0, v1, v2, v3, v4]) T.writes(conv3d_ndhwc[v0, v1, v2, v3, v4]) conv3d_ndhwc[v0, v1, v2, v3, v4] = conv3d_ndhwc_global[v0, v1, v2, v3, v4] @T.prim_func def c3d_1(inputs: T.Buffer((1, 16, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 7, 3, 64), "float32"), conv3d_ndhwc: T.Buffer((1, 8, 112, 112, 64), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":64, "meta_schedule.vectorize":64}) - PadInput = T.alloc_buffer([1, 22, 230, 230, 3], dtype="float32") - conv3d_ndhwc_global = T.alloc_buffer([1, 8, 112, 112, 64], dtype="float32") - for i0_0, i1_0, i2_0, i3_0, i4_0 in T.grid(1, 2, 4, 1, 2): - for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 4, 4, 14): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) + PadInput = T.alloc_buffer((1, 22, 230, 230, 3)) + conv3d_ndhwc_global = T.alloc_buffer((1, 8, 112, 112, 64)) + for n_0, d_0, h_0, w_0, co_0 in T.grid(1, 2, 4, 1, 2): + for n_1, d_1, h_1, w_1 in T.grid(1, 4, 4, 14): for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 7, 19, 21, 3): with T.block("PadInput"): - i0 = T.axis.spatial(1, ax0) - i1 = T.axis.spatial(22, i1_0 * 8 + i1_1 * 2 + ax1) - i2 = T.axis.spatial(230, i2_0 * 56 + i2_1 * 14 + ax2) - i3 = T.axis.spatial(230, i3_1 * 16 + ax3) - i4 = T.axis.spatial(3, ax4) - T.reads(inputs[i0, i1 - 3, i2 - 3, i3 - 3, i4]) - T.writes(PadInput[i0, i1, i2, i3, i4]) - PadInput[i0, i1, i2, i3, i4] = T.if_then_else(3 <= i1 and i1 < 19 and 3 <= i2 and i2 < 227 and 3 <= i3 and i3 < 227, inputs[i0, i1 - 3, i2 - 3, i3 - 3, i4], T.float32(0), dtype="float32") - for i4_1, i5_0, i6_0, i7_0, i8_0, i0_2, i1_2, i2_2, i3_2, i4_2, i5_1, i6_1, i7_1, i8_1, i0_3, i1_3, i2_3, i3_3, i4_3 in T.grid(1, 1, 7, 7, 3, 1, 1, 1, 1, 32, 7, 1, 1, 1, 1, 1, 7, 8, 1): + v_i0 = T.axis.spatial(1, ax0) + v_i1 = T.axis.spatial(22, d_0 * 8 + d_1 * 2 + ax1) + v_i2 = T.axis.spatial(230, h_0 * 56 + h_1 * 14 + ax2) + v_i3 = T.axis.spatial(230, w_1 * 16 + ax3) + v_i4 = T.axis.spatial(3, ax4) + T.reads(inputs[v_i0, v_i1 - 3, v_i2 - 3, v_i3 - 3, v_i4]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3, v_i4]) + PadInput[v_i0, v_i1, v_i2, v_i3, v_i4] = T.if_then_else(3 <= v_i1 and v_i1 < 19 and 3 <= v_i2 and v_i2 < 227 and 3 <= v_i3 and v_i3 < 227, inputs[v_i0, v_i1 - 3, v_i2 - 3, v_i3 - 3, v_i4], T.float32(0)) + for co_1, rd_0, rh_0, rw_0, rc_0, n_2, d_2, h_2, w_2, co_2, rd_1, rh_1, rw_1, rc_1, n_3, d_3, h_3, w_3, co_3 in T.grid(1, 1, 7, 7, 3, 1, 1, 1, 1, 32, 7, 1, 1, 1, 1, 1, 7, 8, 1): with T.block("conv3d_ndhwc"): - n = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0) - d = T.axis.spatial(8, i1_3 + i1_0 * 4 + i1_1 + i1_2) - h = T.axis.spatial(112, i2_0 * 28 + i2_1 * 7 + i2_2 * 7 + i2_3) - w = T.axis.spatial(112, i3_0 * 112 + i3_1 * 8 + i3_2 * 8 + i3_3) - co = T.axis.spatial(64, i4_3 + i4_0 * 32 + i4_1 * 32 + i4_2) - rd = T.axis.reduce(7, i5_0 * 7 + i5_1) - rh = T.axis.reduce(7, i6_1 + i6_0) - rw = T.axis.reduce(7, i7_0 + i7_1) - rc = T.axis.reduce(3, i8_1 + i8_0) - T.reads(PadInput[n, d * 2 + rd, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight[rd, rh, rw, rc, co]) - T.writes(conv3d_ndhwc_global[n, d, h, w, co]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) + v_d = T.axis.spatial(8, d_0 * 4 + d_1 + d_2 + d_3) + v_h = T.axis.spatial(112, h_0 * 28 + h_1 * 7 + h_2 * 7 + h_3) + v_w = T.axis.spatial(112, w_0 * 112 + w_1 * 8 + w_2 * 8 + w_3) + v_co = T.axis.spatial(64, co_0 * 32 + co_1 * 32 + co_2 + co_3) + v_rd = T.axis.reduce(7, rd_0 * 7 + rd_1) + v_rh = T.axis.reduce(7, rh_0 + rh_1) + v_rw = T.axis.reduce(7, rw_0 + rw_1) + v_rc = T.axis.reduce(3, rc_0 + rc_1) + T.reads(PadInput[v_n, v_d * 2 + v_rd, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc], weight[v_rd, v_rh, v_rw, v_rc, v_co]) + T.writes(conv3d_ndhwc_global[v_n, v_d, v_h, v_w, v_co]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - conv3d_ndhwc_global[n, d, h, w, co] = T.float32(0) - conv3d_ndhwc_global[n, d, h, w, co] = conv3d_ndhwc_global[n, d, h, w, co] + PadInput[n, d * 2 + rd, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc] * weight[rd, rh, rw, rc, co] + conv3d_ndhwc_global[v_n, v_d, v_h, v_w, v_co] = T.float32(0) + conv3d_ndhwc_global[v_n, v_d, v_h, v_w, v_co] = conv3d_ndhwc_global[v_n, v_d, v_h, v_w, v_co] + PadInput[v_n, v_d * 2 + v_rd, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc] * weight[v_rd, v_rh, v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 4, 28, 112, 32): with T.block("conv3d_ndhwc_global"): v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(8, i1_0 * 4 + ax1) - v2 = T.axis.spatial(112, i2_0 * 28 + ax2) + v1 = T.axis.spatial(8, d_0 * 4 + ax1) + v2 = T.axis.spatial(112, h_0 * 28 + ax2) v3 = T.axis.spatial(112, ax3) - v4 = T.axis.spatial(64, i4_0 * 32 + ax4) + v4 = T.axis.spatial(64, co_0 * 32 + ax4) T.reads(conv3d_ndhwc_global[v0, v1, v2, v3, v4]) T.writes(conv3d_ndhwc[v0, v1, v2, v3, v4]) conv3d_ndhwc[v0, v1, v2, v3, v4] = conv3d_ndhwc_global[v0, v1, v2, v3, v4] @T.prim_func def c3d_2(inputs: T.Buffer((1, 16, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 7, 3, 64), "float32"), conv3d_ndhwc: T.Buffer((1, 8, 112, 112, 64), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64}) - PadInput = T.alloc_buffer([1, 22, 230, 230, 3], dtype="float32") - for i0_0, i1_0, i2_0, i3_0, i4_0, i0_1, i1_1, i2_1, i3_1 in T.grid(1, 2, 4, 1, 2, 1, 4, 4, 14): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) + PadInput = T.alloc_buffer((1, 22, 230, 230, 3)) + for n_0, d_0, h_0, w_0, co_0, n_1, d_1, h_1, w_1 in T.grid(1, 2, 4, 1, 2, 1, 4, 4, 14): for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 7, 19, 21, 3): with T.block("PadInput"): - i0 = T.axis.spatial(1, ax0) - i1 = T.axis.spatial(22, i1_0 * 8 + i1_1 * 2 + ax1) - i2 = T.axis.spatial(230, i2_0 * 56 + i2_1 * 14 + ax2) - i3 = T.axis.spatial(230, i3_1 * 16 + ax3) - i4 = T.axis.spatial(3, ax4) - T.reads(inputs[i0, i1 - 3, i2 - 3, i3 - 3, i4]) - T.writes(PadInput[i0, i1, i2, i3, i4]) - PadInput[i0, i1, i2, i3, i4] = T.if_then_else(3 <= i1 and i1 < 19 and 3 <= i2 and i2 < 227 and 3 <= i3 and i3 < 227, inputs[i0, i1 - 3, i2 - 3, i3 - 3, i4], T.float32(0), dtype="float32") - for i4_1, i5_0, i6_0, i7_0, i8_0, i0_2, i1_2, i2_2, i3_2, i4_2, i5_1, i6_1, i7_1, i8_1, i0_3, i1_3, i2_3, i3_3, i4_3 in T.grid(1, 1, 7, 7, 3, 1, 1, 1, 1, 32, 7, 1, 1, 1, 1, 1, 7, 8, 1): + v_i0 = T.axis.spatial(1, ax0) + v_i1 = T.axis.spatial(22, d_0 * 8 + d_1 * 2 + ax1) + v_i2 = T.axis.spatial(230, h_0 * 56 + h_1 * 14 + ax2) + v_i3 = T.axis.spatial(230, w_1 * 16 + ax3) + v_i4 = T.axis.spatial(3, ax4) + T.reads(inputs[v_i0, v_i1 - 3, v_i2 - 3, v_i3 - 3, v_i4]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3, v_i4]) + PadInput[v_i0, v_i1, v_i2, v_i3, v_i4] = T.if_then_else(3 <= v_i1 and v_i1 < 19 and 3 <= v_i2 and v_i2 < 227 and 3 <= v_i3 and v_i3 < 227, inputs[v_i0, v_i1 - 3, v_i2 - 3, v_i3 - 3, v_i4], T.float32(0)) + for co_1, rd_0, rh_0, rw_0, rc_0, n_2, d_2, h_2, w_2, co_2, rd_1, rh_1, rw_1, rc_1, n_3, d_3, h_3, w_3, co_3 in T.grid(1, 1, 7, 7, 3, 1, 1, 1, 1, 32, 7, 1, 1, 1, 1, 1, 7, 8, 1): with T.block("conv3d_ndhwc"): - n = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0) - d = T.axis.spatial(8, i1_3 + i1_0 * 4 + i1_1 + i1_2) - h = T.axis.spatial(112, i2_0 * 28 + i2_1 * 7 + i2_2 * 7 + i2_3) - w = T.axis.spatial(112, i3_0 * 112 + i3_1 * 8 + i3_2 * 8 + i3_3) - co = T.axis.spatial(64, i4_3 + i4_0 * 32 + i4_1 * 32 + i4_2) - rd = T.axis.reduce(7, i5_0 * 7 + i5_1) - rh = T.axis.reduce(7, i6_1 + i6_0) - rw = T.axis.reduce(7, i7_0 + i7_1) - rc = T.axis.reduce(3, i8_1 + i8_0) - T.reads(PadInput[n, d * 2 + rd, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight[rd, rh, rw, rc, co]) - T.writes(conv3d_ndhwc[n, d, h, w, co]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) + v_d = T.axis.spatial(8, d_0 * 4 + d_1 + d_2 + d_3) + v_h = T.axis.spatial(112, h_0 * 28 + h_1 * 7 + h_2 * 7 + h_3) + v_w = T.axis.spatial(112, w_0 * 112 + w_1 * 8 + w_2 * 8 + w_3) + v_co = T.axis.spatial(64, co_0 * 32 + co_1 * 32 + co_2 + co_3) + v_rd = T.axis.reduce(7, rd_0 * 7 + rd_1) + v_rh = T.axis.reduce(7, rh_0 + rh_1) + v_rw = T.axis.reduce(7, rw_0 + rw_1) + v_rc = T.axis.reduce(3, rc_0 + rc_1) + T.reads(PadInput[v_n, v_d * 2 + v_rd, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc], weight[v_rd, v_rh, v_rw, v_rc, v_co]) + T.writes(conv3d_ndhwc[v_n, v_d, v_h, v_w, v_co]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - conv3d_ndhwc[n, d, h, w, co] = T.float32(0) - conv3d_ndhwc[n, d, h, w, co] = conv3d_ndhwc[n, d, h, w, co] + PadInput[n, d * 2 + rd, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc] * weight[rd, rh, rw, rc, co] + conv3d_ndhwc[v_n, v_d, v_h, v_w, v_co] = T.float32(0) + conv3d_ndhwc[v_n, v_d, v_h, v_w, v_co] = conv3d_ndhwc[v_n, v_d, v_h, v_w, v_co] + PadInput[v_n, v_d * 2 + v_rd, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc] * weight[v_rd, v_rh, v_rw, v_rc, v_co] # fmt: on decision_0 = [ @@ -550,137 +533,131 @@ def test_cpu_cap(): # fmt: off @T.prim_func def cap_0(inputs: T.Buffer((1, 16, 16, 4, 4, 32), "float32"), weight: T.Buffer((3, 3, 4, 4, 32, 32), "float32"), conv2d_capsule_nhwijc: T.Buffer((1, 8, 8, 4, 4, 32), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":0, "meta_schedule.vectorize":64}) - PadInput = T.alloc_buffer([1, 18, 18, 4, 4, 32], dtype="float32") - conv2d_capsule_nhwijc_global = T.alloc_buffer([1, 8, 8, 4, 4, 32], dtype="float32") - for i0_0, i1_0, i2_0, i3_0, i4_0, i5_0, i0_1, i1_1 in T.grid(1, 2, 1, 1, 1, 1, 1, 4): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) + PadInput = T.alloc_buffer((1, 18, 18, 4, 4, 32)) + conv2d_capsule_nhwijc_global = T.alloc_buffer((1, 8, 8, 4, 4, 32)) + for n_0, h_0, w_0, cap_i_0, cap_j_0, co_0, n_1, h_1 in T.grid(1, 2, 1, 1, 1, 1, 1, 4): for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(1, 3, 17, 4, 4, 32): with T.block("PadInput"): - i0 = T.axis.spatial(1, ax0) - i1 = T.axis.spatial(18, i1_0 * 8 + i1_1 * 2 + ax1) - i2 = T.axis.spatial(18, ax2) - i3, i4, i5 = T.axis.remap("SSS", [ax3, ax4, ax5]) - T.reads(inputs[i0, i1 - 1, i2 - 1, i3, i4, i5]) - T.writes(PadInput[i0, i1, i2, i3, i4, i5]) - PadInput[i0, i1, i2, i3, i4, i5] = T.if_then_else(1 <= i1 and i1 < 17 and 1 <= i2 and i2 < 17, inputs[i0, i1 - 1, i2 - 1, i3, i4, i5], T.float32(0), dtype="float32") - for i2_1, i3_1, i4_1, i5_1 in T.grid(4, 1, 4, 2): - for i6_0, i7_0, i8_0, i9_0, i0_2, i1_2, i2_2, i3_2, i4_2, i5_2, i6_1, i7_1, i8_1, i9_1, i0_3, i1_3, i2_3, i3_3, i4_3, i5_3 in T.grid(1, 3, 4, 1, 1, 1, 2, 1, 1, 1, 3, 1, 1, 32, 1, 1, 1, 4, 1, 16): + v_i0 = T.axis.spatial(1, ax0) + v_i1 = T.axis.spatial(18, h_0 * 8 + h_1 * 2 + ax1) + v_i2 = T.axis.spatial(18, ax2) + v_i3, v_i4, v_i5 = T.axis.remap("SSS", [ax3, ax4, ax5]) + T.reads(inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3, v_i4, v_i5]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3, v_i4, v_i5]) + PadInput[v_i0, v_i1, v_i2, v_i3, v_i4, v_i5] = T.if_then_else(1 <= v_i1 and v_i1 < 17 and 1 <= v_i2 and v_i2 < 17, inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3, v_i4, v_i5], T.float32(0)) + for w_1, cap_i_1, cap_j_1, co_1 in T.grid(4, 1, 4, 2): + for rh_0, rw_0, cap_k_0, rc_0, n_2, h_2, w_2, cap_i_2, cap_j_2, co_2, rh_1, rw_1, cap_k_1, rc_1, n_3, h_3, w_3, cap_i_3, cap_j_3, co_3 in T.grid(1, 3, 4, 1, 1, 1, 2, 1, 1, 1, 3, 1, 1, 32, 1, 1, 1, 4, 1, 16): with T.block("conv2d_capsule_nhwijc"): - n = T.axis.spatial(1, i0_2 + i0_3 + i0_0 + i0_1) - h = T.axis.spatial(8, i1_0 * 4 + i1_1 + i1_2 + i1_3) - w = T.axis.spatial(8, i2_0 * 8 + i2_1 * 2 + i2_2 + i2_3) - cap_i = T.axis.spatial(4, i3_0 * 4 + i3_1 * 4 + i3_2 * 4 + i3_3) - cap_j = T.axis.spatial(4, i4_0 * 4 + i4_1 + i4_2 + i4_3) - co = T.axis.spatial(32, i5_0 * 32 + i5_1 * 16 + i5_2 * 16 + i5_3) - rh = T.axis.reduce(3, i6_0 * 3 + i6_1) - rw = T.axis.reduce(3, i7_1 + i7_0) - cap_k = T.axis.reduce(4, i8_0 + i8_1) - rc = T.axis.reduce(32, i9_0 * 32 + i9_1) - T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc], weight[rh, rw, cap_k, cap_j, rc, co]) - T.writes(conv2d_capsule_nhwijc_global[n, h, w, cap_i, cap_j, co]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) + v_h = T.axis.spatial(8, h_0 * 4 + h_1 + h_2 + h_3) + v_w = T.axis.spatial(8, w_0 * 8 + w_1 * 2 + w_2 + w_3) + v_cap_i = T.axis.spatial(4, cap_i_0 * 4 + cap_i_1 * 4 + cap_i_2 * 4 + cap_i_3) + v_cap_j = T.axis.spatial(4, cap_j_0 * 4 + cap_j_1 + cap_j_2 + cap_j_3) + v_co = T.axis.spatial(32, co_0 * 32 + co_1 * 16 + co_2 * 16 + co_3) + v_rh = T.axis.reduce(3, rh_0 * 3 + rh_1) + v_rw = T.axis.reduce(3, rw_0 + rw_1) + v_cap_k = T.axis.reduce(4, cap_k_0 + cap_k_1) + v_rc = T.axis.reduce(32, rc_0 * 32 + rc_1) + T.reads(PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_cap_i, v_cap_k, v_rc], weight[v_rh, v_rw, v_cap_k, v_cap_j, v_rc, v_co]) + T.writes(conv2d_capsule_nhwijc_global[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - conv2d_capsule_nhwijc_global[n, h, w, cap_i, cap_j, co] = T.float32(0) - conv2d_capsule_nhwijc_global[n, h, w, cap_i, cap_j, co] = conv2d_capsule_nhwijc_global[n, h, w, cap_i, cap_j, co] + PadInput[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc] * weight[rh, rw, cap_k, cap_j, rc, co] + conv2d_capsule_nhwijc_global[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co] = T.float32(0) + conv2d_capsule_nhwijc_global[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co] = conv2d_capsule_nhwijc_global[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co] + PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_cap_i, v_cap_k, v_rc] * weight[v_rh, v_rw, v_cap_k, v_cap_j, v_rc, v_co] for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(1, 1, 2, 4, 1, 16): with T.block("conv2d_capsule_nhwijc_global"): v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(8, i1_0 * 4 + i1_1 + ax1) - v2 = T.axis.spatial(8, i2_1 * 2 + ax2) + v1 = T.axis.spatial(8, h_0 * 4 + h_1 + ax1) + v2 = T.axis.spatial(8, w_1 * 2 + ax2) v3 = T.axis.spatial(4, ax3) - v4 = T.axis.spatial(4, i4_1 + ax4) - v5 = T.axis.spatial(32, i5_1 * 16 + ax5) + v4 = T.axis.spatial(4, cap_j_1 + ax4) + v5 = T.axis.spatial(32, co_1 * 16 + ax5) T.reads(conv2d_capsule_nhwijc_global[v0, v1, v2, v3, v4, v5]) T.writes(conv2d_capsule_nhwijc[v0, v1, v2, v3, v4, v5]) conv2d_capsule_nhwijc[v0, v1, v2, v3, v4, v5] = conv2d_capsule_nhwijc_global[v0, v1, v2, v3, v4, v5] @T.prim_func def cap_1(inputs: T.Buffer((1, 16, 16, 4, 4, 32), "float32"), weight: T.Buffer((3, 3, 4, 4, 32, 32), "float32"), conv2d_capsule_nhwijc: T.Buffer((1, 8, 8, 4, 4, 32), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":0, "meta_schedule.vectorize":64}) - PadInput = T.alloc_buffer([1, 18, 18, 4, 4, 32], dtype="float32") - conv2d_capsule_nhwijc_global = T.alloc_buffer([1, 8, 8, 4, 4, 32], dtype="float32") - for i0_0, i1_0, i2_0, i3_0, i4_0, i5_0 in T.grid(1, 2, 1, 1, 1, 1): - for i0_1, i1_1, i2_1, i3_1, i4_1, i5_1 in T.grid(1, 4, 4, 1, 4, 2): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) + PadInput = T.alloc_buffer((1, 18, 18, 4, 4, 32)) + conv2d_capsule_nhwijc_global = T.alloc_buffer((1, 8, 8, 4, 4, 32)) + for n_0, h_0, w_0, cap_i_0, cap_j_0, co_0 in T.grid(1, 2, 1, 1, 1, 1): + for n_1, h_1, w_1, cap_i_1, cap_j_1, co_1 in T.grid(1, 4, 4, 1, 4, 2): for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(1, 3, 5, 4, 4, 32): with T.block("PadInput"): - i0 = T.axis.spatial(1, ax0) - i1 = T.axis.spatial(18, i1_0 * 8 + i1_1 * 2 + ax1) - i2 = T.axis.spatial(18, i2_1 * 4 + ax2) - i3, i4, i5 = T.axis.remap("SSS", [ax3, ax4, ax5]) - T.reads(inputs[i0, i1 - 1, i2 - 1, i3, i4, i5]) - T.writes(PadInput[i0, i1, i2, i3, i4, i5]) - PadInput[i0, i1, i2, i3, i4, i5] = T.if_then_else(1 <= i1 and i1 < 17 and 1 <= i2 and i2 < 17, inputs[i0, i1 - 1, i2 - 1, i3, i4, i5], T.float32(0), dtype="float32") - for i6_0, i7_0, i8_0, i9_0, i0_2, i1_2, i2_2, i3_2, i4_2, i5_2, i6_1, i7_1, i8_1, i9_1, i0_3, i1_3, i2_3, i3_3, i4_3, i5_3 in T.grid(1, 3, 4, 1, 1, 1, 2, 1, 1, 1, 3, 1, 1, 32, 1, 1, 1, 4, 1, 16): + v_i0 = T.axis.spatial(1, ax0) + v_i1 = T.axis.spatial(18, h_0 * 8 + h_1 * 2 + ax1) + v_i2 = T.axis.spatial(18, w_1 * 4 + ax2) + v_i3, v_i4, v_i5 = T.axis.remap("SSS", [ax3, ax4, ax5]) + T.reads(inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3, v_i4, v_i5]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3, v_i4, v_i5]) + PadInput[v_i0, v_i1, v_i2, v_i3, v_i4, v_i5] = T.if_then_else(1 <= v_i1 and v_i1 < 17 and 1 <= v_i2 and v_i2 < 17, inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3, v_i4, v_i5], T.float32(0)) + for rh_0, rw_0, cap_k_0, rc_0, n_2, h_2, w_2, cap_i_2, cap_j_2, co_2, rh_1, rw_1, cap_k_1, rc_1, n_3, h_3, w_3, cap_i_3, cap_j_3, co_3 in T.grid(1, 3, 4, 1, 1, 1, 2, 1, 1, 1, 3, 1, 1, 32, 1, 1, 1, 4, 1, 16): with T.block("conv2d_capsule_nhwijc"): - n = T.axis.spatial(1, i0_2 + i0_3 + i0_0 + i0_1) - h = T.axis.spatial(8, i1_0 * 4 + i1_1 + i1_2 + i1_3) - w = T.axis.spatial(8, i2_0 * 8 + i2_1 * 2 + i2_2 + i2_3) - cap_i = T.axis.spatial(4, i3_0 * 4 + i3_1 * 4 + i3_2 * 4 + i3_3) - cap_j = T.axis.spatial(4, i4_0 * 4 + i4_1 + i4_2 + i4_3) - co = T.axis.spatial(32, i5_0 * 32 + i5_1 * 16 + i5_2 * 16 + i5_3) - rh = T.axis.reduce(3, i6_0 * 3 + i6_1) - rw = T.axis.reduce(3, i7_1 + i7_0) - cap_k = T.axis.reduce(4, i8_0 + i8_1) - rc = T.axis.reduce(32, i9_0 * 32 + i9_1) - T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc], weight[rh, rw, cap_k, cap_j, rc, co]) - T.writes(conv2d_capsule_nhwijc_global[n, h, w, cap_i, cap_j, co]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) + v_h = T.axis.spatial(8, h_0 * 4 + h_1 + h_2 + h_3) + v_w = T.axis.spatial(8, w_0 * 8 + w_1 * 2 + w_2 + w_3) + v_cap_i = T.axis.spatial(4, cap_i_0 * 4 + cap_i_1 * 4 + cap_i_2 * 4 + cap_i_3) + v_cap_j = T.axis.spatial(4, cap_j_0 * 4 + cap_j_1 + cap_j_2 + cap_j_3) + v_co = T.axis.spatial(32, co_0 * 32 + co_1 * 16 + co_2 * 16 + co_3) + v_rh = T.axis.reduce(3, rh_0 * 3 + rh_1) + v_rw = T.axis.reduce(3, rw_0 + rw_1) + v_cap_k = T.axis.reduce(4, cap_k_0 + cap_k_1) + v_rc = T.axis.reduce(32, rc_0 * 32 + rc_1) + T.reads(PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_cap_i, v_cap_k, v_rc], weight[v_rh, v_rw, v_cap_k, v_cap_j, v_rc, v_co]) + T.writes(conv2d_capsule_nhwijc_global[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - conv2d_capsule_nhwijc_global[n, h, w, cap_i, cap_j, co] = T.float32(0) - conv2d_capsule_nhwijc_global[n, h, w, cap_i, cap_j, co] = conv2d_capsule_nhwijc_global[n, h, w, cap_i, cap_j, co] + PadInput[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc] * weight[rh, rw, cap_k, cap_j, rc, co] + conv2d_capsule_nhwijc_global[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co] = T.float32(0) + conv2d_capsule_nhwijc_global[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co] = conv2d_capsule_nhwijc_global[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co] + PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_cap_i, v_cap_k, v_rc] * weight[v_rh, v_rw, v_cap_k, v_cap_j, v_rc, v_co] for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(1, 4, 8, 4, 4, 32): with T.block("conv2d_capsule_nhwijc_global"): v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(8, i1_0 * 4 + ax1) + v1 = T.axis.spatial(8, h_0 * 4 + ax1) v2, v3, v4, v5 = T.axis.remap("SSSS", [ax2, ax3, ax4, ax5]) T.reads(conv2d_capsule_nhwijc_global[v0, v1, v2, v3, v4, v5]) T.writes(conv2d_capsule_nhwijc[v0, v1, v2, v3, v4, v5]) conv2d_capsule_nhwijc[v0, v1, v2, v3, v4, v5] = conv2d_capsule_nhwijc_global[v0, v1, v2, v3, v4, v5] @T.prim_func def cap_2(inputs: T.Buffer((1, 16, 16, 4, 4, 32), "float32"), weight: T.Buffer((3, 3, 4, 4, 32, 32), "float32"), conv2d_capsule_nhwijc: T.Buffer((1, 8, 8, 4, 4, 32), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64}) - PadInput = T.alloc_buffer([1, 18, 18, 4, 4, 32], dtype="float32") + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) + PadInput = T.alloc_buffer((1, 18, 18, 4, 4, 32)) for i0, i1, i2, i3, i4, i5 in T.grid(1, 18, 18, 4, 4, 32): with T.block("PadInput"): - i0_1, i1_1, i2_1, i3_1, i4_1, i5_1 = T.axis.remap("SSSSSS", [i0, i1, i2, i3, i4, i5]) - T.reads(inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1, i4_1, i5_1]) - T.writes(PadInput[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1]) - PadInput[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1] = T.if_then_else(1 <= i1_1 and i1_1 < 17 and 1 <= i2_1 and i2_1 < 17, inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1, i4_1, i5_1], T.float32(0), dtype="float32") - for i0_0, i1_0, i2_0, i3_0, i4_0, i5_0, i0_1_1, i1_1_1, i2_1_1, i3_1_1, i4_1_1, i5_1_1, i6_0, i7_0, i8_0, i9_0, i0_2, i1_2, i2_2, i3_2, i4_2, i5_2, i6_1, i7_1, i8_1, i9_1, i0_3, i1_3, i2_3, i3_3, i4_3, i5_3 in T.grid(1, 2, 1, 1, 1, 1, 1, 4, 4, 1, 4, 2, 1, 3, 4, 1, 1, 1, 2, 1, 1, 1, 3, 1, 1, 32, 1, 1, 1, 4, 1, 16): + v_i0, v_i1, v_i2, v_i3, v_i4, v_i5 = T.axis.remap("SSSSSS", [i0, i1, i2, i3, i4, i5]) + T.reads(inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3, v_i4, v_i5]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3, v_i4, v_i5]) + PadInput[v_i0, v_i1, v_i2, v_i3, v_i4, v_i5] = T.if_then_else(1 <= v_i1 and v_i1 < 17 and 1 <= v_i2 and v_i2 < 17, inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3, v_i4, v_i5], T.float32(0)) + for n_0, h_0, w_0, cap_i_0, cap_j_0, co_0, n_1, h_1, w_1, cap_i_1, cap_j_1, co_1, rh_0, rw_0, cap_k_0, rc_0, n_2, h_2, w_2, cap_i_2, cap_j_2, co_2, rh_1, rw_1, cap_k_1, rc_1, n_3, h_3, w_3, cap_i_3, cap_j_3, co_3 in T.grid(1, 2, 1, 1, 1, 1, 1, 4, 4, 1, 4, 2, 1, 3, 4, 1, 1, 1, 2, 1, 1, 1, 3, 1, 1, 32, 1, 1, 1, 4, 1, 16): with T.block("conv2d_capsule_nhwijc"): - n = T.axis.spatial(1, i0_2 + i0_3 + i0_0 + i0_1_1) - h = T.axis.spatial(8, i1_0 * 4 + i1_1_1 + i1_2 + i1_3) - w = T.axis.spatial(8, i2_0 * 8 + i2_1_1 * 2 + i2_2 + i2_3) - cap_i = T.axis.spatial(4, i3_0 * 4 + i3_1_1 * 4 + i3_2 * 4 + i3_3) - cap_j = T.axis.spatial(4, i4_0 * 4 + i4_1_1 + i4_2 + i4_3) - co = T.axis.spatial(32, i5_0 * 32 + i5_1_1 * 16 + i5_2 * 16 + i5_3) - rh = T.axis.reduce(3, i6_0 * 3 + i6_1) - rw = T.axis.reduce(3, i7_1 + i7_0) - cap_k = T.axis.reduce(4, i8_0 + i8_1) - rc = T.axis.reduce(32, i9_0 * 32 + i9_1) - T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc], weight[rh, rw, cap_k, cap_j, rc, co]) - T.writes(conv2d_capsule_nhwijc[n, h, w, cap_i, cap_j, co]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) + v_h = T.axis.spatial(8, h_0 * 4 + h_1 + h_2 + h_3) + v_w = T.axis.spatial(8, w_0 * 8 + w_1 * 2 + w_2 + w_3) + v_cap_i = T.axis.spatial(4, cap_i_0 * 4 + cap_i_1 * 4 + cap_i_2 * 4 + cap_i_3) + v_cap_j = T.axis.spatial(4, cap_j_0 * 4 + cap_j_1 + cap_j_2 + cap_j_3) + v_co = T.axis.spatial(32, co_0 * 32 + co_1 * 16 + co_2 * 16 + co_3) + v_rh = T.axis.reduce(3, rh_0 * 3 + rh_1) + v_rw = T.axis.reduce(3, rw_0 + rw_1) + v_cap_k = T.axis.reduce(4, cap_k_0 + cap_k_1) + v_rc = T.axis.reduce(32, rc_0 * 32 + rc_1) + T.reads(PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_cap_i, v_cap_k, v_rc], weight[v_rh, v_rw, v_cap_k, v_cap_j, v_rc, v_co]) + T.writes(conv2d_capsule_nhwijc[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - conv2d_capsule_nhwijc[n, h, w, cap_i, cap_j, co] = T.float32(0) - conv2d_capsule_nhwijc[n, h, w, cap_i, cap_j, co] = conv2d_capsule_nhwijc[n, h, w, cap_i, cap_j, co] + PadInput[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc] * weight[rh, rw, cap_k, cap_j, rc, co] + conv2d_capsule_nhwijc[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co] = T.float32(0) + conv2d_capsule_nhwijc[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co] = conv2d_capsule_nhwijc[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co] + PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_cap_i, v_cap_k, v_rc] * weight[v_rh, v_rw, v_cap_k, v_cap_j, v_rc, v_co] # fmt: on decision_0 = [ ("SamplePerfectTile", [1, 1, 1, 1]), @@ -738,77 +715,73 @@ def test_cpu_dep(): # fmt: off @T.prim_func def dep_0(placeholder: T.Buffer((1, 112, 112, 32), "float32"), placeholder_1: T.Buffer((1, 3, 3, 32), "float32"), depth_conv2d_nhwc: T.Buffer((1, 112, 112, 32), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":64, "meta_schedule.vectorize":64}) - PadInput = T.alloc_buffer([1, 114, 114, 32], dtype="float32") - depth_conv2d_nhwc_global = T.alloc_buffer([1, 112, 112, 32], dtype="float32") + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) + PadInput = T.alloc_buffer((1, 114, 114, 32)) + depth_conv2d_nhwc_global = T.alloc_buffer((1, 112, 112, 32)) for i0, i1, i2, i3 in T.grid(1, 114, 114, 32): with T.block("PadInput"): - i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(placeholder[i0_1, i1_1 - 1, i2_1 - 1, i3_1]) - T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) - PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i1_1 and i1_1 < 113 and 1 <= i2_1 and i2_1 < 113, placeholder[i0_1, i1_1 - 1, i2_1 - 1, i3_1], T.float32(0), dtype="float32") - for i0_0, i1_0, i2_0, i3_0, i0_1_1, i1_1_1, i2_1_1, i3_1_1 in T.grid(1, 1, 1, 1, 1, 4, 4, 8): - for i4_0, i5_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 1, 1, 2, 7, 2, 3, 3, 1, 14, 4, 2): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(placeholder[v_i0, v_i1 - 1, v_i2 - 1, v_i3]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) + PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(1 <= v_i1 and v_i1 < 113 and 1 <= v_i2 and v_i2 < 113, placeholder[v_i0, v_i1 - 1, v_i2 - 1, v_i3], T.float32(0)) + for n_0, h_0, w_0, c_0, n_1, h_1, w_1, c_1 in T.grid(1, 1, 1, 1, 1, 4, 4, 8): + for rh_0, rw_0, n_2, h_2, w_2, c_2, rh_1, rw_1, n_3, h_3, w_3, c_3 in T.grid(1, 1, 1, 2, 7, 2, 3, 3, 1, 14, 4, 2): with T.block("depth_conv2d_nhwc"): - n = T.axis.spatial(1, i0_2 + i0_3 + i0_0 + i0_1_1) - h = T.axis.spatial(112, i1_0 * 112 + i1_1_1 * 28 + i1_2 * 14 + i1_3) - w = T.axis.spatial(112, i2_0 * 112 + i2_1_1 * 28 + i2_2 * 4 + i2_3) - c = T.axis.spatial(32, i3_0 * 32 + i3_1_1 * 4 + i3_2 * 2 + i3_3) - rh = T.axis.reduce(3, i4_0 * 3 + i4_1) - rw = T.axis.reduce(3, i5_0 * 3 + i5_1) - T.reads(PadInput[n, h + rh, w + rw, c], placeholder_1[0, rh, rw, c]) - T.writes(depth_conv2d_nhwc_global[n, h, w, c]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) + v_h = T.axis.spatial(112, h_0 * 112 + h_1 * 28 + h_2 * 14 + h_3) + v_w = T.axis.spatial(112, w_0 * 112 + w_1 * 28 + w_2 * 4 + w_3) + v_c = T.axis.spatial(32, c_0 * 32 + c_1 * 4 + c_2 * 2 + c_3) + v_rh = T.axis.reduce(3, rh_0 * 3 + rh_1) + v_rw = T.axis.reduce(3, rw_0 * 3 + rw_1) + T.reads(PadInput[v_n, v_h + v_rh, v_w + v_rw, v_c], placeholder_1[0, v_rh, v_rw, v_c]) + T.writes(depth_conv2d_nhwc_global[v_n, v_h, v_w, v_c]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - depth_conv2d_nhwc_global[n, h, w, c] = T.float32(0) - depth_conv2d_nhwc_global[n, h, w, c] = depth_conv2d_nhwc_global[n, h, w, c] + PadInput[n, h + rh, w + rw, c] * placeholder_1[0, rh, rw, c] + depth_conv2d_nhwc_global[v_n, v_h, v_w, v_c] = T.float32(0) + depth_conv2d_nhwc_global[v_n, v_h, v_w, v_c] = depth_conv2d_nhwc_global[v_n, v_h, v_w, v_c] + PadInput[v_n, v_h + v_rh, v_w + v_rw, v_c] * placeholder_1[0, v_rh, v_rw, v_c] for ax0, ax1, ax2, ax3 in T.grid(1, 28, 28, 4): with T.block("depth_conv2d_nhwc_global"): v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(112, i1_1_1 * 28 + ax1) - v2 = T.axis.spatial(112, i2_1_1 * 28 + ax2) - v3 = T.axis.spatial(32, i3_1_1 * 4 + ax3) + v1 = T.axis.spatial(112, h_1 * 28 + ax1) + v2 = T.axis.spatial(112, w_1 * 28 + ax2) + v3 = T.axis.spatial(32, c_1 * 4 + ax3) T.reads(depth_conv2d_nhwc_global[v0, v1, v2, v3]) T.writes(depth_conv2d_nhwc[v0, v1, v2, v3]) depth_conv2d_nhwc[v0, v1, v2, v3] = depth_conv2d_nhwc_global[v0, v1, v2, v3] @T.prim_func def dep_1(placeholder: T.Buffer((1, 112, 112, 32), "float32"), placeholder_1: T.Buffer((1, 3, 3, 32), "float32"), depth_conv2d_nhwc: T.Buffer((1, 112, 112, 32), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64}) - PadInput = T.alloc_buffer([1, 114, 114, 32], dtype="float32") - depth_conv2d_nhwc_global = T.alloc_buffer([1, 112, 112, 32], dtype="float32") + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) + PadInput = T.alloc_buffer((1, 114, 114, 32)) + depth_conv2d_nhwc_global = T.alloc_buffer((1, 112, 112, 32)) for i0, i1, i2, i3 in T.grid(1, 114, 114, 32): with T.block("PadInput"): - i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(placeholder[i0_1, i1_1 - 1, i2_1 - 1, i3_1]) - T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) - PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i1_1 and i1_1 < 113 and 1 <= i2_1 and i2_1 < 113, placeholder[i0_1, i1_1 - 1, i2_1 - 1, i3_1], T.float32(0), dtype="float32") - for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 1, 1, 1): - for i0_1_1, i1_1_1, i2_1_1, i3_1_1, i4_0, i5_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 4, 4, 8, 1, 1, 1, 2, 7, 2, 3, 3, 1, 14, 4, 2): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(placeholder[v_i0, v_i1 - 1, v_i2 - 1, v_i3]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) + PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(1 <= v_i1 and v_i1 < 113 and 1 <= v_i2 and v_i2 < 113, placeholder[v_i0, v_i1 - 1, v_i2 - 1, v_i3], T.float32(0)) + for n_0, h_0, w_0, c_0 in T.grid(1, 1, 1, 1): + for n_1, h_1, w_1, c_1, rh_0, rw_0, n_2, h_2, w_2, c_2, rh_1, rw_1, n_3, h_3, w_3, c_3 in T.grid(1, 4, 4, 8, 1, 1, 1, 2, 7, 2, 3, 3, 1, 14, 4, 2): with T.block("depth_conv2d_nhwc"): - n = T.axis.spatial(1, i0_2 + i0_3 + i0_0 + i0_1_1) - h = T.axis.spatial(112, i1_0 * 112 + i1_1_1 * 28 + i1_2 * 14 + i1_3) - w = T.axis.spatial(112, i2_0 * 112 + i2_1_1 * 28 + i2_2 * 4 + i2_3) - c = T.axis.spatial(32, i3_0 * 32 + i3_1_1 * 4 + i3_2 * 2 + i3_3) - rh = T.axis.reduce(3, i4_0 * 3 + i4_1) - rw = T.axis.reduce(3, i5_0 * 3 + i5_1) - T.reads(PadInput[n, h + rh, w + rw, c], placeholder_1[0, rh, rw, c]) - T.writes(depth_conv2d_nhwc_global[n, h, w, c]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) + v_h = T.axis.spatial(112, h_0 * 112 + h_1 * 28 + h_2 * 14 + h_3) + v_w = T.axis.spatial(112, w_0 * 112 + w_1 * 28 + w_2 * 4 + w_3) + v_c = T.axis.spatial(32, c_0 * 32 + c_1 * 4 + c_2 * 2 + c_3) + v_rh = T.axis.reduce(3, rh_0 * 3 + rh_1) + v_rw = T.axis.reduce(3, rw_0 * 3 + rw_1) + T.reads(PadInput[v_n, v_h + v_rh, v_w + v_rw, v_c], placeholder_1[0, v_rh, v_rw, v_c]) + T.writes(depth_conv2d_nhwc_global[v_n, v_h, v_w, v_c]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - depth_conv2d_nhwc_global[n, h, w, c] = T.float32(0) - depth_conv2d_nhwc_global[n, h, w, c] = depth_conv2d_nhwc_global[n, h, w, c] + PadInput[n, h + rh, w + rw, c] * placeholder_1[0, rh, rw, c] + depth_conv2d_nhwc_global[v_n, v_h, v_w, v_c] = T.float32(0) + depth_conv2d_nhwc_global[v_n, v_h, v_w, v_c] = depth_conv2d_nhwc_global[v_n, v_h, v_w, v_c] + PadInput[v_n, v_h + v_rh, v_w + v_rw, v_c] * placeholder_1[0, v_rh, v_rw, v_c] for ax0, ax1, ax2, ax3 in T.grid(1, 112, 112, 32): with T.block("depth_conv2d_nhwc_global"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) @@ -817,37 +790,35 @@ def dep_1(placeholder: T.Buffer((1, 112, 112, 32), "float32"), placeholder_1: T. depth_conv2d_nhwc[v0, v1, v2, v3] = depth_conv2d_nhwc_global[v0, v1, v2, v3] @T.prim_func def dep_2(placeholder: T.Buffer((1, 112, 112, 32), "float32"), placeholder_1: T.Buffer((1, 3, 3, 32), "float32"), depth_conv2d_nhwc: T.Buffer((1, 112, 112, 32), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":0, "meta_schedule.vectorize":64}) - PadInput = T.alloc_buffer([1, 114, 114, 32], dtype="float32") - for i0_0, i1_0, i2_0, i3_0, i0_1, i1_1 in T.grid(1, 1, 1, 1, 1, 4): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) + PadInput = T.alloc_buffer((1, 114, 114, 32)) + for n_0, h_0, w_0, c_0, n_1, h_1 in T.grid(1, 1, 1, 1, 1, 4): for ax0, ax1, ax2, ax3 in T.grid(1, 30, 114, 32): with T.block("PadInput"): - i0 = T.axis.spatial(1, ax0) - i1 = T.axis.spatial(114, i1_1 * 28 + ax1) - i2, i3 = T.axis.remap("SS", [ax2, ax3]) - T.reads(placeholder[i0, i1 - 1, i2 - 1, i3]) - T.writes(PadInput[i0, i1, i2, i3]) - PadInput[i0, i1, i2, i3] = T.if_then_else(1 <= i1 and i1 < 113 and 1 <= i2 and i2 < 113, placeholder[i0, i1 - 1, i2 - 1, i3], T.float32(0), dtype="float32") - for i2_1, i3_1, i4_0, i5_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i0_3, i1_3, i2_3, i3_3 in T.grid(4, 8, 1, 1, 1, 2, 7, 2, 3, 3, 1, 14, 4, 2): + v_i0 = T.axis.spatial(1, ax0) + v_i1 = T.axis.spatial(114, h_1 * 28 + ax1) + v_i2, v_i3 = T.axis.remap("SS", [ax2, ax3]) + T.reads(placeholder[v_i0, v_i1 - 1, v_i2 - 1, v_i3]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) + PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(1 <= v_i1 and v_i1 < 113 and 1 <= v_i2 and v_i2 < 113, placeholder[v_i0, v_i1 - 1, v_i2 - 1, v_i3], T.float32(0)) + for w_1, c_1, rh_0, rw_0, n_2, h_2, w_2, c_2, rh_1, rw_1, n_3, h_3, w_3, c_3 in T.grid(4, 8, 1, 1, 1, 2, 7, 2, 3, 3, 1, 14, 4, 2): with T.block("depth_conv2d_nhwc"): - n = T.axis.spatial(1, i0_2 + i0_3 + i0_0 + i0_1) - h = T.axis.spatial(112, i1_0 * 112 + i1_1 * 28 + i1_2 * 14 + i1_3) - w = T.axis.spatial(112, i2_0 * 112 + i2_1 * 28 + i2_2 * 4 + i2_3) - c = T.axis.spatial(32, i3_0 * 32 + i3_1 * 4 + i3_2 * 2 + i3_3) - rh = T.axis.reduce(3, i4_0 * 3 + i4_1) - rw = T.axis.reduce(3, i5_0 * 3 + i5_1) - T.reads(PadInput[n, h + rh, w + rw, c], placeholder_1[0, rh, rw, c]) - T.writes(depth_conv2d_nhwc[n, h, w, c]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) + v_h = T.axis.spatial(112, h_0 * 112 + h_1 * 28 + h_2 * 14 + h_3) + v_w = T.axis.spatial(112, w_0 * 112 + w_1 * 28 + w_2 * 4 + w_3) + v_c = T.axis.spatial(32, c_0 * 32 + c_1 * 4 + c_2 * 2 + c_3) + v_rh = T.axis.reduce(3, rh_0 * 3 + rh_1) + v_rw = T.axis.reduce(3, rw_0 * 3 + rw_1) + T.reads(PadInput[v_n, v_h + v_rh, v_w + v_rw, v_c], placeholder_1[0, v_rh, v_rw, v_c]) + T.writes(depth_conv2d_nhwc[v_n, v_h, v_w, v_c]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - depth_conv2d_nhwc[n, h, w, c] = T.float32(0) - depth_conv2d_nhwc[n, h, w, c] = depth_conv2d_nhwc[n, h, w, c] + PadInput[n, h + rh, w + rw, c] * placeholder_1[0, rh, rw, c] + depth_conv2d_nhwc[v_n, v_h, v_w, v_c] = T.float32(0) + depth_conv2d_nhwc[v_n, v_h, v_w, v_c] = depth_conv2d_nhwc[v_n, v_h, v_w, v_c] + PadInput[v_n, v_h + v_rh, v_w + v_rw, v_c] * placeholder_1[0, v_rh, v_rw, v_c] # fmt: on decision_0 = [ ("SamplePerfectTile", [1, 1, 1, 1]), @@ -893,131 +864,124 @@ def test_cpu_dil(): # fmt: off @T.prim_func def dil_0(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 109, 109, 64), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":64, "meta_schedule.vectorize":64}) - PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") - conv2d_nhwc_global = T.alloc_buffer([1, 109, 109, 64], dtype="float32") - for i0_0, i1_0, i2_0, i3_0, i0_1, i1_1, i2_1, i3_1 in T.grid(1, 109, 1, 4, 1, 1, 1, 2): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) + PadInput = T.alloc_buffer((1, 230, 230, 3)) + conv2d_nhwc_global = T.alloc_buffer((1, 109, 109, 64)) + for n_0, h_0, w_0, co_0, n_1, h_1, w_1, co_1 in T.grid(1, 109, 1, 4, 1, 1, 1, 2): for ax0, ax1, ax2, ax3 in T.grid(1, 13, 229, 3): with T.block("PadInput"): - i0 = T.axis.spatial(1, ax0) - i1 = T.axis.spatial(230, i1_0 * 2 + ax1) - i2 = T.axis.spatial(230, ax2) - i3 = T.axis.spatial(3, ax3) - T.reads(inputs[i0, i1 - 3, i2 - 3, i3]) - T.writes(PadInput[i0, i1, i2, i3]) - PadInput[i0, i1, i2, i3] = T.if_then_else(3 <= i1 and i1 < 227 and 3 <= i2 and i2 < 227, inputs[i0, i1 - 3, i2 - 3, i3], T.float32(0), dtype="float32") - for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(7, 1, 1, 1, 1, 109, 8, 1, 7, 3, 1, 1, 1, 1): + v_i0 = T.axis.spatial(1, ax0) + v_i1 = T.axis.spatial(230, h_0 * 2 + ax1) + v_i2 = T.axis.spatial(230, ax2) + v_i3 = T.axis.spatial(3, ax3) + T.reads(inputs[v_i0, v_i1 - 3, v_i2 - 3, v_i3]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) + PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(3 <= v_i1 and v_i1 < 227 and 3 <= v_i2 and v_i2 < 227, inputs[v_i0, v_i1 - 3, v_i2 - 3, v_i3], T.float32(0)) + for rh_0, rw_0, rc_0, n_2, h_2, w_2, co_2, rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3 in T.grid(7, 1, 1, 1, 1, 109, 8, 1, 7, 3, 1, 1, 1, 1): with T.block("conv2d_nhwc"): - n = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2) - h = T.axis.spatial(109, i1_2 + i1_3 + i1_0 + i1_1) - w = T.axis.spatial(109, i2_3 + i2_0 * 109 + i2_1 * 109 + i2_2) - co = T.axis.spatial(64, i3_0 * 16 + i3_1 * 8 + i3_2 + i3_3) - rh = T.axis.reduce(7, i4_1 + i4_0) - rw = T.axis.reduce(7, i5_0 * 7 + i5_1) - rc = T.axis.reduce(3, i6_0 * 3 + i6_1) - T.reads(PadInput[n, h * 2 + rh * 2, w * 2 + rw * 2, co // 64 * 3 + rc], weight[rh, rw, rc, co]) - T.writes(conv2d_nhwc_global[n, h, w, co]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) + v_h = T.axis.spatial(109, h_0 + h_1 + h_2 + h_3) + v_w = T.axis.spatial(109, w_0 * 109 + w_1 * 109 + w_2 + w_3) + v_co = T.axis.spatial(64, co_0 * 16 + co_1 * 8 + co_2 + co_3) + v_rh = T.axis.reduce(7, rh_0 + rh_1) + v_rw = T.axis.reduce(7, rw_0 * 7 + rw_1) + v_rc = T.axis.reduce(3, rc_0 * 3 + rc_1) + T.reads(PadInput[v_n, v_h * 2 + v_rh * 2, v_w * 2 + v_rw * 2, v_co // 64 * 3 + v_rc], weight[v_rh, v_rw, v_rc, v_co]) + T.writes(conv2d_nhwc_global[v_n, v_h, v_w, v_co]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - conv2d_nhwc_global[n, h, w, co] = T.float32(0) - conv2d_nhwc_global[n, h, w, co] = conv2d_nhwc_global[n, h, w, co] + PadInput[n, h * 2 + rh * 2, w * 2 + rw * 2, co // 64 * 3 + rc] * weight[rh, rw, rc, co] + conv2d_nhwc_global[v_n, v_h, v_w, v_co] = T.float32(0) + conv2d_nhwc_global[v_n, v_h, v_w, v_co] = conv2d_nhwc_global[v_n, v_h, v_w, v_co] + PadInput[v_n, v_h * 2 + v_rh * 2, v_w * 2 + v_rw * 2, v_co // 64 * 3 + v_rc] * weight[v_rh, v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3 in T.grid(1, 1, 109, 8): with T.block("conv2d_nhwc_global"): v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(109, i1_0 + ax1) + v1 = T.axis.spatial(109, h_0 + ax1) v2 = T.axis.spatial(109, ax2) - v3 = T.axis.spatial(64, i3_0 * 16 + i3_1 * 8 + ax3) + v3 = T.axis.spatial(64, co_0 * 16 + co_1 * 8 + ax3) T.reads(conv2d_nhwc_global[v0, v1, v2, v3]) T.writes(conv2d_nhwc[v0, v1, v2, v3]) conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3] @T.prim_func def dil_1(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 109, 109, 64), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":0, "meta_schedule.vectorize":64}) - PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") - conv2d_nhwc_global = T.alloc_buffer([1, 109, 109, 64], dtype="float32") - for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 109, 1, 4): - for i0_1, i1_1, i2_1, i3_1, i4_0 in T.grid(1, 1, 1, 2, 7): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) + PadInput = T.alloc_buffer((1, 230, 230, 3)) + conv2d_nhwc_global = T.alloc_buffer((1, 109, 109, 64)) + for n_0, h_0, w_0, co_0 in T.grid(1, 109, 1, 4): + for n_1, h_1, w_1, co_1, rh_0 in T.grid(1, 1, 1, 2, 7): for ax0, ax1, ax2, ax3 in T.grid(1, 1, 229, 3): with T.block("PadInput"): - i0 = T.axis.spatial(1, ax0) - i1 = T.axis.spatial(230, i1_0 * 2 + i4_0 * 2 + ax1) - i2 = T.axis.spatial(230, ax2) - i3 = T.axis.spatial(3, ax3) - T.reads(inputs[i0, i1 - 3, i2 - 3, i3]) - T.writes(PadInput[i0, i1, i2, i3]) - PadInput[i0, i1, i2, i3] = T.if_then_else(3 <= i1 and i1 < 227 and 3 <= i2 and i2 < 227, inputs[i0, i1 - 3, i2 - 3, i3], T.float32(0), dtype="float32") - for i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 1, 1, 1, 109, 8, 1, 7, 3, 1, 1, 1, 1): + v_i0 = T.axis.spatial(1, ax0) + v_i1 = T.axis.spatial(230, h_0 * 2 + rh_0 * 2 + ax1) + v_i2 = T.axis.spatial(230, ax2) + v_i3 = T.axis.spatial(3, ax3) + T.reads(inputs[v_i0, v_i1 - 3, v_i2 - 3, v_i3]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) + PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(3 <= v_i1 and v_i1 < 227 and 3 <= v_i2 and v_i2 < 227, inputs[v_i0, v_i1 - 3, v_i2 - 3, v_i3], T.float32(0)) + for rw_0, rc_0, n_2, h_2, w_2, co_2, rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3 in T.grid(1, 1, 1, 1, 109, 8, 1, 7, 3, 1, 1, 1, 1): with T.block("conv2d_nhwc"): - n = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2) - h = T.axis.spatial(109, i1_2 + i1_3 + i1_0 + i1_1) - w = T.axis.spatial(109, i2_3 + i2_0 * 109 + i2_1 * 109 + i2_2) - co = T.axis.spatial(64, i3_0 * 16 + i3_1 * 8 + i3_2 + i3_3) - rh = T.axis.reduce(7, i4_1 + i4_0) - rw = T.axis.reduce(7, i5_0 * 7 + i5_1) - rc = T.axis.reduce(3, i6_0 * 3 + i6_1) - T.reads(PadInput[n, h * 2 + rh * 2, w * 2 + rw * 2, co // 64 * 3 + rc], weight[rh, rw, rc, co]) - T.writes(conv2d_nhwc_global[n, h, w, co]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) + v_h = T.axis.spatial(109, h_0 + h_1 + h_2 + h_3) + v_w = T.axis.spatial(109, w_0 * 109 + w_1 * 109 + w_2 + w_3) + v_co = T.axis.spatial(64, co_0 * 16 + co_1 * 8 + co_2 + co_3) + v_rh = T.axis.reduce(7, rh_0 + rh_1) + v_rw = T.axis.reduce(7, rw_0 * 7 + rw_1) + v_rc = T.axis.reduce(3, rc_0 * 3 + rc_1) + T.reads(PadInput[v_n, v_h * 2 + v_rh * 2, v_w * 2 + v_rw * 2, v_co // 64 * 3 + v_rc], weight[v_rh, v_rw, v_rc, v_co]) + T.writes(conv2d_nhwc_global[v_n, v_h, v_w, v_co]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - conv2d_nhwc_global[n, h, w, co] = T.float32(0) - conv2d_nhwc_global[n, h, w, co] = conv2d_nhwc_global[n, h, w, co] + PadInput[n, h * 2 + rh * 2, w * 2 + rw * 2, co // 64 * 3 + rc] * weight[rh, rw, rc, co] + conv2d_nhwc_global[v_n, v_h, v_w, v_co] = T.float32(0) + conv2d_nhwc_global[v_n, v_h, v_w, v_co] = conv2d_nhwc_global[v_n, v_h, v_w, v_co] + PadInput[v_n, v_h * 2 + v_rh * 2, v_w * 2 + v_rw * 2, v_co // 64 * 3 + v_rc] * weight[v_rh, v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3 in T.grid(1, 1, 109, 16): with T.block("conv2d_nhwc_global"): v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(109, i1_0 + ax1) + v1 = T.axis.spatial(109, h_0 + ax1) v2 = T.axis.spatial(109, ax2) - v3 = T.axis.spatial(64, i3_0 * 16 + ax3) + v3 = T.axis.spatial(64, co_0 * 16 + ax3) T.reads(conv2d_nhwc_global[v0, v1, v2, v3]) T.writes(conv2d_nhwc[v0, v1, v2, v3]) conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3] @T.prim_func def dil_2(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 109, 109, 64), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":0, "meta_schedule.vectorize":64}) - PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") - for i0_0, i1_0 in T.grid(1, 109): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) + PadInput = T.alloc_buffer((1, 230, 230, 3)) + for n_0, h_0 in T.grid(1, 109): for ax0, ax1, ax2, ax3 in T.grid(1, 13, 229, 3): with T.block("PadInput"): - i0 = T.axis.spatial(1, ax0) - i1 = T.axis.spatial(230, i1_0 * 2 + ax1) - i2 = T.axis.spatial(230, ax2) - i3 = T.axis.spatial(3, ax3) - T.reads(inputs[i0, i1 - 3, i2 - 3, i3]) - T.writes(PadInput[i0, i1, i2, i3]) - PadInput[i0, i1, i2, i3] = T.if_then_else(3 <= i1 and i1 < 227 and 3 <= i2 and i2 < 227, inputs[i0, i1 - 3, i2 - 3, i3], T.float32(0), dtype="float32") - for i2_0, i3_0, i0_1, i1_1, i2_1, i3_1, i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 4, 1, 1, 1, 2, 7, 1, 1, 1, 1, 109, 8, 1, 7, 3, 1, 1, 1, 1): + v_i0 = T.axis.spatial(1, ax0) + v_i1 = T.axis.spatial(230, h_0 * 2 + ax1) + v_i2 = T.axis.spatial(230, ax2) + v_i3 = T.axis.spatial(3, ax3) + T.reads(inputs[v_i0, v_i1 - 3, v_i2 - 3, v_i3]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) + PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(3 <= v_i1 and v_i1 < 227 and 3 <= v_i2 and v_i2 < 227, inputs[v_i0, v_i1 - 3, v_i2 - 3, v_i3], T.float32(0)) + for w_0, co_0, n_1, h_1, w_1, co_1, rh_0, rw_0, rc_0, n_2, h_2, w_2, co_2, rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3 in T.grid(1, 4, 1, 1, 1, 2, 7, 1, 1, 1, 1, 109, 8, 1, 7, 3, 1, 1, 1, 1): with T.block("conv2d_nhwc"): - n = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2) - h = T.axis.spatial(109, i1_2 + i1_3 + i1_0 + i1_1) - w = T.axis.spatial(109, i2_3 + i2_0 * 109 + i2_1 * 109 + i2_2) - co = T.axis.spatial(64, i3_0 * 16 + i3_1 * 8 + i3_2 + i3_3) - rh = T.axis.reduce(7, i4_1 + i4_0) - rw = T.axis.reduce(7, i5_0 * 7 + i5_1) - rc = T.axis.reduce(3, i6_0 * 3 + i6_1) - T.reads(PadInput[n, h * 2 + rh * 2, w * 2 + rw * 2, co // 64 * 3 + rc], weight[rh, rw, rc, co]) - T.writes(conv2d_nhwc[n, h, w, co]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) + v_h = T.axis.spatial(109, h_0 + h_1 + h_2 + h_3) + v_w = T.axis.spatial(109, w_0 * 109 + w_1 * 109 + w_2 + w_3) + v_co = T.axis.spatial(64, co_0 * 16 + co_1 * 8 + co_2 + co_3) + v_rh = T.axis.reduce(7, rh_0 + rh_1) + v_rw = T.axis.reduce(7, rw_0 * 7 + rw_1) + v_rc = T.axis.reduce(3, rc_0 * 3 + rc_1) + T.reads(PadInput[v_n, v_h * 2 + v_rh * 2, v_w * 2 + v_rw * 2, v_co // 64 * 3 + v_rc], weight[v_rh, v_rw, v_rc, v_co]) + T.writes(conv2d_nhwc[v_n, v_h, v_w, v_co]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - conv2d_nhwc[n, h, w, co] = T.float32(0) - conv2d_nhwc[n, h, w, co] = conv2d_nhwc[n, h, w, co] + PadInput[n, h * 2 + rh * 2, w * 2 + rw * 2, co // 64 * 3 + rc] * weight[rh, rw, rc, co] - + conv2d_nhwc[v_n, v_h, v_w, v_co] = T.float32(0) + conv2d_nhwc[v_n, v_h, v_w, v_co] = conv2d_nhwc[v_n, v_h, v_w, v_co] + PadInput[v_n, v_h * 2 + v_rh * 2, v_w * 2 + v_rw * 2, v_co // 64 * 3 + v_rc] * weight[v_rh, v_rw, v_rc, v_co] # fmt: on decision_0 = [ ("SamplePerfectTile", [1, 1, 1, 1]), @@ -1066,87 +1030,81 @@ def test_cpu_gmm(): # fmt: off @T.prim_func def gmm_0(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), "float32"), Z: T.Buffer((1, 128, 128), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64}) - Z_global = T.alloc_buffer([1, 128, 128], dtype="float32") - for i0_0, i1_0, i2_0, i0_1, i1_1, i2_1 in T.grid(1, 4, 2, 1, 1, 8): - for i3_0, i0_2, i1_2, i2_2, i3_1, i0_3, i1_3, i2_3 in T.grid(128, 1, 16, 1, 1, 1, 2, 8): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) + Z_global = T.alloc_buffer((1, 128, 128)) + for b_0, i_0, j_0, b_1, i_1, j_1 in T.grid(1, 4, 2, 1, 1, 8): + for k_0, b_2, i_2, j_2, k_1, b_3, i_3, j_3 in T.grid(128, 1, 16, 1, 1, 1, 2, 8): with T.block("Z"): - b = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3) - i = T.axis.spatial(128, i1_0 * 32 + i1_1 * 32 + i1_2 * 2 + i1_3) - j = T.axis.spatial(128, i2_0 * 64 + i2_1 * 8 + i2_2 * 8 + i2_3) - k = T.axis.reduce(128, i3_1 + i3_0) - T.reads(X[b, i, k], Y[b, k, j]) - T.writes(Z_global[b, i, j]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_b = T.axis.spatial(1, b_0 + b_1 + b_2 + b_3) + v_i = T.axis.spatial(128, i_0 * 32 + i_1 * 32 + i_2 * 2 + i_3) + v_j = T.axis.spatial(128, j_0 * 64 + j_1 * 8 + j_2 * 8 + j_3) + v_k = T.axis.reduce(128, k_0 + k_1) + T.reads(X[v_b, v_i, v_k], Y[v_b, v_k, v_j]) + T.writes(Z_global[v_b, v_i, v_j]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - Z_global[b, i, j] = T.float32(0) - Z_global[b, i, j] = Z_global[b, i, j] + X[b, i, k] * Y[b, k, j] + Z_global[v_b, v_i, v_j] = T.float32(0) + Z_global[v_b, v_i, v_j] = Z_global[v_b, v_i, v_j] + X[v_b, v_i, v_k] * Y[v_b, v_k, v_j] for ax0, ax1, ax2 in T.grid(1, 32, 8): with T.block("Z_global"): v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(128, i1_0 * 32 + ax1) - v2 = T.axis.spatial(128, i2_0 * 64 + i2_1 * 8 + ax2) + v1 = T.axis.spatial(128, i_0 * 32 + ax1) + v2 = T.axis.spatial(128, j_0 * 64 + j_1 * 8 + ax2) T.reads(Z_global[v0, v1, v2]) T.writes(Z[v0, v1, v2]) Z[v0, v1, v2] = Z_global[v0, v1, v2] @T.prim_func def gmm_1(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), "float32"), Z: T.Buffer((1, 128, 128), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64}) - Z_global = T.alloc_buffer([1, 128, 128], dtype="float32") - for i0_0, i1_0, i2_0 in T.grid(1, 4, 2): - for i0_1, i1_1, i2_1, i3_0, i0_2, i1_2, i2_2, i3_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 8, 128, 1, 16, 1, 1, 1, 2, 8): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) + Z_global = T.alloc_buffer((1, 128, 128)) + for b_0, i_0, j_0 in T.grid(1, 4, 2): + for b_1, i_1, j_1, k_0, b_2, i_2, j_2, k_1, b_3, i_3, j_3 in T.grid(1, 1, 8, 128, 1, 16, 1, 1, 1, 2, 8): with T.block("Z"): - b = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3) - i = T.axis.spatial(128, i1_0 * 32 + i1_1 * 32 + i1_2 * 2 + i1_3) - j = T.axis.spatial(128, i2_0 * 64 + i2_1 * 8 + i2_2 * 8 + i2_3) - k = T.axis.reduce(128, i3_1 + i3_0) - T.reads(X[b, i, k], Y[b, k, j]) - T.writes(Z_global[b, i, j]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_b = T.axis.spatial(1, b_0 + b_1 + b_2 + b_3) + v_i = T.axis.spatial(128, i_0 * 32 + i_1 * 32 + i_2 * 2 + i_3) + v_j = T.axis.spatial(128, j_0 * 64 + j_1 * 8 + j_2 * 8 + j_3) + v_k = T.axis.reduce(128, k_0 + k_1) + T.reads(X[v_b, v_i, v_k], Y[v_b, v_k, v_j]) + T.writes(Z_global[v_b, v_i, v_j]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - Z_global[b, i, j] = T.float32(0) - Z_global[b, i, j] = Z_global[b, i, j] + X[b, i, k] * Y[b, k, j] + Z_global[v_b, v_i, v_j] = T.float32(0) + Z_global[v_b, v_i, v_j] = Z_global[v_b, v_i, v_j] + X[v_b, v_i, v_k] * Y[v_b, v_k, v_j] for ax0, ax1, ax2 in T.grid(1, 32, 64): with T.block("Z_global"): v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(128, i1_0 * 32 + ax1) - v2 = T.axis.spatial(128, i2_0 * 64 + ax2) + v1 = T.axis.spatial(128, i_0 * 32 + ax1) + v2 = T.axis.spatial(128, j_0 * 64 + ax2) T.reads(Z_global[v0, v1, v2]) T.writes(Z[v0, v1, v2]) Z[v0, v1, v2] = Z_global[v0, v1, v2] @T.prim_func def gmm_2(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), "float32"), Z: T.Buffer((1, 128, 128), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64}) - for i0_0, i1_0, i2_0, i0_1, i1_1, i2_1, i3_0, i0_2, i1_2, i2_2, i3_1, i0_3, i1_3, i2_3 in T.grid(1, 4, 2, 1, 1, 8, 128, 1, 16, 1, 1, 1, 2, 8): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) + for b_0, i_0, j_0, b_1, i_1, j_1, k_0, b_2, i_2, j_2, k_1, b_3, i_3, j_3 in T.grid(1, 4, 2, 1, 1, 8, 128, 1, 16, 1, 1, 1, 2, 8): with T.block("Z"): - b = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3) - i = T.axis.spatial(128, i1_0 * 32 + i1_1 * 32 + i1_2 * 2 + i1_3) - j = T.axis.spatial(128, i2_0 * 64 + i2_1 * 8 + i2_2 * 8 + i2_3) - k = T.axis.reduce(128, i3_1 + i3_0) - T.reads(X[b, i, k], Y[b, k, j]) - T.writes(Z[b, i, j]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_b = T.axis.spatial(1, b_0 + b_1 + b_2 + b_3) + v_i = T.axis.spatial(128, i_0 * 32 + i_1 * 32 + i_2 * 2 + i_3) + v_j = T.axis.spatial(128, j_0 * 64 + j_1 * 8 + j_2 * 8 + j_3) + v_k = T.axis.reduce(128, k_0 + k_1) + T.reads(X[v_b, v_i, v_k], Y[v_b, v_k, v_j]) + T.writes(Z[v_b, v_i, v_j]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - Z[b, i, j] = T.float32(0) - Z[b, i, j] = Z[b, i, j] + X[b, i, k] * Y[b, k, j] + Z[v_b, v_i, v_j] = T.float32(0) + Z[v_b, v_i, v_j] = Z[v_b, v_i, v_j] + X[v_b, v_i, v_k] * Y[v_b, v_k, v_j] # fmt: on decision_0 = [ ("SamplePerfectTile", [1, 1, 1, 1]), @@ -1183,127 +1141,121 @@ def test_cpu_grp(): # fmt: off @T.prim_func def grp_0(inputs: T.Buffer((1, 56, 56, 64), "float32"), weight: T.Buffer((3, 3, 16, 128), "float32"), conv2d_nhwc: T.Buffer((1, 28, 28, 128), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64}) - PadInput = T.alloc_buffer([1, 58, 58, 64], dtype="float32") - conv2d_nhwc_global = T.alloc_buffer([1, 28, 28, 128], dtype="float32") - for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 7, 1, 2): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) + PadInput = T.alloc_buffer((1, 58, 58, 64)) + conv2d_nhwc_global = T.alloc_buffer((1, 28, 28, 128)) + for n_0, h_0, w_0, co_0 in T.grid(1, 7, 1, 2): for ax0, ax1, ax2, ax3 in T.grid(1, 9, 57, 32): with T.block("PadInput"): - i0 = T.axis.spatial(1, ax0) - i1 = T.axis.spatial(58, i1_0 * 8 + ax1) - i2 = T.axis.spatial(58, ax2) - i3 = T.axis.spatial(64, i3_0 * 32 + ax3) - T.reads(inputs[i0, i1 - 1, i2 - 1, i3]) - T.writes(PadInput[i0, i1, i2, i3]) - PadInput[i0, i1, i2, i3] = T.if_then_else(1 <= i1 and i1 < 57 and 1 <= i2 and i2 < 57, inputs[i0, i1 - 1, i2 - 1, i3], T.float32(0), dtype="float32") - for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 4, 1, 1): - for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 3, 8, 1, 1, 4, 4, 3, 1, 2, 1, 1, 7, 16): + v_i0 = T.axis.spatial(1, ax0) + v_i1 = T.axis.spatial(58, h_0 * 8 + ax1) + v_i2 = T.axis.spatial(58, ax2) + v_i3 = T.axis.spatial(64, co_0 * 32 + ax3) + T.reads(inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) + PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(1 <= v_i1 and v_i1 < 57 and 1 <= v_i2 and v_i2 < 57, inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3], T.float32(0)) + for n_1, h_1, w_1, co_1 in T.grid(1, 4, 1, 1): + for rh_0, rw_0, rc_0, n_2, h_2, w_2, co_2, rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3 in T.grid(1, 3, 8, 1, 1, 4, 4, 3, 1, 2, 1, 1, 7, 16): with T.block("conv2d_nhwc"): - n = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2) - h = T.axis.spatial(28, i1_0 * 4 + i1_1 + i1_2 + i1_3) - w = T.axis.spatial(28, i2_0 * 28 + i2_1 * 28 + i2_2 * 7 + i2_3) - co = T.axis.spatial(128, i3_0 * 64 + i3_1 * 64 + i3_2 * 16 + i3_3) - rh = T.axis.reduce(3, i4_0 * 3 + i4_1) - rw = T.axis.reduce(3, i5_0 + i5_1) - rc = T.axis.reduce(16, i6_0 * 2 + i6_1) - T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, co // 32 * 16 + rc], weight[rh, rw, rc, co]) - T.writes(conv2d_nhwc_global[n, h, w, co]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) + v_h = T.axis.spatial(28, h_0 * 4 + h_1 + h_2 + h_3) + v_w = T.axis.spatial(28, w_0 * 28 + w_1 * 28 + w_2 * 7 + w_3) + v_co = T.axis.spatial(128, co_0 * 64 + co_1 * 64 + co_2 * 16 + co_3) + v_rh = T.axis.reduce(3, rh_0 * 3 + rh_1) + v_rw = T.axis.reduce(3, rw_0 + rw_1) + v_rc = T.axis.reduce(16, rc_0 * 2 + rc_1) + T.reads(PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 32 * 16 + v_rc], weight[v_rh, v_rw, v_rc, v_co]) + T.writes(conv2d_nhwc_global[v_n, v_h, v_w, v_co]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - conv2d_nhwc_global[n, h, w, co] = T.float32(0) - conv2d_nhwc_global[n, h, w, co] = conv2d_nhwc_global[n, h, w, co] + PadInput[n, h * 2 + rh, w * 2 + rw, co // 32 * 16 + rc] * weight[rh, rw, rc, co] + conv2d_nhwc_global[v_n, v_h, v_w, v_co] = T.float32(0) + conv2d_nhwc_global[v_n, v_h, v_w, v_co] = conv2d_nhwc_global[v_n, v_h, v_w, v_co] + PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 32 * 16 + v_rc] * weight[v_rh, v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3 in T.grid(1, 1, 28, 64): with T.block("conv2d_nhwc_global"): v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(28, i1_0 * 4 + i1_1 + ax1) + v1 = T.axis.spatial(28, h_0 * 4 + h_1 + ax1) v2 = T.axis.spatial(28, ax2) - v3 = T.axis.spatial(128, i3_0 * 64 + ax3) + v3 = T.axis.spatial(128, co_0 * 64 + ax3) T.reads(conv2d_nhwc_global[v0, v1, v2, v3]) T.writes(conv2d_nhwc[v0, v1, v2, v3]) conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3] @T.prim_func def grp_1(inputs: T.Buffer((1, 56, 56, 64), "float32"), weight: T.Buffer((3, 3, 16, 128), "float32"), conv2d_nhwc: T.Buffer((1, 28, 28, 128), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":512, "meta_schedule.vectorize":64}) - PadInput = T.alloc_buffer([1, 58, 58, 64], dtype="float32") - conv2d_nhwc_global = T.alloc_buffer([1, 28, 28, 128], dtype="float32") + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) + PadInput = T.alloc_buffer((1, 58, 58, 64)) + conv2d_nhwc_global = T.alloc_buffer((1, 28, 28, 128)) for i0, i1, i2, i3 in T.grid(1, 58, 58, 64): with T.block("PadInput"): - i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1]) - T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) - PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i1_1 and i1_1 < 57 and 1 <= i2_1 and i2_1 < 57, inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1], T.float32(0), dtype="float32") - for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 7, 1, 2): - for i0_1_1, i1_1_1, i2_1_1, i3_1_1, i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 4, 1, 1, 1, 3, 8, 1, 1, 4, 4, 3, 1, 2, 1, 1, 7, 16): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) + PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(1 <= v_i1 and v_i1 < 57 and 1 <= v_i2 and v_i2 < 57, inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3], T.float32(0)) + for n_0, h_0, w_0, co_0 in T.grid(1, 7, 1, 2): + for n_1, h_1, w_1, co_1, rh_0, rw_0, rc_0, n_2, h_2, w_2, co_2, rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3 in T.grid(1, 4, 1, 1, 1, 3, 8, 1, 1, 4, 4, 3, 1, 2, 1, 1, 7, 16): with T.block("conv2d_nhwc"): - n = T.axis.spatial(1, i0_3 + i0_0 + i0_1_1 + i0_2) - h = T.axis.spatial(28, i1_0 * 4 + i1_1_1 + i1_2 + i1_3) - w = T.axis.spatial(28, i2_0 * 28 + i2_1_1 * 28 + i2_2 * 7 + i2_3) - co = T.axis.spatial(128, i3_0 * 64 + i3_1_1 * 64 + i3_2 * 16 + i3_3) - rh = T.axis.reduce(3, i4_0 * 3 + i4_1) - rw = T.axis.reduce(3, i5_0 + i5_1) - rc = T.axis.reduce(16, i6_0 * 2 + i6_1) - T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, co // 32 * 16 + rc], weight[rh, rw, rc, co]) - T.writes(conv2d_nhwc_global[n, h, w, co]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) + v_h = T.axis.spatial(28, h_0 * 4 + h_1 + h_2 + h_3) + v_w = T.axis.spatial(28, w_0 * 28 + w_1 * 28 + w_2 * 7 + w_3) + v_co = T.axis.spatial(128, co_0 * 64 + co_1 * 64 + co_2 * 16 + co_3) + v_rh = T.axis.reduce(3, rh_0 * 3 + rh_1) + v_rw = T.axis.reduce(3, rw_0 + rw_1) + v_rc = T.axis.reduce(16, rc_0 * 2 + rc_1) + T.reads(PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 32 * 16 + v_rc], weight[v_rh, v_rw, v_rc, v_co]) + T.writes(conv2d_nhwc_global[v_n, v_h, v_w, v_co]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - conv2d_nhwc_global[n, h, w, co] = T.float32(0) - conv2d_nhwc_global[n, h, w, co] = conv2d_nhwc_global[n, h, w, co] + PadInput[n, h * 2 + rh, w * 2 + rw, co // 32 * 16 + rc] * weight[rh, rw, rc, co] + conv2d_nhwc_global[v_n, v_h, v_w, v_co] = T.float32(0) + conv2d_nhwc_global[v_n, v_h, v_w, v_co] = conv2d_nhwc_global[v_n, v_h, v_w, v_co] + PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 32 * 16 + v_rc] * weight[v_rh, v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3 in T.grid(1, 4, 28, 64): with T.block("conv2d_nhwc_global"): v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(28, i1_0 * 4 + ax1) + v1 = T.axis.spatial(28, h_0 * 4 + ax1) v2 = T.axis.spatial(28, ax2) - v3 = T.axis.spatial(128, i3_0 * 64 + ax3) + v3 = T.axis.spatial(128, co_0 * 64 + ax3) T.reads(conv2d_nhwc_global[v0, v1, v2, v3]) T.writes(conv2d_nhwc[v0, v1, v2, v3]) conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3] @T.prim_func def grp_2(inputs: T.Buffer((1, 56, 56, 64), "float32"), weight: T.Buffer((3, 3, 16, 128), "float32"), conv2d_nhwc: T.Buffer((1, 28, 28, 128), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64}) - PadInput = T.alloc_buffer([1, 58, 58, 64], dtype="float32") - for i0_0, i1_0, i2_0, i3_0, i0_1, i1_1, i2_1, i3_1, i4_0, i5_0 in T.grid(1, 7, 1, 2, 1, 4, 1, 1, 1, 3): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) + PadInput = T.alloc_buffer((1, 58, 58, 64)) + for n_0, h_0, w_0, co_0, n_1, h_1, w_1, co_1, rh_0, rw_0 in T.grid(1, 7, 1, 2, 1, 4, 1, 1, 1, 3): for ax0, ax1, ax2, ax3 in T.grid(1, 3, 55, 32): with T.block("PadInput"): - i0 = T.axis.spatial(1, ax0) - i1 = T.axis.spatial(58, i1_0 * 8 + i1_1 * 2 + ax1) - i2 = T.axis.spatial(58, i5_0 + ax2) - i3 = T.axis.spatial(64, i3_0 * 32 + ax3) - T.reads(inputs[i0, i1 - 1, i2 - 1, i3]) - T.writes(PadInput[i0, i1, i2, i3]) - PadInput[i0, i1, i2, i3] = T.if_then_else(1 <= i1 and i1 < 57 and 1 <= i2 and i2 < 57, inputs[i0, i1 - 1, i2 - 1, i3], T.float32(0), dtype="float32") - for i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(8, 1, 1, 4, 4, 3, 1, 2, 1, 1, 7, 16): + v_i0 = T.axis.spatial(1, ax0) + v_i1 = T.axis.spatial(58, h_0 * 8 + h_1 * 2 + ax1) + v_i2 = T.axis.spatial(58, rw_0 + ax2) + v_i3 = T.axis.spatial(64, co_0 * 32 + ax3) + T.reads(inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) + PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(1 <= v_i1 and v_i1 < 57 and 1 <= v_i2 and v_i2 < 57, inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3], T.float32(0)) + for rc_0, n_2, h_2, w_2, co_2, rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3 in T.grid(8, 1, 1, 4, 4, 3, 1, 2, 1, 1, 7, 16): with T.block("conv2d_nhwc"): - n = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2) - h = T.axis.spatial(28, i1_0 * 4 + i1_1 + i1_2 + i1_3) - w = T.axis.spatial(28, i2_0 * 28 + i2_1 * 28 + i2_2 * 7 + i2_3) - co = T.axis.spatial(128, i3_0 * 64 + i3_1 * 64 + i3_2 * 16 + i3_3) - rh = T.axis.reduce(3, i4_0 * 3 + i4_1) - rw = T.axis.reduce(3, i5_0 + i5_1) - rc = T.axis.reduce(16, i6_0 * 2 + i6_1) - T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, co // 32 * 16 + rc], weight[rh, rw, rc, co]) - T.writes(conv2d_nhwc[n, h, w, co]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) + v_h = T.axis.spatial(28, h_0 * 4 + h_1 + h_2 + h_3) + v_w = T.axis.spatial(28, w_0 * 28 + w_1 * 28 + w_2 * 7 + w_3) + v_co = T.axis.spatial(128, co_0 * 64 + co_1 * 64 + co_2 * 16 + co_3) + v_rh = T.axis.reduce(3, rh_0 * 3 + rh_1) + v_rw = T.axis.reduce(3, rw_0 + rw_1) + v_rc = T.axis.reduce(16, rc_0 * 2 + rc_1) + T.reads(PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 32 * 16 + v_rc], weight[v_rh, v_rw, v_rc, v_co]) + T.writes(conv2d_nhwc[v_n, v_h, v_w, v_co]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - conv2d_nhwc[n, h, w, co] = T.float32(0) - conv2d_nhwc[n, h, w, co] = conv2d_nhwc[n, h, w, co] + PadInput[n, h * 2 + rh, w * 2 + rw, co // 32 * 16 + rc] * weight[rh, rw, rc, co] + conv2d_nhwc[v_n, v_h, v_w, v_co] = T.float32(0) + conv2d_nhwc[v_n, v_h, v_w, v_co] = conv2d_nhwc[v_n, v_h, v_w, v_co] + PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 32 * 16 + v_rc] * weight[v_rh, v_rw, v_rc, v_co] # fmt: on decision_0 = [ ("SamplePerfectTile", [1, 1, 1, 1]), @@ -1352,113 +1304,107 @@ def test_cpu_t2d(): # fmt: off @T.prim_func def t2d_0(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 512, 256), "float32"), conv2d_transpose_nhwc: T.Buffer((1, 8, 8, 256), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":64, "meta_schedule.vectorize":64}) - PadInput = T.alloc_buffer([1, 6, 6, 512], dtype="float32") - conv2d_transpose_nhwc_global = T.alloc_buffer([1, 8, 8, 256], dtype="float32") + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) + PadInput = T.alloc_buffer((1, 6, 6, 512)) + conv2d_transpose_nhwc_global = T.alloc_buffer((1, 8, 8, 256)) for i0, i1, i2, i3 in T.grid(1, 6, 6, 512): with T.block("PadInput"): - i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1]) - T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) - PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i1_1 and i1_1 < 5 and 1 <= i2_1 and i2_1 < 5, inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1], T.float32(0), dtype="float32") - for i0_0, i1_0, i2_0, i3_0, i0_1_1, i1_1_1, i2_1_1, i3_1_1 in T.grid(1, 1, 2, 8, 1, 4, 1, 4): - for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(2, 2, 64, 1, 1, 1, 1, 2, 2, 8, 1, 2, 4, 8): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) + PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(1 <= v_i1 and v_i1 < 5 and 1 <= v_i2 and v_i2 < 5, inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3], T.float32(0)) + for n_0, h_0, w_0, co_0, n_1, h_1, w_1, co_1 in T.grid(1, 1, 2, 8, 1, 4, 1, 4): + for rh_0, rw_0, rc_0, n_2, h_2, w_2, co_2, rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3 in T.grid(2, 2, 64, 1, 1, 1, 1, 2, 2, 8, 1, 2, 4, 8): with T.block("conv2d_transpose_nhwc"): - n = T.axis.spatial(1, i0_3 + i0_0 + i0_1_1 + i0_2) - h = T.axis.spatial(8, i1_0 * 8 + i1_1_1 * 2 + i1_2 * 2 + i1_3) - w = T.axis.spatial(8, i2_0 * 4 + i2_1_1 * 4 + i2_2 * 4 + i2_3) - co = T.axis.spatial(256, i3_0 * 32 + i3_1_1 * 8 + i3_2 * 8 + i3_3) - rh = T.axis.reduce(4, i4_0 * 2 + i4_1) - rw = T.axis.reduce(4, i5_0 * 2 + i5_1) - rc = T.axis.reduce(512, i6_0 * 8 + i6_1) - T.reads(PadInput[n, (h + rh) // 2, (w + rw) // 2, rc], weight[3 - rh, 3 - rw, rc, co]) - T.writes(conv2d_transpose_nhwc_global[n, h, w, co]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) + v_h = T.axis.spatial(8, h_0 * 8 + h_1 * 2 + h_2 * 2 + h_3) + v_w = T.axis.spatial(8, w_0 * 4 + w_1 * 4 + w_2 * 4 + w_3) + v_co = T.axis.spatial(256, co_0 * 32 + co_1 * 8 + co_2 * 8 + co_3) + v_rh = T.axis.reduce(4, rh_0 * 2 + rh_1) + v_rw = T.axis.reduce(4, rw_0 * 2 + rw_1) + v_rc = T.axis.reduce(512, rc_0 * 8 + rc_1) + T.reads(PadInput[v_n, (v_h + v_rh) // 2, (v_w + v_rw) // 2, v_rc], weight[3 - v_rh, 3 - v_rw, v_rc, v_co]) + T.writes(conv2d_transpose_nhwc_global[v_n, v_h, v_w, v_co]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - conv2d_transpose_nhwc_global[n, h, w, co] = T.float32(0) - conv2d_transpose_nhwc_global[n, h, w, co] = conv2d_transpose_nhwc_global[n, h, w, co] + T.if_then_else((h + rh) % 2 == 0 and (w + rw) % 2 == 0, PadInput[n, (h + rh) // 2, (w + rw) // 2, rc], T.float32(0), dtype="float32") * weight[3 - rh, 3 - rw, rc, co] + conv2d_transpose_nhwc_global[v_n, v_h, v_w, v_co] = T.float32(0) + conv2d_transpose_nhwc_global[v_n, v_h, v_w, v_co] = conv2d_transpose_nhwc_global[v_n, v_h, v_w, v_co] + T.if_then_else((v_h + v_rh) % 2 == 0 and (v_w + v_rw) % 2 == 0, PadInput[v_n, (v_h + v_rh) // 2, (v_w + v_rw) // 2, v_rc], T.float32(0)) * weight[3 - v_rh, 3 - v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3 in T.grid(1, 2, 4, 8): with T.block("conv2d_transpose_nhwc_global"): v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(8, i1_1_1 * 2 + ax1) - v2 = T.axis.spatial(8, i2_0 * 4 + ax2) - v3 = T.axis.spatial(256, i3_0 * 32 + i3_1_1 * 8 + ax3) + v1 = T.axis.spatial(8, h_1 * 2 + ax1) + v2 = T.axis.spatial(8, w_0 * 4 + ax2) + v3 = T.axis.spatial(256, co_0 * 32 + co_1 * 8 + ax3) T.reads(conv2d_transpose_nhwc_global[v0, v1, v2, v3]) T.writes(conv2d_transpose_nhwc[v0, v1, v2, v3]) conv2d_transpose_nhwc[v0, v1, v2, v3] = conv2d_transpose_nhwc_global[v0, v1, v2, v3] @T.prim_func def t2d_1(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 512, 256), "float32"), conv2d_transpose_nhwc: T.Buffer((1, 8, 8, 256), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":64, "meta_schedule.vectorize":64}) - PadInput = T.alloc_buffer([1, 6, 6, 512], dtype="float32") - conv2d_transpose_nhwc_global = T.alloc_buffer([1, 8, 8, 256], dtype="float32") - for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 1, 2, 8): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) + PadInput = T.alloc_buffer((1, 6, 6, 512)) + conv2d_transpose_nhwc_global = T.alloc_buffer((1, 8, 8, 256)) + for n_0, h_0, w_0, co_0 in T.grid(1, 1, 2, 8): for ax0, ax1, ax2, ax3 in T.grid(1, 6, 4, 512): with T.block("PadInput"): - i0, i1 = T.axis.remap("SS", [ax0, ax1]) - i2 = T.axis.spatial(6, i2_0 * 2 + ax2) - i3 = T.axis.spatial(512, ax3) - T.reads(inputs[i0, i1 - 1, i2 - 1, i3]) - T.writes(PadInput[i0, i1, i2, i3]) - PadInput[i0, i1, i2, i3] = T.if_then_else(1 <= i1 and i1 < 5 and 1 <= i2 and i2 < 5, inputs[i0, i1 - 1, i2 - 1, i3], T.float32(0), dtype="float32") - for i0_1, i1_1, i2_1, i3_1, i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 4, 1, 4, 2, 2, 64, 1, 1, 1, 1, 2, 2, 8, 1, 2, 4, 8): + v_i0, v_i1 = T.axis.remap("SS", [ax0, ax1]) + v_i2 = T.axis.spatial(6, w_0 * 2 + ax2) + v_i3 = T.axis.spatial(512, ax3) + T.reads(inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) + PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(1 <= v_i1 and v_i1 < 5 and 1 <= v_i2 and v_i2 < 5, inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3], T.float32(0)) + for n_1, h_1, w_1, co_1, rh_0, rw_0, rc_0, n_2, h_2, w_2, co_2, rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3 in T.grid(1, 4, 1, 4, 2, 2, 64, 1, 1, 1, 1, 2, 2, 8, 1, 2, 4, 8): with T.block("conv2d_transpose_nhwc"): - n = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2) - h = T.axis.spatial(8, i1_0 * 8 + i1_1 * 2 + i1_2 * 2 + i1_3) - w = T.axis.spatial(8, i2_0 * 4 + i2_1 * 4 + i2_2 * 4 + i2_3) - co = T.axis.spatial(256, i3_0 * 32 + i3_1 * 8 + i3_2 * 8 + i3_3) - rh = T.axis.reduce(4, i4_0 * 2 + i4_1) - rw = T.axis.reduce(4, i5_0 * 2 + i5_1) - rc = T.axis.reduce(512, i6_0 * 8 + i6_1) - T.reads(PadInput[n, (h + rh) // 2, (w + rw) // 2, rc], weight[3 - rh, 3 - rw, rc, co]) - T.writes(conv2d_transpose_nhwc_global[n, h, w, co]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) + v_h = T.axis.spatial(8, h_0 * 8 + h_1 * 2 + h_2 * 2 + h_3) + v_w = T.axis.spatial(8, w_0 * 4 + w_1 * 4 + w_2 * 4 + w_3) + v_co = T.axis.spatial(256, co_0 * 32 + co_1 * 8 + co_2 * 8 + co_3) + v_rh = T.axis.reduce(4, rh_0 * 2 + rh_1) + v_rw = T.axis.reduce(4, rw_0 * 2 + rw_1) + v_rc = T.axis.reduce(512, rc_0 * 8 + rc_1) + T.reads(PadInput[v_n, (v_h + v_rh) // 2, (v_w + v_rw) // 2, v_rc], weight[3 - v_rh, 3 - v_rw, v_rc, v_co]) + T.writes(conv2d_transpose_nhwc_global[v_n, v_h, v_w, v_co]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - conv2d_transpose_nhwc_global[n, h, w, co] = T.float32(0) - conv2d_transpose_nhwc_global[n, h, w, co] = conv2d_transpose_nhwc_global[n, h, w, co] + T.if_then_else((h + rh) % 2 == 0 and (w + rw) % 2 == 0, PadInput[n, (h + rh) // 2, (w + rw) // 2, rc], T.float32(0), dtype="float32") * weight[3 - rh, 3 - rw, rc, co] + conv2d_transpose_nhwc_global[v_n, v_h, v_w, v_co] = T.float32(0) + conv2d_transpose_nhwc_global[v_n, v_h, v_w, v_co] = conv2d_transpose_nhwc_global[v_n, v_h, v_w, v_co] + T.if_then_else((v_h + v_rh) % 2 == 0 and (v_w + v_rw) % 2 == 0, PadInput[v_n, (v_h + v_rh) // 2, (v_w + v_rw) // 2, v_rc], T.float32(0)) * weight[3 - v_rh, 3 - v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3 in T.grid(1, 8, 4, 32): with T.block("conv2d_transpose_nhwc_global"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(8, i2_0 * 4 + ax2) - v3 = T.axis.spatial(256, i3_0 * 32 + ax3) + v2 = T.axis.spatial(8, w_0 * 4 + ax2) + v3 = T.axis.spatial(256, co_0 * 32 + ax3) T.reads(conv2d_transpose_nhwc_global[v0, v1, v2, v3]) T.writes(conv2d_transpose_nhwc[v0, v1, v2, v3]) conv2d_transpose_nhwc[v0, v1, v2, v3] = conv2d_transpose_nhwc_global[v0, v1, v2, v3] @T.prim_func def t2d_2(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 512, 256), "float32"), conv2d_transpose_nhwc: T.Buffer((1, 8, 8, 256), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":512, "meta_schedule.vectorize":64}) - for i0_0, i1_0, i2_0, i3_0, i0_1, i1_1, i2_1, i3_1, i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 1, 2, 8, 1, 4, 1, 4, 2, 2, 64, 1, 1, 1, 1, 2, 2, 8, 1, 2, 4, 8): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) + for n_0, h_0, w_0, co_0, n_1, h_1, w_1, co_1, rh_0, rw_0, rc_0, n_2, h_2, w_2, co_2, rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3 in T.grid(1, 1, 2, 8, 1, 4, 1, 4, 2, 2, 64, 1, 1, 1, 1, 2, 2, 8, 1, 2, 4, 8): with T.block("conv2d_transpose_nhwc"): - n = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2) - h = T.axis.spatial(8, i1_0 * 8 + i1_1 * 2 + i1_2 * 2 + i1_3) - w = T.axis.spatial(8, i2_0 * 4 + i2_1 * 4 + i2_2 * 4 + i2_3) - co = T.axis.spatial(256, i3_0 * 32 + i3_1 * 8 + i3_2 * 8 + i3_3) - rh = T.axis.reduce(4, i4_0 * 2 + i4_1) - rw = T.axis.reduce(4, i5_0 * 2 + i5_1) - rc = T.axis.reduce(512, i6_0 * 8 + i6_1) - T.reads(inputs[n, (h + rh) // 2 - 1, (w + rw) // 2 - 1, rc], weight[3 - rh, 3 - rw, rc, co]) - T.writes(conv2d_transpose_nhwc[n, h, w, co]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) + v_h = T.axis.spatial(8, h_0 * 8 + h_1 * 2 + h_2 * 2 + h_3) + v_w = T.axis.spatial(8, w_0 * 4 + w_1 * 4 + w_2 * 4 + w_3) + v_co = T.axis.spatial(256, co_0 * 32 + co_1 * 8 + co_2 * 8 + co_3) + v_rh = T.axis.reduce(4, rh_0 * 2 + rh_1) + v_rw = T.axis.reduce(4, rw_0 * 2 + rw_1) + v_rc = T.axis.reduce(512, rc_0 * 8 + rc_1) + T.reads(inputs[v_n, (v_h + v_rh) // 2 - 1, (v_w + v_rw) // 2 - 1, v_rc], weight[3 - v_rh, 3 - v_rw, v_rc, v_co]) + T.writes(conv2d_transpose_nhwc[v_n, v_h, v_w, v_co]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - conv2d_transpose_nhwc[n, h, w, co] = T.float32(0) - conv2d_transpose_nhwc[n, h, w, co] = conv2d_transpose_nhwc[n, h, w, co] + T.if_then_else((h + rh) % 2 == 0 and (w + rw) % 2 == 0, T.if_then_else(1 <= (h + rh) // 2 and (h + rh) // 2 < 5 and 1 <= (w + rw) // 2 and (w + rw) // 2 < 5, inputs[n, (h + rh) // 2 - 1, (w + rw) // 2 - 1, rc], T.float32(0), dtype="float32"), T.float32(0), dtype="float32") * weight[3 - rh, 3 - rw, rc, co] + conv2d_transpose_nhwc[v_n, v_h, v_w, v_co] = T.float32(0) + conv2d_transpose_nhwc[v_n, v_h, v_w, v_co] = conv2d_transpose_nhwc[v_n, v_h, v_w, v_co] + T.if_then_else((v_h + v_rh) % 2 == 0 and (v_w + v_rw) % 2 == 0, T.if_then_else(1 <= (v_h + v_rh) // 2 and (v_h + v_rh) // 2 < 5 and 1 <= (v_w + v_rw) // 2 and (v_w + v_rw) // 2 < 5, inputs[v_n, (v_h + v_rh) // 2 - 1, (v_w + v_rw) // 2 - 1, v_rc], T.float32(0)), T.float32(0)) * weight[3 - v_rh, 3 - v_rw, v_rc, v_co] # fmt: on decision_0 = [ ("SamplePerfectTile", [1, 1, 1, 1]), @@ -1508,94 +1454,88 @@ def test_cpu_nrm(): # fmt: off @T.prim_func def nrm_0(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer(1, "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":0, "meta_schedule.vectorize":64}) - C = T.alloc_buffer([1], dtype="float32") - C_rf = T.alloc_buffer([1, 32768], dtype="float32") - for i0, i1_i2_fused_0, i1_i2_fused_1 in T.grid(1, 32768, 2): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) + C = T.alloc_buffer((1,)) + C_rf = T.alloc_buffer((1, 32768)) + for b, i_j_fused_0, i_j_fused_1 in T.grid(1, 32768, 2): with T.block("C_rf"): - vi1_i2_fused_0, b, vi1_i2_fused_1 = T.axis.remap("SSR", [i1_i2_fused_0, i0, i1_i2_fused_1]) - T.reads(A[b, (vi1_i2_fused_0 * 2 + vi1_i2_fused_1) // 256, (vi1_i2_fused_0 * 2 + vi1_i2_fused_1) % 256]) - T.writes(C_rf[b, vi1_i2_fused_0]) + vi_j_fused_0, v_b, vi_j_fused_1 = T.axis.remap("SSR", [i_j_fused_0, b, i_j_fused_1]) + T.reads(A[v_b, (vi_j_fused_0 * 2 + vi_j_fused_1) // 256, (vi_j_fused_0 * 2 + vi_j_fused_1) % 256]) + T.writes(C_rf[v_b, vi_j_fused_0]) with T.init(): - C_rf[b, vi1_i2_fused_0] = T.float32(0) - C_rf[b, vi1_i2_fused_0] = C_rf[b, vi1_i2_fused_0] + A[b, (vi1_i2_fused_0 * 2 + vi1_i2_fused_1) // 256, (vi1_i2_fused_0 * 2 + vi1_i2_fused_1) % 256] * A[b, (vi1_i2_fused_0 * 2 + vi1_i2_fused_1) // 256, (vi1_i2_fused_0 * 2 + vi1_i2_fused_1) % 256] - for i0, i1_i2_fused_0 in T.grid(1, 32768): + C_rf[v_b, vi_j_fused_0] = T.float32(0) + C_rf[v_b, vi_j_fused_0] = C_rf[v_b, vi_j_fused_0] + A[v_b, (vi_j_fused_0 * 2 + vi_j_fused_1) // 256, (vi_j_fused_0 * 2 + vi_j_fused_1) % 256] * A[v_b, (vi_j_fused_0 * 2 + vi_j_fused_1) // 256, (vi_j_fused_0 * 2 + vi_j_fused_1) % 256] + for b, i_j_fused_0 in T.grid(1, 32768): with T.block("C"): - vi1_i2_fused_0, b = T.axis.remap("RS", [i1_i2_fused_0, i0]) - T.reads(C_rf[b, vi1_i2_fused_0]) - T.writes(C[b]) + vi_j_fused_0, v_b = T.axis.remap("RS", [i_j_fused_0, b]) + T.reads(C_rf[v_b, vi_j_fused_0]) + T.writes(C[v_b]) with T.init(): - C[b] = T.float32(0) - C[b] = C[b] + C_rf[b, vi1_i2_fused_0] - for i0 in T.serial(1): + C[v_b] = T.float32(0) + C[v_b] = C[v_b] + C_rf[v_b, vi_j_fused_0] + for b in range(1): with T.block("D"): - b = T.axis.spatial(1, i0) - T.reads(C[b]) - T.writes(D[b]) - D[b] = T.sqrt(C[b], dtype="float32") + v_b = T.axis.spatial(1, b) + T.reads(C[v_b]) + T.writes(D[v_b]) + D[v_b] = T.sqrt(C[v_b]) @T.prim_func def nrm_1(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer(1, "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64}) - C = T.alloc_buffer([1], dtype="float32") - C_rf = T.alloc_buffer([1, 2], dtype="float32") - for i0, i1_i2_fused_0, i1_i2_fused_1 in T.grid(1, 32768, 2): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) + C = T.alloc_buffer((1,)) + C_rf = T.alloc_buffer((1, 2)) + for b, i_j_fused_0, i_j_fused_1 in T.grid(1, 32768, 2): with T.block("C_rf"): - vi1_i2_fused_1, b, vi1_i2_fused_0 = T.axis.remap("SSR", [i1_i2_fused_1, i0, i1_i2_fused_0]) - T.reads(A[b, (vi1_i2_fused_0 * 2 + vi1_i2_fused_1) // 256, (vi1_i2_fused_0 * 2 + vi1_i2_fused_1) % 256]) - T.writes(C_rf[b, vi1_i2_fused_1]) + vi_j_fused_1, v_b, vi_j_fused_0 = T.axis.remap("SSR", [i_j_fused_1, b, i_j_fused_0]) + T.reads(A[v_b, (vi_j_fused_0 * 2 + vi_j_fused_1) // 256, (vi_j_fused_0 * 2 + vi_j_fused_1) % 256]) + T.writes(C_rf[v_b, vi_j_fused_1]) with T.init(): - C_rf[b, vi1_i2_fused_1] = T.float32(0) - C_rf[b, vi1_i2_fused_1] = C_rf[b, vi1_i2_fused_1] + A[b, (vi1_i2_fused_0 * 2 + vi1_i2_fused_1) // 256, (vi1_i2_fused_0 * 2 + vi1_i2_fused_1) % 256] * A[b, (vi1_i2_fused_0 * 2 + vi1_i2_fused_1) // 256, (vi1_i2_fused_0 * 2 + vi1_i2_fused_1) % 256] - for i0, i1_i2_fused_1 in T.grid(1, 2): + C_rf[v_b, vi_j_fused_1] = T.float32(0) + C_rf[v_b, vi_j_fused_1] = C_rf[v_b, vi_j_fused_1] + A[v_b, (vi_j_fused_0 * 2 + vi_j_fused_1) // 256, (vi_j_fused_0 * 2 + vi_j_fused_1) % 256] * A[v_b, (vi_j_fused_0 * 2 + vi_j_fused_1) // 256, (vi_j_fused_0 * 2 + vi_j_fused_1) % 256] + for b, i_j_fused_1 in T.grid(1, 2): with T.block("C"): - vi1_i2_fused_1, b = T.axis.remap("RS", [i1_i2_fused_1, i0]) - T.reads(C_rf[b, vi1_i2_fused_1]) - T.writes(C[b]) + vi_j_fused_1, v_b = T.axis.remap("RS", [i_j_fused_1, b]) + T.reads(C_rf[v_b, vi_j_fused_1]) + T.writes(C[v_b]) with T.init(): - C[b] = T.float32(0) - C[b] = C[b] + C_rf[b, vi1_i2_fused_1] - for i0 in T.serial(1): + C[v_b] = T.float32(0) + C[v_b] = C[v_b] + C_rf[v_b, vi_j_fused_1] + for b in range(1): with T.block("D"): - b = T.axis.spatial(1, i0) - T.reads(C[b]) - T.writes(D[b]) - D[b] = T.sqrt(C[b], dtype="float32") + v_b = T.axis.spatial(1, b) + T.reads(C[v_b]) + T.writes(D[v_b]) + D[v_b] = T.sqrt(C[v_b]) @T.prim_func def nrm_2(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer(1, "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":0, "meta_schedule.vectorize":64}) - C = T.alloc_buffer([1], dtype="float32") - for i0, i1, i2 in T.grid(1, 256, 256): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) + C = T.alloc_buffer((1,)) + for b, i, j in T.grid(1, 256, 256): with T.block("C"): - b, i, j = T.axis.remap("SRR", [i0, i1, i2]) - T.reads(A[b, i, j]) - T.writes(C[b]) + v_b, v_i, v_j = T.axis.remap("SRR", [b, i, j]) + T.reads(A[v_b, v_i, v_j]) + T.writes(C[v_b]) with T.init(): - C[b] = T.float32(0) - C[b] = C[b] + A[b, i, j] * A[b, i, j] - for i0 in T.serial(1): + C[v_b] = T.float32(0) + C[v_b] = C[v_b] + A[v_b, v_i, v_j] * A[v_b, v_i, v_j] + for b in range(1): with T.block("D"): - b = T.axis.spatial(1, i0) - T.reads(C[b]) - T.writes(D[b]) - D[b] = T.sqrt(C[b], dtype="float32") + v_b = T.axis.spatial(1, b) + T.reads(C[v_b]) + T.writes(D[v_b]) + D[v_b] = T.sqrt(C[v_b]) # fmt: on decision_0 = [ ("SamplePerfectTile", [32768, 2]), @@ -1627,482 +1567,464 @@ def test_cpu_sfm(): # fmt: off @T.prim_func def sfm_0(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":0, "meta_schedule.vectorize":64}) - T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") - T_softmax_expsum = T.alloc_buffer([256], dtype="float32") - T_softmax_expsum_rf = T.alloc_buffer([256, 16], dtype="float32") - T_softmax_maxelem_rf = T.alloc_buffer([256, 4], dtype="float32") - for i0, i1_0, i1_1 in T.grid(256, 4, 64): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) + T_softmax_maxelem = T.alloc_buffer((256,)) + T_softmax_expsum = T.alloc_buffer((256,)) + T_softmax_expsum_rf = T.alloc_buffer((256, 16)) + T_softmax_maxelem_rf = T.alloc_buffer((256, 4)) + for i0, k_0, k_1 in T.grid(256, 4, 64): with T.block("T_softmax_maxelem_rf"): - vi1_0, i0_1, vi1_1 = T.axis.remap("SSR", [i1_0, i0, i1_1]) - T.reads(A[i0_1, vi1_0 * 64 + vi1_1]) - T.writes(T_softmax_maxelem_rf[i0_1, vi1_0]) + vk_0, v_i0, vk_1 = T.axis.remap("SSR", [k_0, i0, k_1]) + T.reads(A[v_i0, vk_0 * 64 + vk_1]) + T.writes(T_softmax_maxelem_rf[v_i0, vk_0]) with T.init(): - T_softmax_maxelem_rf[i0_1, vi1_0] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem_rf[i0_1, vi1_0] = T.max(T_softmax_maxelem_rf[i0_1, vi1_0], A[i0_1, vi1_0 * 64 + vi1_1]) - for i0, i1_0 in T.grid(256, 4): + T_softmax_maxelem_rf[v_i0, vk_0] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem_rf[v_i0, vk_0] = T.max(T_softmax_maxelem_rf[v_i0, vk_0], A[v_i0, vk_0 * 64 + vk_1]) + for i0, k_0 in T.grid(256, 4): with T.block("T_softmax_maxelem"): - vi1_0, i0_2 = T.axis.remap("RS", [i1_0, i0]) - T.reads(T_softmax_maxelem_rf[i0_2, vi1_0]) - T.writes(T_softmax_maxelem[i0_2]) + vk_0, v_i0 = T.axis.remap("RS", [k_0, i0]) + T.reads(T_softmax_maxelem_rf[v_i0, vk_0]) + T.writes(T_softmax_maxelem[v_i0]) with T.init(): - T_softmax_maxelem[i0_2] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem[i0_2] = T.max(T_softmax_maxelem[i0_2], T_softmax_maxelem_rf[i0_2, vi1_0]) - for i0_3, i1_0, i1_1 in T.grid(256, 16, 16): + T_softmax_maxelem[v_i0] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0] = T.max(T_softmax_maxelem[v_i0], T_softmax_maxelem_rf[v_i0, vk_0]) + for i0, k_0, k_1 in T.grid(256, 16, 16): with T.block("T_softmax_expsum_rf"): - vi1_0, i0_4, vi1_1 = T.axis.remap("SSR", [i1_0, i0_3, i1_1]) - T.reads(A[i0_4, vi1_0 * 16 + vi1_1], T_softmax_maxelem[i0_4]) - T.writes(T_softmax_expsum_rf[i0_4, vi1_0]) + vk_0, v_i0, vk_1 = T.axis.remap("SSR", [k_0, i0, k_1]) + T.reads(A[v_i0, vk_0 * 16 + vk_1], T_softmax_maxelem[v_i0]) + T.writes(T_softmax_expsum_rf[v_i0, vk_0]) with T.init(): - T_softmax_expsum_rf[i0_4, vi1_0] = T.float32(0) - T_softmax_expsum_rf[i0_4, vi1_0] = T_softmax_expsum_rf[i0_4, vi1_0] + T.exp(A[i0_4, vi1_0 * 16 + vi1_1] - T_softmax_maxelem[i0_4], dtype="float32") - for i0_5, i1 in T.grid(256, 256): + T_softmax_expsum_rf[v_i0, vk_0] = T.float32(0) + T_softmax_expsum_rf[v_i0, vk_0] = T_softmax_expsum_rf[v_i0, vk_0] + T.exp(A[v_i0, vk_0 * 16 + vk_1] - T_softmax_maxelem[v_i0]) + for i0, i1 in T.grid(256, 256): for ax0, ax1 in T.grid(16, 1): with T.block("T_softmax_expsum"): - vi1_0 = T.axis.reduce(16, ax0) - i0_6 = T.axis.spatial(256, i0_5 + ax1) - T.reads(T_softmax_expsum_rf[i0_6, vi1_0]) - T.writes(T_softmax_expsum[i0_6]) + vk_0 = T.axis.reduce(16, ax0) + v_i0 = T.axis.spatial(256, i0 + ax1) + T.reads(T_softmax_expsum_rf[v_i0, vk_0]) + T.writes(T_softmax_expsum[v_i0]) with T.init(): - T_softmax_expsum[i0_6] = T.float32(0) - T_softmax_expsum[i0_6] = T_softmax_expsum[i0_6] + T_softmax_expsum_rf[i0_6, vi1_0] + T_softmax_expsum[v_i0] = T.float32(0) + T_softmax_expsum[v_i0] = T_softmax_expsum[v_i0] + T_softmax_expsum_rf[v_i0, vk_0] with T.block("T_softmax_norm"): - i0_7, i1_2 = T.axis.remap("SS", [i0_5, i1]) - T.reads(A[i0_7, i1_2], T_softmax_maxelem[i0_7], T_softmax_expsum[i0_7]) - T.writes(T_softmax_norm[i0_7, i1_2]) - T.block_attr({"axis":1}) - T_softmax_norm[i0_7, i1_2] = T.exp(A[i0_7, i1_2] - T_softmax_maxelem[i0_7], dtype="float32") / T_softmax_expsum[i0_7] + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[v_i0, v_i1], T_softmax_maxelem[v_i0], T_softmax_expsum[v_i0]) + T.writes(T_softmax_norm[v_i0, v_i1]) + T.block_attr({"axis": 1}) + T_softmax_norm[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) / T_softmax_expsum[v_i0] @T.prim_func def sfm_1(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64}) - T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") - T_softmax_exp = T.alloc_buffer([256, 256], dtype="float32") - T_softmax_expsum = T.alloc_buffer([256], dtype="float32") - T_softmax_expsum_rf = T.alloc_buffer([256, 16], dtype="float32") - T_softmax_maxelem_rf = T.alloc_buffer([256, 64], dtype="float32") - for i0 in T.serial(256): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) + T_softmax_maxelem = T.alloc_buffer((256,)) + T_softmax_exp = T.alloc_buffer((256, 256)) + T_softmax_expsum = T.alloc_buffer((256,)) + T_softmax_expsum_rf = T.alloc_buffer((256, 16)) + T_softmax_maxelem_rf = T.alloc_buffer((256, 64)) + for i0 in range(256): for ax0, ax1, ax2 in T.grid(64, 1, 4): with T.block("T_softmax_maxelem_rf"): - vi1_1 = T.axis.spatial(64, ax0) - i0_1 = T.axis.spatial(256, i0 + ax1) - vi1_0 = T.axis.reduce(4, ax2) - T.reads(A[i0_1, vi1_0 * 64 + vi1_1]) - T.writes(T_softmax_maxelem_rf[i0_1, vi1_1]) + vk_1 = T.axis.spatial(64, ax0) + v_i0 = T.axis.spatial(256, i0 + ax1) + vk_0 = T.axis.reduce(4, ax2) + T.reads(A[v_i0, vk_0 * 64 + vk_1]) + T.writes(T_softmax_maxelem_rf[v_i0, vk_1]) with T.init(): - T_softmax_maxelem_rf[i0_1, vi1_1] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem_rf[i0_1, vi1_1] = T.max(T_softmax_maxelem_rf[i0_1, vi1_1], A[i0_1, vi1_0 * 64 + vi1_1]) - for i1 in T.serial(256): + T_softmax_maxelem_rf[v_i0, vk_1] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem_rf[v_i0, vk_1] = T.max(T_softmax_maxelem_rf[v_i0, vk_1], A[v_i0, vk_0 * 64 + vk_1]) + for i1 in range(256): for ax0, ax1 in T.grid(64, 1): with T.block("T_softmax_maxelem"): - vi1_1 = T.axis.reduce(64, ax0) - i0_2 = T.axis.spatial(256, i0 + ax1) - T.reads(T_softmax_maxelem_rf[i0_2, vi1_1]) - T.writes(T_softmax_maxelem[i0_2]) + vk_1 = T.axis.reduce(64, ax0) + v_i0 = T.axis.spatial(256, i0 + ax1) + T.reads(T_softmax_maxelem_rf[v_i0, vk_1]) + T.writes(T_softmax_maxelem[v_i0]) with T.init(): - T_softmax_maxelem[i0_2] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem[i0_2] = T.max(T_softmax_maxelem[i0_2], T_softmax_maxelem_rf[i0_2, vi1_1]) + T_softmax_maxelem[v_i0] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0] = T.max(T_softmax_maxelem[v_i0], T_softmax_maxelem_rf[v_i0, vk_1]) with T.block("T_softmax_exp"): - i0_3, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(A[i0_3, i1_1], T_softmax_maxelem[i0_3]) - T.writes(T_softmax_exp[i0_3, i1_1]) - T_softmax_exp[i0_3, i1_1] = T.exp(A[i0_3, i1_1] - T_softmax_maxelem[i0_3], dtype="float32") - for i0_4, i1_0, i1_1_1 in T.grid(256, 16, 16): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[v_i0, v_i1], T_softmax_maxelem[v_i0]) + T.writes(T_softmax_exp[v_i0, v_i1]) + T_softmax_exp[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) + for i0, k_0, k_1 in T.grid(256, 16, 16): with T.block("T_softmax_expsum_rf"): - vi1_0, i0_5, vi1_1 = T.axis.remap("SSR", [i1_0, i0_4, i1_1_1]) - T.reads(T_softmax_exp[i0_5, vi1_0 * 16 + vi1_1]) - T.writes(T_softmax_expsum_rf[i0_5, vi1_0]) + vk_0, v_i0, vk_1 = T.axis.remap("SSR", [k_0, i0, k_1]) + T.reads(T_softmax_exp[v_i0, vk_0 * 16 + vk_1]) + T.writes(T_softmax_expsum_rf[v_i0, vk_0]) with T.init(): - T_softmax_expsum_rf[i0_5, vi1_0] = T.float32(0) - T_softmax_expsum_rf[i0_5, vi1_0] = T_softmax_expsum_rf[i0_5, vi1_0] + T_softmax_exp[i0_5, vi1_0 * 16 + vi1_1] - for i0_6, i1_0 in T.grid(256, 16): + T_softmax_expsum_rf[v_i0, vk_0] = T.float32(0) + T_softmax_expsum_rf[v_i0, vk_0] = T_softmax_expsum_rf[v_i0, vk_0] + T_softmax_exp[v_i0, vk_0 * 16 + vk_1] + for i0, k_0 in T.grid(256, 16): with T.block("T_softmax_expsum"): - vi1_0, i0_7 = T.axis.remap("RS", [i1_0, i0_6]) - T.reads(T_softmax_expsum_rf[i0_7, vi1_0]) - T.writes(T_softmax_expsum[i0_7]) + vk_0, v_i0 = T.axis.remap("RS", [k_0, i0]) + T.reads(T_softmax_expsum_rf[v_i0, vk_0]) + T.writes(T_softmax_expsum[v_i0]) with T.init(): - T_softmax_expsum[i0_7] = T.float32(0) - T_softmax_expsum[i0_7] = T_softmax_expsum[i0_7] + T_softmax_expsum_rf[i0_7, vi1_0] - for i0_8, i1 in T.grid(256, 256): + T_softmax_expsum[v_i0] = T.float32(0) + T_softmax_expsum[v_i0] = T_softmax_expsum[v_i0] + T_softmax_expsum_rf[v_i0, vk_0] + for i0, i1 in T.grid(256, 256): with T.block("T_softmax_norm"): - i0_9, i1_2 = T.axis.remap("SS", [i0_8, i1]) - T.reads(T_softmax_exp[i0_9, i1_2], T_softmax_expsum[i0_9]) - T.writes(T_softmax_norm[i0_9, i1_2]) - T.block_attr({"axis":1}) - T_softmax_norm[i0_9, i1_2] = T_softmax_exp[i0_9, i1_2] / T_softmax_expsum[i0_9] + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(T_softmax_exp[v_i0, v_i1], T_softmax_expsum[v_i0]) + T.writes(T_softmax_norm[v_i0, v_i1]) + T.block_attr({"axis": 1}) + T_softmax_norm[v_i0, v_i1] = T_softmax_exp[v_i0, v_i1] / T_softmax_expsum[v_i0] @T.prim_func def sfm_2(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":512, "meta_schedule.vectorize":64}) - T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") - T_softmax_expsum = T.alloc_buffer([256], dtype="float32") - T_softmax_expsum_rf = T.alloc_buffer([256, 16], dtype="float32") - for i0, i1 in T.grid(256, 256): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) + T_softmax_maxelem = T.alloc_buffer((256,)) + T_softmax_expsum = T.alloc_buffer((256,)) + T_softmax_expsum_rf = T.alloc_buffer((256, 16)) + for i0, k in T.grid(256, 256): with T.block("T_softmax_maxelem"): - i0_1, k = T.axis.remap("SR", [i0, i1]) - T.reads(A[i0_1, k]) - T.writes(T_softmax_maxelem[i0_1]) + v_i0, v_k = T.axis.remap("SR", [i0, k]) + T.reads(A[v_i0, v_k]) + T.writes(T_softmax_maxelem[v_i0]) with T.init(): - T_softmax_maxelem[i0_1] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k]) - for i0, i1_0, i1_1 in T.grid(256, 16, 16): + T_softmax_maxelem[v_i0] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0] = T.max(T_softmax_maxelem[v_i0], A[v_i0, v_k]) + for i0, k_0, k_1 in T.grid(256, 16, 16): with T.block("T_softmax_expsum_rf"): - vi1_0, i0_2, vi1_1 = T.axis.remap("SSR", [i1_0, i0, i1_1]) - T.reads(A[i0_2, vi1_0 * 16 + vi1_1], T_softmax_maxelem[i0_2]) - T.writes(T_softmax_expsum_rf[i0_2, vi1_0]) + vk_0, v_i0, vk_1 = T.axis.remap("SSR", [k_0, i0, k_1]) + T.reads(A[v_i0, vk_0 * 16 + vk_1], T_softmax_maxelem[v_i0]) + T.writes(T_softmax_expsum_rf[v_i0, vk_0]) with T.init(): - T_softmax_expsum_rf[i0_2, vi1_0] = T.float32(0) - T_softmax_expsum_rf[i0_2, vi1_0] = T_softmax_expsum_rf[i0_2, vi1_0] + T.exp(A[i0_2, vi1_0 * 16 + vi1_1] - T_softmax_maxelem[i0_2], dtype="float32") - for i0_3, i1_0 in T.grid(256, 16): + T_softmax_expsum_rf[v_i0, vk_0] = T.float32(0) + T_softmax_expsum_rf[v_i0, vk_0] = T_softmax_expsum_rf[v_i0, vk_0] + T.exp(A[v_i0, vk_0 * 16 + vk_1] - T_softmax_maxelem[v_i0]) + for i0, k_0 in T.grid(256, 16): with T.block("T_softmax_expsum"): - vi1_0, i0_4 = T.axis.remap("RS", [i1_0, i0_3]) - T.reads(T_softmax_expsum_rf[i0_4, vi1_0]) - T.writes(T_softmax_expsum[i0_4]) + vk_0, v_i0 = T.axis.remap("RS", [k_0, i0]) + T.reads(T_softmax_expsum_rf[v_i0, vk_0]) + T.writes(T_softmax_expsum[v_i0]) with T.init(): - T_softmax_expsum[i0_4] = T.float32(0) - T_softmax_expsum[i0_4] = T_softmax_expsum[i0_4] + T_softmax_expsum_rf[i0_4, vi1_0] - for i0_5, i1 in T.grid(256, 256): + T_softmax_expsum[v_i0] = T.float32(0) + T_softmax_expsum[v_i0] = T_softmax_expsum[v_i0] + T_softmax_expsum_rf[v_i0, vk_0] + for i0, i1 in T.grid(256, 256): with T.block("T_softmax_norm"): - i0_6, i1_2 = T.axis.remap("SS", [i0_5, i1]) - T.reads(A[i0_6, i1_2], T_softmax_maxelem[i0_6], T_softmax_expsum[i0_6]) - T.writes(T_softmax_norm[i0_6, i1_2]) - T.block_attr({"axis":1}) - T_softmax_norm[i0_6, i1_2] = T.exp(A[i0_6, i1_2] - T_softmax_maxelem[i0_6], dtype="float32") / T_softmax_expsum[i0_6] + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[v_i0, v_i1], T_softmax_maxelem[v_i0], T_softmax_expsum[v_i0]) + T.writes(T_softmax_norm[v_i0, v_i1]) + T.block_attr({"axis": 1}) + T_softmax_norm[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) / T_softmax_expsum[v_i0] @T.prim_func def sfm_3(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":512, "meta_schedule.vectorize":64}) - T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") - T_softmax_exp = T.alloc_buffer([256, 256], dtype="float32") - T_softmax_expsum = T.alloc_buffer([256], dtype="float32") - T_softmax_expsum_rf = T.alloc_buffer([256, 16], dtype="float32") - T_softmax_maxelem_rf = T.alloc_buffer([256, 256], dtype="float32") + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) + T_softmax_maxelem = T.alloc_buffer((256,)) + T_softmax_exp = T.alloc_buffer((256, 256)) + T_softmax_expsum = T.alloc_buffer((256,)) + T_softmax_expsum_rf = T.alloc_buffer((256, 16)) + T_softmax_maxelem_rf = T.alloc_buffer((256, 256)) for i0, i1 in T.grid(256, 256): for ax0, ax1, ax2 in T.grid(256, 1, 1): with T.block("T_softmax_maxelem_rf"): - vi1_0 = T.axis.spatial(256, ax0) - i0_1 = T.axis.spatial(256, i0 + ax1) - vi1_1 = T.axis.reduce(1, ax2) - T.reads(A[i0_1, vi1_1 + vi1_0]) - T.writes(T_softmax_maxelem_rf[i0_1, vi1_0]) + vk_0 = T.axis.spatial(256, ax0) + v_i0 = T.axis.spatial(256, i0 + ax1) + vk_1 = T.axis.reduce(1, ax2) + T.reads(A[v_i0, vk_0 + vk_1]) + T.writes(T_softmax_maxelem_rf[v_i0, vk_0]) with T.init(): - T_softmax_maxelem_rf[i0_1, vi1_0] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem_rf[i0_1, vi1_0] = T.max(T_softmax_maxelem_rf[i0_1, vi1_0], A[i0_1, vi1_1 + vi1_0]) + T_softmax_maxelem_rf[v_i0, vk_0] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem_rf[v_i0, vk_0] = T.max(T_softmax_maxelem_rf[v_i0, vk_0], A[v_i0, vk_0 + vk_1]) for ax0, ax1 in T.grid(256, 1): with T.block("T_softmax_maxelem"): - vi1_0 = T.axis.reduce(256, ax0) - i0_2 = T.axis.spatial(256, i0 + ax1) - T.reads(T_softmax_maxelem_rf[i0_2, vi1_0]) - T.writes(T_softmax_maxelem[i0_2]) + vk_0 = T.axis.reduce(256, ax0) + v_i0 = T.axis.spatial(256, i0 + ax1) + T.reads(T_softmax_maxelem_rf[v_i0, vk_0]) + T.writes(T_softmax_maxelem[v_i0]) with T.init(): - T_softmax_maxelem[i0_2] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem[i0_2] = T.max(T_softmax_maxelem[i0_2], T_softmax_maxelem_rf[i0_2, vi1_0]) + T_softmax_maxelem[v_i0] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0] = T.max(T_softmax_maxelem[v_i0], T_softmax_maxelem_rf[v_i0, vk_0]) for ax0, ax1 in T.grid(1, 256): with T.block("T_softmax_exp"): - i0_3 = T.axis.spatial(256, i0 + ax0) - i1_1 = T.axis.spatial(256, ax1) - T.reads(A[i0_3, i1_1], T_softmax_maxelem[i0_3]) - T.writes(T_softmax_exp[i0_3, i1_1]) - T_softmax_exp[i0_3, i1_1] = T.exp(A[i0_3, i1_1] - T_softmax_maxelem[i0_3], dtype="float32") - for ax0 in T.serial(16): + v_i0 = T.axis.spatial(256, i0 + ax0) + v_i1 = T.axis.spatial(256, ax1) + T.reads(A[v_i0, v_i1], T_softmax_maxelem[v_i0]) + T.writes(T_softmax_exp[v_i0, v_i1]) + T_softmax_exp[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) + for ax0 in range(16): for ax0_1, ax1, ax2 in T.grid(1, 1, 16): with T.block("T_softmax_expsum_rf"): - vi1_1 = T.axis.spatial(16, ax0 + ax0_1) - i0_4 = T.axis.spatial(256, i0 + ax1) - vi1_0 = T.axis.reduce(16, ax2) - T.reads(T_softmax_exp[i0_4, vi1_0 * 16 + vi1_1]) - T.writes(T_softmax_expsum_rf[i0_4, vi1_1]) + vk_1 = T.axis.spatial(16, ax0 + ax0_1) + v_i0 = T.axis.spatial(256, i0 + ax1) + vk_0 = T.axis.reduce(16, ax2) + T.reads(T_softmax_exp[v_i0, vk_0 * 16 + vk_1]) + T.writes(T_softmax_expsum_rf[v_i0, vk_1]) with T.init(): - T_softmax_expsum_rf[i0_4, vi1_1] = T.float32(0) - T_softmax_expsum_rf[i0_4, vi1_1] = T_softmax_expsum_rf[i0_4, vi1_1] + T_softmax_exp[i0_4, vi1_0 * 16 + vi1_1] - for ax1 in T.serial(1): + T_softmax_expsum_rf[v_i0, vk_1] = T.float32(0) + T_softmax_expsum_rf[v_i0, vk_1] = T_softmax_expsum_rf[v_i0, vk_1] + T_softmax_exp[v_i0, vk_0 * 16 + vk_1] + for ax1 in range(1): with T.block("T_softmax_expsum"): - vi1_1 = T.axis.reduce(16, ax0) - i0_5 = T.axis.spatial(256, i0 + ax1) - T.reads(T_softmax_expsum_rf[i0_5, vi1_1]) - T.writes(T_softmax_expsum[i0_5]) + vk_1 = T.axis.reduce(16, ax0) + v_i0 = T.axis.spatial(256, i0 + ax1) + T.reads(T_softmax_expsum_rf[v_i0, vk_1]) + T.writes(T_softmax_expsum[v_i0]) with T.init(): - T_softmax_expsum[i0_5] = T.float32(0) - T_softmax_expsum[i0_5] = T_softmax_expsum[i0_5] + T_softmax_expsum_rf[i0_5, vi1_1] + T_softmax_expsum[v_i0] = T.float32(0) + T_softmax_expsum[v_i0] = T_softmax_expsum[v_i0] + T_softmax_expsum_rf[v_i0, vk_1] with T.block("T_softmax_norm"): - i0_6, i1_2 = T.axis.remap("SS", [i0, i1]) - T.reads(T_softmax_exp[i0_6, i1_2], T_softmax_expsum[i0_6]) - T.writes(T_softmax_norm[i0_6, i1_2]) - T.block_attr({"axis":1}) - T_softmax_norm[i0_6, i1_2] = T_softmax_exp[i0_6, i1_2] / T_softmax_expsum[i0_6] + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(T_softmax_exp[v_i0, v_i1], T_softmax_expsum[v_i0]) + T.writes(T_softmax_norm[v_i0, v_i1]) + T.block_attr({"axis": 1}) + T_softmax_norm[v_i0, v_i1] = T_softmax_exp[v_i0, v_i1] / T_softmax_expsum[v_i0] @T.prim_func def sfm_4(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":0, "meta_schedule.vectorize":64}) - T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") - T_softmax_exp = T.alloc_buffer([256, 256], dtype="float32") - T_softmax_expsum = T.alloc_buffer([256], dtype="float32") - T_softmax_expsum_rf = T.alloc_buffer([256, 16], dtype="float32") - T_softmax_maxelem_rf = T.alloc_buffer([256, 1], dtype="float32") - for i0 in T.serial(256): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) + T_softmax_maxelem = T.alloc_buffer((256,)) + T_softmax_exp = T.alloc_buffer((256, 256)) + T_softmax_expsum = T.alloc_buffer((256,)) + T_softmax_expsum_rf = T.alloc_buffer((256, 16)) + T_softmax_maxelem_rf = T.alloc_buffer((256, 1)) + for i0 in range(256): for ax0, ax1, ax2 in T.grid(1, 1, 256): with T.block("T_softmax_maxelem_rf"): - vi1_1 = T.axis.spatial(1, ax0) - i0_1 = T.axis.spatial(256, i0 + ax1) - vi1_0 = T.axis.reduce(256, ax2) - T.reads(A[i0_1, vi1_1 + vi1_0]) - T.writes(T_softmax_maxelem_rf[i0_1, vi1_1]) + vk_1 = T.axis.spatial(1, ax0) + v_i0 = T.axis.spatial(256, i0 + ax1) + vk_0 = T.axis.reduce(256, ax2) + T.reads(A[v_i0, vk_0 + vk_1]) + T.writes(T_softmax_maxelem_rf[v_i0, vk_1]) with T.init(): - T_softmax_maxelem_rf[i0_1, vi1_1] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem_rf[i0_1, vi1_1] = T.max(T_softmax_maxelem_rf[i0_1, vi1_1], A[i0_1, vi1_1 + vi1_0]) - for i1_1 in T.serial(1): + T_softmax_maxelem_rf[v_i0, vk_1] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem_rf[v_i0, vk_1] = T.max(T_softmax_maxelem_rf[v_i0, vk_1], A[v_i0, vk_0 + vk_1]) + for k_1 in range(1): with T.block("T_softmax_maxelem"): - vi1_1, i0_2 = T.axis.remap("RS", [i1_1, i0]) - T.reads(T_softmax_maxelem_rf[i0_2, vi1_1]) - T.writes(T_softmax_maxelem[i0_2]) + vk_1, v_i0 = T.axis.remap("RS", [k_1, i0]) + T.reads(T_softmax_maxelem_rf[v_i0, vk_1]) + T.writes(T_softmax_maxelem[v_i0]) with T.init(): - T_softmax_maxelem[i0_2] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem[i0_2] = T.max(T_softmax_maxelem[i0_2], T_softmax_maxelem_rf[i0_2, vi1_1]) - for i0_3, i1 in T.grid(256, 256): + T_softmax_maxelem[v_i0] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0] = T.max(T_softmax_maxelem[v_i0], T_softmax_maxelem_rf[v_i0, vk_1]) + for i0, i1 in T.grid(256, 256): with T.block("T_softmax_exp"): - i0_4, i1_2 = T.axis.remap("SS", [i0_3, i1]) - T.reads(A[i0_4, i1_2], T_softmax_maxelem[i0_4]) - T.writes(T_softmax_exp[i0_4, i1_2]) - T_softmax_exp[i0_4, i1_2] = T.exp(A[i0_4, i1_2] - T_softmax_maxelem[i0_4], dtype="float32") - for i0_5, i1_0, i1_1 in T.grid(256, 16, 16): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[v_i0, v_i1], T_softmax_maxelem[v_i0]) + T.writes(T_softmax_exp[v_i0, v_i1]) + T_softmax_exp[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) + for i0, k_0, k_1 in T.grid(256, 16, 16): with T.block("T_softmax_expsum_rf"): - vi1_1, i0_6, vi1_0 = T.axis.remap("SSR", [i1_1, i0_5, i1_0]) - T.reads(T_softmax_exp[i0_6, vi1_0 * 16 + vi1_1]) - T.writes(T_softmax_expsum_rf[i0_6, vi1_1]) + vk_1, v_i0, vk_0 = T.axis.remap("SSR", [k_1, i0, k_0]) + T.reads(T_softmax_exp[v_i0, vk_0 * 16 + vk_1]) + T.writes(T_softmax_expsum_rf[v_i0, vk_1]) with T.init(): - T_softmax_expsum_rf[i0_6, vi1_1] = T.float32(0) - T_softmax_expsum_rf[i0_6, vi1_1] = T_softmax_expsum_rf[i0_6, vi1_1] + T_softmax_exp[i0_6, vi1_0 * 16 + vi1_1] - for i0_7, i1_1 in T.grid(256, 16): + T_softmax_expsum_rf[v_i0, vk_1] = T.float32(0) + T_softmax_expsum_rf[v_i0, vk_1] = T_softmax_expsum_rf[v_i0, vk_1] + T_softmax_exp[v_i0, vk_0 * 16 + vk_1] + for i0, k_1 in T.grid(256, 16): with T.block("T_softmax_expsum"): - vi1_1, i0_8 = T.axis.remap("RS", [i1_1, i0_7]) - T.reads(T_softmax_expsum_rf[i0_8, vi1_1]) - T.writes(T_softmax_expsum[i0_8]) + vk_1, v_i0 = T.axis.remap("RS", [k_1, i0]) + T.reads(T_softmax_expsum_rf[v_i0, vk_1]) + T.writes(T_softmax_expsum[v_i0]) with T.init(): - T_softmax_expsum[i0_8] = T.float32(0) - T_softmax_expsum[i0_8] = T_softmax_expsum[i0_8] + T_softmax_expsum_rf[i0_8, vi1_1] - for i0_9, i1_3 in T.grid(256, 256): + T_softmax_expsum[v_i0] = T.float32(0) + T_softmax_expsum[v_i0] = T_softmax_expsum[v_i0] + T_softmax_expsum_rf[v_i0, vk_1] + for i0, i1 in T.grid(256, 256): with T.block("T_softmax_norm"): - i0_10, i1_4 = T.axis.remap("SS", [i0_9, i1_3]) - T.reads(T_softmax_exp[i0_10, i1_4], T_softmax_expsum[i0_10]) - T.writes(T_softmax_norm[i0_10, i1_4]) - T.block_attr({"axis":1}) - T_softmax_norm[i0_10, i1_4] = T_softmax_exp[i0_10, i1_4] / T_softmax_expsum[i0_10] + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(T_softmax_exp[v_i0, v_i1], T_softmax_expsum[v_i0]) + T.writes(T_softmax_norm[v_i0, v_i1]) + T.block_attr({"axis": 1}) + T_softmax_norm[v_i0, v_i1] = T_softmax_exp[v_i0, v_i1] / T_softmax_expsum[v_i0] @T.prim_func def sfm_5(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":512, "meta_schedule.vectorize":64}) - T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") - T_softmax_exp = T.alloc_buffer([256, 256], dtype="float32") - T_softmax_expsum = T.alloc_buffer([256], dtype="float32") - T_softmax_expsum_rf = T.alloc_buffer([256, 16], dtype="float32") - for i0 in T.serial(256): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) + T_softmax_maxelem = T.alloc_buffer((256,)) + T_softmax_exp = T.alloc_buffer((256, 256)) + T_softmax_expsum = T.alloc_buffer((256,)) + T_softmax_expsum_rf = T.alloc_buffer((256, 16)) + for i0 in range(256): for ax0, ax1 in T.grid(1, 256): with T.block("T_softmax_maxelem"): - i0_1 = T.axis.spatial(256, i0 + ax0) - k = T.axis.reduce(256, ax1) - T.reads(A[i0_1, k]) - T.writes(T_softmax_maxelem[i0_1]) + v_i0 = T.axis.spatial(256, i0 + ax0) + v_k = T.axis.reduce(256, ax1) + T.reads(A[v_i0, v_k]) + T.writes(T_softmax_maxelem[v_i0]) with T.init(): - T_softmax_maxelem[i0_1] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k]) + T_softmax_maxelem[v_i0] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0] = T.max(T_softmax_maxelem[v_i0], A[v_i0, v_k]) for ax0, ax1 in T.grid(1, 256): with T.block("T_softmax_exp"): - i0_2 = T.axis.spatial(256, i0 + ax0) - i1 = T.axis.spatial(256, ax1) - T.reads(A[i0_2, i1], T_softmax_maxelem[i0_2]) - T.writes(T_softmax_exp[i0_2, i1]) - T_softmax_exp[i0_2, i1] = T.exp(A[i0_2, i1] - T_softmax_maxelem[i0_2], dtype="float32") - for ax0 in T.serial(16): + v_i0 = T.axis.spatial(256, i0 + ax0) + v_i1 = T.axis.spatial(256, ax1) + T.reads(A[v_i0, v_i1], T_softmax_maxelem[v_i0]) + T.writes(T_softmax_exp[v_i0, v_i1]) + T_softmax_exp[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) + for ax0 in range(16): for ax0_1, ax1, ax2 in T.grid(1, 1, 16): with T.block("T_softmax_expsum_rf"): - vi1_1 = T.axis.spatial(16, ax0 + ax0_1) - i0_3 = T.axis.spatial(256, i0 + ax1) - vi1_0 = T.axis.reduce(16, ax2) - T.reads(T_softmax_exp[i0_3, vi1_0 * 16 + vi1_1]) - T.writes(T_softmax_expsum_rf[i0_3, vi1_1]) + vk_1 = T.axis.spatial(16, ax0 + ax0_1) + v_i0 = T.axis.spatial(256, i0 + ax1) + vk_0 = T.axis.reduce(16, ax2) + T.reads(T_softmax_exp[v_i0, vk_0 * 16 + vk_1]) + T.writes(T_softmax_expsum_rf[v_i0, vk_1]) with T.init(): - T_softmax_expsum_rf[i0_3, vi1_1] = T.float32(0) - T_softmax_expsum_rf[i0_3, vi1_1] = T_softmax_expsum_rf[i0_3, vi1_1] + T_softmax_exp[i0_3, vi1_0 * 16 + vi1_1] - for ax1 in T.serial(1): + T_softmax_expsum_rf[v_i0, vk_1] = T.float32(0) + T_softmax_expsum_rf[v_i0, vk_1] = T_softmax_expsum_rf[v_i0, vk_1] + T_softmax_exp[v_i0, vk_0 * 16 + vk_1] + for ax1 in range(1): with T.block("T_softmax_expsum"): - vi1_1 = T.axis.reduce(16, ax0) - i0_4 = T.axis.spatial(256, i0 + ax1) - T.reads(T_softmax_expsum_rf[i0_4, vi1_1]) - T.writes(T_softmax_expsum[i0_4]) + vk_1 = T.axis.reduce(16, ax0) + v_i0 = T.axis.spatial(256, i0 + ax1) + T.reads(T_softmax_expsum_rf[v_i0, vk_1]) + T.writes(T_softmax_expsum[v_i0]) with T.init(): - T_softmax_expsum[i0_4] = T.float32(0) - T_softmax_expsum[i0_4] = T_softmax_expsum[i0_4] + T_softmax_expsum_rf[i0_4, vi1_1] - for i1 in T.serial(256): + T_softmax_expsum[v_i0] = T.float32(0) + T_softmax_expsum[v_i0] = T_softmax_expsum[v_i0] + T_softmax_expsum_rf[v_i0, vk_1] + for i1 in range(256): with T.block("T_softmax_norm"): - i0_5, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(T_softmax_exp[i0_5, i1_1], T_softmax_expsum[i0_5]) - T.writes(T_softmax_norm[i0_5, i1_1]) - T.block_attr({"axis":1}) - T_softmax_norm[i0_5, i1_1] = T_softmax_exp[i0_5, i1_1] / T_softmax_expsum[i0_5] + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(T_softmax_exp[v_i0, v_i1], T_softmax_expsum[v_i0]) + T.writes(T_softmax_norm[v_i0, v_i1]) + T.block_attr({"axis": 1}) + T_softmax_norm[v_i0, v_i1] = T_softmax_exp[v_i0, v_i1] / T_softmax_expsum[v_i0] @T.prim_func def sfm_6(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":64, "meta_schedule.vectorize":64}) - T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") - T_softmax_expsum = T.alloc_buffer([256], dtype="float32") - T_softmax_maxelem_rf = T.alloc_buffer([256, 64], dtype="float32") - for i0 in T.serial(256): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) + T_softmax_maxelem = T.alloc_buffer((256,)) + T_softmax_expsum = T.alloc_buffer((256,)) + T_softmax_maxelem_rf = T.alloc_buffer((256, 64)) + for i0 in range(256): for ax0, ax1, ax2 in T.grid(64, 1, 4): with T.block("T_softmax_maxelem_rf"): - vi1_0 = T.axis.spatial(64, ax0) - i0_1 = T.axis.spatial(256, i0 + ax1) - vi1_1 = T.axis.reduce(4, ax2) - T.reads(A[i0_1, vi1_0 * 4 + vi1_1]) - T.writes(T_softmax_maxelem_rf[i0_1, vi1_0]) + vk_0 = T.axis.spatial(64, ax0) + v_i0 = T.axis.spatial(256, i0 + ax1) + vk_1 = T.axis.reduce(4, ax2) + T.reads(A[v_i0, vk_0 * 4 + vk_1]) + T.writes(T_softmax_maxelem_rf[v_i0, vk_0]) with T.init(): - T_softmax_maxelem_rf[i0_1, vi1_0] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem_rf[i0_1, vi1_0] = T.max(T_softmax_maxelem_rf[i0_1, vi1_0], A[i0_1, vi1_0 * 4 + vi1_1]) - for i1_0 in T.serial(64): + T_softmax_maxelem_rf[v_i0, vk_0] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem_rf[v_i0, vk_0] = T.max(T_softmax_maxelem_rf[v_i0, vk_0], A[v_i0, vk_0 * 4 + vk_1]) + for k_0 in range(64): with T.block("T_softmax_maxelem"): - vi1_0, i0_2 = T.axis.remap("RS", [i1_0, i0]) - T.reads(T_softmax_maxelem_rf[i0_2, vi1_0]) - T.writes(T_softmax_maxelem[i0_2]) + vk_0, v_i0 = T.axis.remap("RS", [k_0, i0]) + T.reads(T_softmax_maxelem_rf[v_i0, vk_0]) + T.writes(T_softmax_maxelem[v_i0]) with T.init(): - T_softmax_maxelem[i0_2] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem[i0_2] = T.max(T_softmax_maxelem[i0_2], T_softmax_maxelem_rf[i0_2, vi1_0]) - for i0_3, i1 in T.grid(256, 256): + T_softmax_maxelem[v_i0] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0] = T.max(T_softmax_maxelem[v_i0], T_softmax_maxelem_rf[v_i0, vk_0]) + for i0, k in T.grid(256, 256): with T.block("T_softmax_expsum"): - i0_4, k = T.axis.remap("SR", [i0_3, i1]) - T.reads(A[i0_4, k], T_softmax_maxelem[i0_4]) - T.writes(T_softmax_expsum[i0_4]) + v_i0, v_k = T.axis.remap("SR", [i0, k]) + T.reads(A[v_i0, v_k], T_softmax_maxelem[v_i0]) + T.writes(T_softmax_expsum[v_i0]) with T.init(): - T_softmax_expsum[i0_4] = T.float32(0) - T_softmax_expsum[i0_4] = T_softmax_expsum[i0_4] + T.exp(A[i0_4, k] - T_softmax_maxelem[i0_4], dtype="float32") - for i0_5, i1 in T.grid(256, 256): + T_softmax_expsum[v_i0] = T.float32(0) + T_softmax_expsum[v_i0] = T_softmax_expsum[v_i0] + T.exp(A[v_i0, v_k] - T_softmax_maxelem[v_i0]) + for i0, i1 in T.grid(256, 256): with T.block("T_softmax_norm"): - i0_6, i1_1 = T.axis.remap("SS", [i0_5, i1]) - T.reads(A[i0_6, i1_1], T_softmax_maxelem[i0_6], T_softmax_expsum[i0_6]) - T.writes(T_softmax_norm[i0_6, i1_1]) - T.block_attr({"axis":1}) - T_softmax_norm[i0_6, i1_1] = T.exp(A[i0_6, i1_1] - T_softmax_maxelem[i0_6], dtype="float32") / T_softmax_expsum[i0_6] + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[v_i0, v_i1], T_softmax_maxelem[v_i0], T_softmax_expsum[v_i0]) + T.writes(T_softmax_norm[v_i0, v_i1]) + T.block_attr({"axis": 1}) + T_softmax_norm[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) / T_softmax_expsum[v_i0] @T.prim_func def sfm_7(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":64, "meta_schedule.vectorize":64}) - T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") - T_softmax_expsum = T.alloc_buffer([256], dtype="float32") - T_softmax_maxelem_rf = T.alloc_buffer([256, 4], dtype="float32") - for i0, i1_0, i1_1 in T.grid(256, 64, 4): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) + T_softmax_maxelem = T.alloc_buffer((256,)) + T_softmax_expsum = T.alloc_buffer((256,)) + T_softmax_maxelem_rf = T.alloc_buffer((256, 4)) + for i0, k_0, k_1 in T.grid(256, 64, 4): with T.block("T_softmax_maxelem_rf"): - vi1_1, i0_1, vi1_0 = T.axis.remap("SSR", [i1_1, i0, i1_0]) - T.reads(A[i0_1, vi1_0 * 4 + vi1_1]) - T.writes(T_softmax_maxelem_rf[i0_1, vi1_1]) + vk_1, v_i0, vk_0 = T.axis.remap("SSR", [k_1, i0, k_0]) + T.reads(A[v_i0, vk_0 * 4 + vk_1]) + T.writes(T_softmax_maxelem_rf[v_i0, vk_1]) with T.init(): - T_softmax_maxelem_rf[i0_1, vi1_1] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem_rf[i0_1, vi1_1] = T.max(T_softmax_maxelem_rf[i0_1, vi1_1], A[i0_1, vi1_0 * 4 + vi1_1]) - for i0, i1_1 in T.grid(256, 4): + T_softmax_maxelem_rf[v_i0, vk_1] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem_rf[v_i0, vk_1] = T.max(T_softmax_maxelem_rf[v_i0, vk_1], A[v_i0, vk_0 * 4 + vk_1]) + for i0, k_1 in T.grid(256, 4): with T.block("T_softmax_maxelem"): - vi1_1, i0_2 = T.axis.remap("RS", [i1_1, i0]) - T.reads(T_softmax_maxelem_rf[i0_2, vi1_1]) - T.writes(T_softmax_maxelem[i0_2]) + vk_1, v_i0 = T.axis.remap("RS", [k_1, i0]) + T.reads(T_softmax_maxelem_rf[v_i0, vk_1]) + T.writes(T_softmax_maxelem[v_i0]) with T.init(): - T_softmax_maxelem[i0_2] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem[i0_2] = T.max(T_softmax_maxelem[i0_2], T_softmax_maxelem_rf[i0_2, vi1_1]) - for i0_3, i1 in T.grid(256, 256): + T_softmax_maxelem[v_i0] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0] = T.max(T_softmax_maxelem[v_i0], T_softmax_maxelem_rf[v_i0, vk_1]) + for i0, i1 in T.grid(256, 256): for ax0, ax1 in T.grid(1, 256): with T.block("T_softmax_expsum"): - i0_4 = T.axis.spatial(256, i0_3 + ax0) - k = T.axis.reduce(256, ax1) - T.reads(A[i0_4, k], T_softmax_maxelem[i0_4]) - T.writes(T_softmax_expsum[i0_4]) + v_i0 = T.axis.spatial(256, i0 + ax0) + v_k = T.axis.reduce(256, ax1) + T.reads(A[v_i0, v_k], T_softmax_maxelem[v_i0]) + T.writes(T_softmax_expsum[v_i0]) with T.init(): - T_softmax_expsum[i0_4] = T.float32(0) - T_softmax_expsum[i0_4] = T_softmax_expsum[i0_4] + T.exp(A[i0_4, k] - T_softmax_maxelem[i0_4], dtype="float32") + T_softmax_expsum[v_i0] = T.float32(0) + T_softmax_expsum[v_i0] = T_softmax_expsum[v_i0] + T.exp(A[v_i0, v_k] - T_softmax_maxelem[v_i0]) with T.block("T_softmax_norm"): - i0_5, i1_2 = T.axis.remap("SS", [i0_3, i1]) - T.reads(A[i0_5, i1_2], T_softmax_maxelem[i0_5], T_softmax_expsum[i0_5]) - T.writes(T_softmax_norm[i0_5, i1_2]) - T.block_attr({"axis":1}) - T_softmax_norm[i0_5, i1_2] = T.exp(A[i0_5, i1_2] - T_softmax_maxelem[i0_5], dtype="float32") / T_softmax_expsum[i0_5] + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[v_i0, v_i1], T_softmax_maxelem[v_i0], T_softmax_expsum[v_i0]) + T.writes(T_softmax_norm[v_i0, v_i1]) + T.block_attr({"axis": 1}) + T_softmax_norm[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) / T_softmax_expsum[v_i0] @T.prim_func def sfm_8(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":512, "meta_schedule.vectorize":64}) - T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") - T_softmax_exp = T.alloc_buffer([256, 256], dtype="float32") - T_softmax_expsum = T.alloc_buffer([256], dtype="float32") - for i0 in T.serial(256): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) + T_softmax_maxelem = T.alloc_buffer((256,)) + T_softmax_exp = T.alloc_buffer((256, 256)) + T_softmax_expsum = T.alloc_buffer((256,)) + for i0 in range(256): for ax0, ax1 in T.grid(1, 256): with T.block("T_softmax_maxelem"): - i0_1 = T.axis.spatial(256, i0 + ax0) - k = T.axis.reduce(256, ax1) - T.reads(A[i0_1, k]) - T.writes(T_softmax_maxelem[i0_1]) + v_i0 = T.axis.spatial(256, i0 + ax0) + v_k = T.axis.reduce(256, ax1) + T.reads(A[v_i0, v_k]) + T.writes(T_softmax_maxelem[v_i0]) with T.init(): - T_softmax_maxelem[i0_1] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k]) - for i1 in T.serial(256): + T_softmax_maxelem[v_i0] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0] = T.max(T_softmax_maxelem[v_i0], A[v_i0, v_k]) + for i1 in range(256): with T.block("T_softmax_exp"): - i0_2, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(A[i0_2, i1_1], T_softmax_maxelem[i0_2]) - T.writes(T_softmax_exp[i0_2, i1_1]) - T_softmax_exp[i0_2, i1_1] = T.exp(A[i0_2, i1_1] - T_softmax_maxelem[i0_2], dtype="float32") - for i0_3, i1 in T.grid(256, 256): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[v_i0, v_i1], T_softmax_maxelem[v_i0]) + T.writes(T_softmax_exp[v_i0, v_i1]) + T_softmax_exp[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) + for i0, k in T.grid(256, 256): with T.block("T_softmax_expsum"): - i0_4, k = T.axis.remap("SR", [i0_3, i1]) - T.reads(T_softmax_exp[i0_4, k]) - T.writes(T_softmax_expsum[i0_4]) + v_i0, v_k = T.axis.remap("SR", [i0, k]) + T.reads(T_softmax_exp[v_i0, v_k]) + T.writes(T_softmax_expsum[v_i0]) with T.init(): - T_softmax_expsum[i0_4] = T.float32(0) - T_softmax_expsum[i0_4] = T_softmax_expsum[i0_4] + T_softmax_exp[i0_4, k] - for i0_5, i1 in T.grid(256, 256): + T_softmax_expsum[v_i0] = T.float32(0) + T_softmax_expsum[v_i0] = T_softmax_expsum[v_i0] + T_softmax_exp[v_i0, v_k] + for i0, i1 in T.grid(256, 256): with T.block("T_softmax_norm"): - i0_6, i1_2 = T.axis.remap("SS", [i0_5, i1]) - T.reads(T_softmax_exp[i0_6, i1_2], T_softmax_expsum[i0_6]) - T.writes(T_softmax_norm[i0_6, i1_2]) - T.block_attr({"axis":1}) - T_softmax_norm[i0_6, i1_2] = T_softmax_exp[i0_6, i1_2] / T_softmax_expsum[i0_6] + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(T_softmax_exp[v_i0, v_i1], T_softmax_expsum[v_i0]) + T.writes(T_softmax_norm[v_i0, v_i1]) + T.block_attr({"axis": 1}) + T_softmax_norm[v_i0, v_i1] = T_softmax_exp[v_i0, v_i1] / T_softmax_expsum[v_i0] # fmt: on decision_0 = [ ("SamplePerfectTile", [16, 16]), @@ -2206,127 +2128,121 @@ def test_cpu_cbr(): # fmt: off @T.prim_func def cbr_0(data: T.Buffer((1, 224, 224, 3), "float32"), kernel: T.Buffer((7, 7, 3, 64), "float32"), bias: T.Buffer(64, "float32"), bn_offset: T.Buffer(64, "float32"), bn_scale: T.Buffer(64, "float32"), compute: T.Buffer((1, 112, 112, 64), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":64, "meta_schedule.vectorize":64}) - Conv2dOutput = T.alloc_buffer([1, 112, 112, 64], dtype="float32") - for i0_0, i1_0, i2_0, i3_0, i0_1, i1_1, i2_1, i3_1, i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 2, 7, 1, 1, 2, 2, 32, 7, 7, 1, 1, 1, 4, 1, 1, 1, 3, 1, 28, 2, 2): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) + Conv2dOutput = T.alloc_buffer((1, 112, 112, 64)) + for nn_0, yy_0, xx_0, ff_0, nn_1, yy_1, xx_1, ff_1, ry_0, rx_0, rc_0, nn_2, yy_2, xx_2, ff_2, ry_1, rx_1, rc_1, nn_3, yy_3, xx_3, ff_3 in T.grid(1, 2, 7, 1, 1, 2, 2, 32, 7, 7, 1, 1, 1, 4, 1, 1, 1, 3, 1, 28, 2, 2): with T.block("Conv2dOutput"): - nn = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2) - yy = T.axis.spatial(112, i1_0 * 56 + i1_1 * 28 + i1_2 * 28 + i1_3) - xx = T.axis.spatial(112, i2_0 * 16 + i2_1 * 8 + i2_2 * 2 + i2_3) - ff = T.axis.spatial(64, i3_0 * 64 + i3_1 * 2 + i3_2 * 2 + i3_3) - ry = T.axis.reduce(7, i4_1 + i4_0) - rx = T.axis.reduce(7, i5_0 + i5_1) - rc = T.axis.reduce(3, i6_0 * 3 + i6_1) - T.reads(data[nn, yy * 2 + ry - 3, xx * 2 + rx - 3, rc], kernel[ry, rx, rc, ff]) - T.writes(Conv2dOutput[nn, yy, xx, ff]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_nn = T.axis.spatial(1, nn_0 + nn_1 + nn_2 + nn_3) + v_yy = T.axis.spatial(112, yy_0 * 56 + yy_1 * 28 + yy_2 * 28 + yy_3) + v_xx = T.axis.spatial(112, xx_0 * 16 + xx_1 * 8 + xx_2 * 2 + xx_3) + v_ff = T.axis.spatial(64, ff_0 * 64 + ff_1 * 2 + ff_2 * 2 + ff_3) + v_ry = T.axis.reduce(7, ry_0 + ry_1) + v_rx = T.axis.reduce(7, rx_0 + rx_1) + v_rc = T.axis.reduce(3, rc_0 * 3 + rc_1) + T.reads(data[v_nn, v_yy * 2 + v_ry - 3, v_xx * 2 + v_rx - 3, v_rc], kernel[v_ry, v_rx, v_rc, v_ff]) + T.writes(Conv2dOutput[v_nn, v_yy, v_xx, v_ff]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - Conv2dOutput[nn, yy, xx, ff] = T.float32(0) - Conv2dOutput[nn, yy, xx, ff] = Conv2dOutput[nn, yy, xx, ff] + T.if_then_else(3 <= yy * 2 + ry and yy * 2 + ry < 227 and 3 <= xx * 2 + rx and xx * 2 + rx < 227, data[nn, yy * 2 + ry - 3, xx * 2 + rx - 3, rc], T.float32(0), dtype="float32") * kernel[ry, rx, rc, ff] + Conv2dOutput[v_nn, v_yy, v_xx, v_ff] = T.float32(0) + Conv2dOutput[v_nn, v_yy, v_xx, v_ff] = Conv2dOutput[v_nn, v_yy, v_xx, v_ff] + T.if_then_else(3 <= v_yy * 2 + v_ry and v_yy * 2 + v_ry < 227 and 3 <= v_xx * 2 + v_rx and v_xx * 2 + v_rx < 227, data[v_nn, v_yy * 2 + v_ry - 3, v_xx * 2 + v_rx - 3, v_rc], T.float32(0)) * kernel[v_ry, v_rx, v_rc, v_ff] for i0, i1, i2, i3 in T.grid(1, 112, 112, 64): with T.block("compute"): - i0_4, i1_4, i2_4, i3_4 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(Conv2dOutput[i0_4, i1_4, i2_4, i3_4], bias[i3_4], bn_scale[i3_4], bn_offset[i3_4]) - T.writes(compute[i0_4, i1_4, i2_4, i3_4]) - compute[i0_4, i1_4, i2_4, i3_4] = T.max((Conv2dOutput[i0_4, i1_4, i2_4, i3_4] + bias[i3_4]) * bn_scale[i3_4] + bn_offset[i3_4], T.float32(0)) + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(Conv2dOutput[v_i0, v_i1, v_i2, v_i3], bias[v_i3], bn_scale[v_i3], bn_offset[v_i3]) + T.writes(compute[v_i0, v_i1, v_i2, v_i3]) + compute[v_i0, v_i1, v_i2, v_i3] = T.max((Conv2dOutput[v_i0, v_i1, v_i2, v_i3] + bias[v_i3]) * bn_scale[v_i3] + bn_offset[v_i3], T.float32(0)) @T.prim_func def cbr_1(data: T.Buffer((1, 224, 224, 3), "float32"), kernel: T.Buffer((7, 7, 3, 64), "float32"), bias: T.Buffer(64, "float32"), bn_offset: T.Buffer(64, "float32"), bn_scale: T.Buffer(64, "float32"), compute: T.Buffer((1, 112, 112, 64), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":512, "meta_schedule.vectorize":64}) - PaddedInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") - Conv2dOutput = T.alloc_buffer([1, 112, 112, 64], dtype="float32") - for i0_0, i1_0 in T.grid(1, 2): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) + PaddedInput = T.alloc_buffer((1, 230, 230, 3)) + Conv2dOutput = T.alloc_buffer((1, 112, 112, 64)) + for nn_0, yy_0 in T.grid(1, 2): for ax0, ax1, ax2, ax3 in T.grid(1, 117, 229, 3): with T.block("PaddedInput"): - i0 = T.axis.spatial(1, ax0) - i1 = T.axis.spatial(230, i1_0 * 112 + ax1) - i2 = T.axis.spatial(230, ax2) - i3 = T.axis.spatial(3, ax3) - T.reads(data[i0, i1 - 3, i2 - 3, i3]) - T.writes(PaddedInput[i0, i1, i2, i3]) - PaddedInput[i0, i1, i2, i3] = T.if_then_else(3 <= i1 and i1 < 227 and 3 <= i2 and i2 < 227, data[i0, i1 - 3, i2 - 3, i3], T.float32(0), dtype="float32") - for i2_0, i3_0, i0_1, i1_1, i2_1, i3_1 in T.grid(7, 1, 1, 2, 2, 32): - for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(7, 7, 1, 1, 1, 4, 1, 1, 1, 3, 1, 28, 2, 2): + v_i0 = T.axis.spatial(1, ax0) + v_i1 = T.axis.spatial(230, yy_0 * 112 + ax1) + v_i2 = T.axis.spatial(230, ax2) + v_i3 = T.axis.spatial(3, ax3) + T.reads(data[v_i0, v_i1 - 3, v_i2 - 3, v_i3]) + T.writes(PaddedInput[v_i0, v_i1, v_i2, v_i3]) + PaddedInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(3 <= v_i1 and v_i1 < 227 and 3 <= v_i2 and v_i2 < 227, data[v_i0, v_i1 - 3, v_i2 - 3, v_i3], T.float32(0)) + for xx_0, ff_0, nn_1, yy_1, xx_1, ff_1 in T.grid(7, 1, 1, 2, 2, 32): + for ry_0, rx_0, rc_0, nn_2, yy_2, xx_2, ff_2, ry_1, rx_1, rc_1, nn_3, yy_3, xx_3, ff_3 in T.grid(7, 7, 1, 1, 1, 4, 1, 1, 1, 3, 1, 28, 2, 2): with T.block("Conv2dOutput"): - nn = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2) - yy = T.axis.spatial(112, i1_0 * 56 + i1_1 * 28 + i1_2 * 28 + i1_3) - xx = T.axis.spatial(112, i2_0 * 16 + i2_1 * 8 + i2_2 * 2 + i2_3) - ff = T.axis.spatial(64, i3_0 * 64 + i3_1 * 2 + i3_2 * 2 + i3_3) - ry = T.axis.reduce(7, i4_1 + i4_0) - rx = T.axis.reduce(7, i5_0 + i5_1) - rc = T.axis.reduce(3, i6_0 * 3 + i6_1) - T.reads(PaddedInput[nn, yy * 2 + ry, xx * 2 + rx, rc], kernel[ry, rx, rc, ff]) - T.writes(Conv2dOutput[nn, yy, xx, ff]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_nn = T.axis.spatial(1, nn_0 + nn_1 + nn_2 + nn_3) + v_yy = T.axis.spatial(112, yy_0 * 56 + yy_1 * 28 + yy_2 * 28 + yy_3) + v_xx = T.axis.spatial(112, xx_0 * 16 + xx_1 * 8 + xx_2 * 2 + xx_3) + v_ff = T.axis.spatial(64, ff_0 * 64 + ff_1 * 2 + ff_2 * 2 + ff_3) + v_ry = T.axis.reduce(7, ry_0 + ry_1) + v_rx = T.axis.reduce(7, rx_0 + rx_1) + v_rc = T.axis.reduce(3, rc_0 * 3 + rc_1) + T.reads(PaddedInput[v_nn, v_yy * 2 + v_ry, v_xx * 2 + v_rx, v_rc], kernel[v_ry, v_rx, v_rc, v_ff]) + T.writes(Conv2dOutput[v_nn, v_yy, v_xx, v_ff]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - Conv2dOutput[nn, yy, xx, ff] = T.float32(0) - Conv2dOutput[nn, yy, xx, ff] = Conv2dOutput[nn, yy, xx, ff] + PaddedInput[nn, yy * 2 + ry, xx * 2 + rx, rc] * kernel[ry, rx, rc, ff] + Conv2dOutput[v_nn, v_yy, v_xx, v_ff] = T.float32(0) + Conv2dOutput[v_nn, v_yy, v_xx, v_ff] = Conv2dOutput[v_nn, v_yy, v_xx, v_ff] + PaddedInput[v_nn, v_yy * 2 + v_ry, v_xx * 2 + v_rx, v_rc] * kernel[v_ry, v_rx, v_rc, v_ff] for ax0, ax1, ax2, ax3 in T.grid(1, 28, 8, 2): with T.block("compute"): - i0 = T.axis.spatial(1, ax0) - i1 = T.axis.spatial(112, i1_0 * 56 + i1_1 * 28 + ax1) - i2 = T.axis.spatial(112, i2_0 * 16 + i2_1 * 8 + ax2) - i3 = T.axis.spatial(64, i3_1 * 2 + ax3) - T.reads(Conv2dOutput[i0, i1, i2, i3], bias[i3], bn_scale[i3], bn_offset[i3]) - T.writes(compute[i0, i1, i2, i3]) - compute[i0, i1, i2, i3] = T.max((Conv2dOutput[i0, i1, i2, i3] + bias[i3]) * bn_scale[i3] + bn_offset[i3], T.float32(0)) + v_i0 = T.axis.spatial(1, ax0) + v_i1 = T.axis.spatial(112, yy_0 * 56 + yy_1 * 28 + ax1) + v_i2 = T.axis.spatial(112, xx_0 * 16 + xx_1 * 8 + ax2) + v_i3 = T.axis.spatial(64, ff_1 * 2 + ax3) + T.reads(Conv2dOutput[v_i0, v_i1, v_i2, v_i3], bias[v_i3], bn_scale[v_i3], bn_offset[v_i3]) + T.writes(compute[v_i0, v_i1, v_i2, v_i3]) + compute[v_i0, v_i1, v_i2, v_i3] = T.max((Conv2dOutput[v_i0, v_i1, v_i2, v_i3] + bias[v_i3]) * bn_scale[v_i3] + bn_offset[v_i3], T.float32(0)) @T.prim_func def cbr_2(data: T.Buffer((1, 224, 224, 3), "float32"), kernel: T.Buffer((7, 7, 3, 64), "float32"), bias: T.Buffer(64, "float32"), bn_offset: T.Buffer(64, "float32"), bn_scale: T.Buffer(64, "float32"), compute: T.Buffer((1, 112, 112, 64), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":64, "meta_schedule.vectorize":64}) - PaddedInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") - Conv2dOutput = T.alloc_buffer([1, 112, 112, 64], dtype="float32") - for i0_0, i1_0 in T.grid(1, 2): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) + PaddedInput = T.alloc_buffer((1, 230, 230, 3)) + Conv2dOutput = T.alloc_buffer((1, 112, 112, 64)) + for nn_0, yy_0 in T.grid(1, 2): for ax0, ax1, ax2, ax3 in T.grid(1, 117, 229, 3): with T.block("PaddedInput"): - i0 = T.axis.spatial(1, ax0) - i1 = T.axis.spatial(230, i1_0 * 112 + ax1) - i2 = T.axis.spatial(230, ax2) - i3 = T.axis.spatial(3, ax3) - T.reads(data[i0, i1 - 3, i2 - 3, i3]) - T.writes(PaddedInput[i0, i1, i2, i3]) - PaddedInput[i0, i1, i2, i3] = T.if_then_else(3 <= i1 and i1 < 227 and 3 <= i2 and i2 < 227, data[i0, i1 - 3, i2 - 3, i3], T.float32(0), dtype="float32") - for i2_0, i3_0 in T.grid(7, 1): - for i0_1, i1_1, i2_1, i3_1, i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 2, 2, 32, 7, 7, 1, 1, 1, 4, 1, 1, 1, 3, 1, 28, 2, 2): + v_i0 = T.axis.spatial(1, ax0) + v_i1 = T.axis.spatial(230, yy_0 * 112 + ax1) + v_i2 = T.axis.spatial(230, ax2) + v_i3 = T.axis.spatial(3, ax3) + T.reads(data[v_i0, v_i1 - 3, v_i2 - 3, v_i3]) + T.writes(PaddedInput[v_i0, v_i1, v_i2, v_i3]) + PaddedInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(3 <= v_i1 and v_i1 < 227 and 3 <= v_i2 and v_i2 < 227, data[v_i0, v_i1 - 3, v_i2 - 3, v_i3], T.float32(0)) + for xx_0, ff_0 in T.grid(7, 1): + for nn_1, yy_1, xx_1, ff_1, ry_0, rx_0, rc_0, nn_2, yy_2, xx_2, ff_2, ry_1, rx_1, rc_1, nn_3, yy_3, xx_3, ff_3 in T.grid(1, 2, 2, 32, 7, 7, 1, 1, 1, 4, 1, 1, 1, 3, 1, 28, 2, 2): with T.block("Conv2dOutput"): - nn = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2) - yy = T.axis.spatial(112, i1_0 * 56 + i1_1 * 28 + i1_2 * 28 + i1_3) - xx = T.axis.spatial(112, i2_0 * 16 + i2_1 * 8 + i2_2 * 2 + i2_3) - ff = T.axis.spatial(64, i3_0 * 64 + i3_1 * 2 + i3_2 * 2 + i3_3) - ry = T.axis.reduce(7, i4_1 + i4_0) - rx = T.axis.reduce(7, i5_0 + i5_1) - rc = T.axis.reduce(3, i6_0 * 3 + i6_1) - T.reads(PaddedInput[nn, yy * 2 + ry, xx * 2 + rx, rc], kernel[ry, rx, rc, ff]) - T.writes(Conv2dOutput[nn, yy, xx, ff]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_nn = T.axis.spatial(1, nn_0 + nn_1 + nn_2 + nn_3) + v_yy = T.axis.spatial(112, yy_0 * 56 + yy_1 * 28 + yy_2 * 28 + yy_3) + v_xx = T.axis.spatial(112, xx_0 * 16 + xx_1 * 8 + xx_2 * 2 + xx_3) + v_ff = T.axis.spatial(64, ff_0 * 64 + ff_1 * 2 + ff_2 * 2 + ff_3) + v_ry = T.axis.reduce(7, ry_0 + ry_1) + v_rx = T.axis.reduce(7, rx_0 + rx_1) + v_rc = T.axis.reduce(3, rc_0 * 3 + rc_1) + T.reads(PaddedInput[v_nn, v_yy * 2 + v_ry, v_xx * 2 + v_rx, v_rc], kernel[v_ry, v_rx, v_rc, v_ff]) + T.writes(Conv2dOutput[v_nn, v_yy, v_xx, v_ff]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - Conv2dOutput[nn, yy, xx, ff] = T.float32(0) - Conv2dOutput[nn, yy, xx, ff] = Conv2dOutput[nn, yy, xx, ff] + PaddedInput[nn, yy * 2 + ry, xx * 2 + rx, rc] * kernel[ry, rx, rc, ff] + Conv2dOutput[v_nn, v_yy, v_xx, v_ff] = T.float32(0) + Conv2dOutput[v_nn, v_yy, v_xx, v_ff] = Conv2dOutput[v_nn, v_yy, v_xx, v_ff] + PaddedInput[v_nn, v_yy * 2 + v_ry, v_xx * 2 + v_rx, v_rc] * kernel[v_ry, v_rx, v_rc, v_ff] for ax0, ax1, ax2, ax3 in T.grid(1, 56, 16, 64): with T.block("compute"): - i0 = T.axis.spatial(1, ax0) - i1 = T.axis.spatial(112, i1_0 * 56 + ax1) - i2 = T.axis.spatial(112, i2_0 * 16 + ax2) - i3 = T.axis.spatial(64, ax3) - T.reads(Conv2dOutput[i0, i1, i2, i3], bias[i3], bn_scale[i3], bn_offset[i3]) - T.writes(compute[i0, i1, i2, i3]) - compute[i0, i1, i2, i3] = T.max((Conv2dOutput[i0, i1, i2, i3] + bias[i3]) * bn_scale[i3] + bn_offset[i3], T.float32(0)) + v_i0 = T.axis.spatial(1, ax0) + v_i1 = T.axis.spatial(112, yy_0 * 56 + ax1) + v_i2 = T.axis.spatial(112, xx_0 * 16 + ax2) + v_i3 = T.axis.spatial(64, ax3) + T.reads(Conv2dOutput[v_i0, v_i1, v_i2, v_i3], bias[v_i3], bn_scale[v_i3], bn_offset[v_i3]) + T.writes(compute[v_i0, v_i1, v_i2, v_i3]) + compute[v_i0, v_i1, v_i2, v_i3] = T.max((Conv2dOutput[v_i0, v_i1, v_i2, v_i3] + bias[v_i3]) * bn_scale[v_i3] + bn_offset[v_i3], T.float32(0)) # fmt: on decision_0 = [ ("SamplePerfectTile", [1, 1, 1, 1]), @@ -2375,140 +2291,134 @@ def test_cpu_tbg(): # fmt: off @T.prim_func def tbg_0(query: T.Buffer((1, 128, 12, 64), "float32"), value: T.Buffer((1, 128, 12, 64), "float32"), C: T.Buffer((1, 12, 128, 128), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":64, "meta_schedule.vectorize":64}) - query_T = T.alloc_buffer([1, 12, 128, 64], dtype="float32") - value_T = T.alloc_buffer([1, 12, 64, 128], dtype="float32") - C_global = T.alloc_buffer([1, 12, 128, 128], dtype="float32") - for i0_0, i1_0, i2_0, i3_0, i0_1, i1_1, i2_1 in T.grid(1, 1, 1, 2, 1, 6, 2): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) + query_T = T.alloc_buffer((1, 12, 128, 64)) + value_T = T.alloc_buffer((1, 12, 64, 128)) + C_global = T.alloc_buffer((1, 12, 128, 128)) + for b_0, h_0, i_0, j_0, b_1, h_1, i_1 in T.grid(1, 1, 1, 2, 1, 6, 2): for ax0, ax1, ax2, ax3 in T.grid(1, 2, 64, 64): with T.block("value_T"): - b = T.axis.spatial(1, ax0) - h = T.axis.spatial(12, i1_1 * 2 + ax1) - d = T.axis.spatial(64, ax2) - l = T.axis.spatial(128, i3_0 * 64 + ax3) - T.reads(value[b, l, h, d]) - T.writes(value_T[b, h, d, l]) - value_T[b, h, d, l] = value[b, l, h, d] + v_b = T.axis.spatial(1, ax0) + v_h = T.axis.spatial(12, h_1 * 2 + ax1) + v_d = T.axis.spatial(64, ax2) + v_l = T.axis.spatial(128, j_0 * 64 + ax3) + T.reads(value[v_b, v_l, v_h, v_d]) + T.writes(value_T[v_b, v_h, v_d, v_l]) + value_T[v_b, v_h, v_d, v_l] = value[v_b, v_l, v_h, v_d] for ax0, ax1, ax2, ax3 in T.grid(1, 2, 64, 64): with T.block("query_T"): - b = T.axis.spatial(1, ax0) - h = T.axis.spatial(12, i1_1 * 2 + ax1) - l = T.axis.spatial(128, i2_1 * 64 + ax2) - d = T.axis.spatial(64, ax3) - T.reads(query[b, l, h, d]) - T.writes(query_T[b, h, l, d]) - query_T[b, h, l, d] = query[b, l, h, d] - for i3_1 in T.serial(8): - for i4_0, i0_2, i1_2, i2_2, i3_2, i4_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 1, 2, 2, 4, 64, 1, 1, 32, 2): + v_b = T.axis.spatial(1, ax0) + v_h = T.axis.spatial(12, h_1 * 2 + ax1) + v_l = T.axis.spatial(128, i_1 * 64 + ax2) + v_d = T.axis.spatial(64, ax3) + T.reads(query[v_b, v_l, v_h, v_d]) + T.writes(query_T[v_b, v_h, v_l, v_d]) + query_T[v_b, v_h, v_l, v_d] = query[v_b, v_l, v_h, v_d] + for j_1 in range(8): + for k_0, b_2, h_2, i_2, j_2, k_1, b_3, h_3, i_3, j_3 in T.grid(1, 1, 2, 2, 4, 64, 1, 1, 32, 2): with T.block("C"): - b = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0) - h = T.axis.spatial(12, i1_0 * 12 + i1_1 * 2 + i1_2 + i1_3) - i = T.axis.spatial(128, i2_0 * 128 + i2_1 * 64 + i2_2 * 32 + i2_3) - j = T.axis.spatial(128, i3_0 * 64 + i3_1 * 8 + i3_2 * 2 + i3_3) - k = T.axis.reduce(64, i4_0 * 64 + i4_1) - T.reads(query_T[b, h, i, k], value_T[b, h, k, j]) - T.writes(C_global[b, h, i, j]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_b = T.axis.spatial(1, b_0 + b_1 + b_2 + b_3) + v_h = T.axis.spatial(12, h_0 * 12 + h_1 * 2 + h_2 + h_3) + v_i = T.axis.spatial(128, i_0 * 128 + i_1 * 64 + i_2 * 32 + i_3) + v_j = T.axis.spatial(128, j_0 * 64 + j_1 * 8 + j_2 * 2 + j_3) + v_k = T.axis.reduce(64, k_0 * 64 + k_1) + T.reads(query_T[v_b, v_h, v_i, v_k], value_T[v_b, v_h, v_k, v_j]) + T.writes(C_global[v_b, v_h, v_i, v_j]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - C_global[b, h, i, j] = T.float32(0) - C_global[b, h, i, j] = C_global[b, h, i, j] + query_T[b, h, i, k] * value_T[b, h, k, j] + C_global[v_b, v_h, v_i, v_j] = T.float32(0) + C_global[v_b, v_h, v_i, v_j] = C_global[v_b, v_h, v_i, v_j] + query_T[v_b, v_h, v_i, v_k] * value_T[v_b, v_h, v_k, v_j] for ax0, ax1, ax2, ax3 in T.grid(1, 2, 64, 8): with T.block("C_global"): v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(12, i1_1 * 2 + ax1) - v2 = T.axis.spatial(128, i2_1 * 64 + ax2) - v3 = T.axis.spatial(128, i3_0 * 64 + i3_1 * 8 + ax3) + v1 = T.axis.spatial(12, h_1 * 2 + ax1) + v2 = T.axis.spatial(128, i_1 * 64 + ax2) + v3 = T.axis.spatial(128, j_0 * 64 + j_1 * 8 + ax3) T.reads(C_global[v0, v1, v2, v3]) T.writes(C[v0, v1, v2, v3]) C[v0, v1, v2, v3] = C_global[v0, v1, v2, v3] @T.prim_func def tbg_1(query: T.Buffer((1, 128, 12, 64), "float32"), value: T.Buffer((1, 128, 12, 64), "float32"), C: T.Buffer((1, 12, 128, 128), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":64, "meta_schedule.vectorize":64}) - query_T = T.alloc_buffer([1, 12, 128, 64], dtype="float32") - value_T = T.alloc_buffer([1, 12, 64, 128], dtype="float32") - C_global = T.alloc_buffer([1, 12, 128, 128], dtype="float32") - for i0, i1, i2, i3 in T.grid(1, 12, 128, 64): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) + query_T = T.alloc_buffer((1, 12, 128, 64)) + value_T = T.alloc_buffer((1, 12, 64, 128)) + C_global = T.alloc_buffer((1, 12, 128, 128)) + for b, h, l, d in T.grid(1, 12, 128, 64): with T.block("query_T"): - b, h, l, d = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(query[b, l, h, d]) - T.writes(query_T[b, h, l, d]) - query_T[b, h, l, d] = query[b, l, h, d] - for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 1, 1, 2): - for i0_1, i1_1, i2_1, i3_1, i4_0, i0_2, i1_2, i2_2, i3_2, i4_1 in T.grid(1, 6, 2, 8, 1, 1, 2, 2, 4, 64): + v_b, v_h, v_l, v_d = T.axis.remap("SSSS", [b, h, l, d]) + T.reads(query[v_b, v_l, v_h, v_d]) + T.writes(query_T[v_b, v_h, v_l, v_d]) + query_T[v_b, v_h, v_l, v_d] = query[v_b, v_l, v_h, v_d] + for b_0, h_0, i_0, j_0 in T.grid(1, 1, 1, 2): + for b_1, h_1, i_1, j_1, k_0, b_2, h_2, i_2, j_2, k_1 in T.grid(1, 6, 2, 8, 1, 1, 2, 2, 4, 64): for ax0, ax1, ax2, ax3 in T.grid(1, 1, 1, 2): with T.block("value_T"): - b = T.axis.spatial(1, ax0) - h = T.axis.spatial(12, i1_1 * 2 + i1_2 + ax1) - d = T.axis.spatial(64, i4_1 + ax2) - l = T.axis.spatial(128, i3_0 * 64 + i3_1 * 8 + i3_2 * 2 + ax3) - T.reads(value[b, l, h, d]) - T.writes(value_T[b, h, d, l]) - value_T[b, h, d, l] = value[b, l, h, d] - for i0_3, i1_3, i2_3, i3_3 in T.grid(1, 1, 32, 2): + v_b = T.axis.spatial(1, ax0) + v_h = T.axis.spatial(12, h_1 * 2 + h_2 + ax1) + v_d = T.axis.spatial(64, k_1 + ax2) + v_l = T.axis.spatial(128, j_0 * 64 + j_1 * 8 + j_2 * 2 + ax3) + T.reads(value[v_b, v_l, v_h, v_d]) + T.writes(value_T[v_b, v_h, v_d, v_l]) + value_T[v_b, v_h, v_d, v_l] = value[v_b, v_l, v_h, v_d] + for b_3, h_3, i_3, j_3 in T.grid(1, 1, 32, 2): with T.block("C"): - b = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0) - h = T.axis.spatial(12, i1_0 * 12 + i1_1 * 2 + i1_2 + i1_3) - i = T.axis.spatial(128, i2_0 * 128 + i2_1 * 64 + i2_2 * 32 + i2_3) - j = T.axis.spatial(128, i3_0 * 64 + i3_1 * 8 + i3_2 * 2 + i3_3) - k = T.axis.reduce(64, i4_0 * 64 + i4_1) - T.reads(query_T[b, h, i, k], value_T[b, h, k, j]) - T.writes(C_global[b, h, i, j]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_b = T.axis.spatial(1, b_0 + b_1 + b_2 + b_3) + v_h = T.axis.spatial(12, h_0 * 12 + h_1 * 2 + h_2 + h_3) + v_i = T.axis.spatial(128, i_0 * 128 + i_1 * 64 + i_2 * 32 + i_3) + v_j = T.axis.spatial(128, j_0 * 64 + j_1 * 8 + j_2 * 2 + j_3) + v_k = T.axis.reduce(64, k_0 * 64 + k_1) + T.reads(query_T[v_b, v_h, v_i, v_k], value_T[v_b, v_h, v_k, v_j]) + T.writes(C_global[v_b, v_h, v_i, v_j]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - C_global[b, h, i, j] = T.float32(0) - C_global[b, h, i, j] = C_global[b, h, i, j] + query_T[b, h, i, k] * value_T[b, h, k, j] + C_global[v_b, v_h, v_i, v_j] = T.float32(0) + C_global[v_b, v_h, v_i, v_j] = C_global[v_b, v_h, v_i, v_j] + query_T[v_b, v_h, v_i, v_k] * value_T[v_b, v_h, v_k, v_j] for ax0, ax1, ax2, ax3 in T.grid(1, 12, 128, 64): with T.block("C_global"): v0, v1, v2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - v3 = T.axis.spatial(128, i3_0 * 64 + ax3) + v3 = T.axis.spatial(128, j_0 * 64 + ax3) T.reads(C_global[v0, v1, v2, v3]) T.writes(C[v0, v1, v2, v3]) C[v0, v1, v2, v3] = C_global[v0, v1, v2, v3] @T.prim_func def tbg_2(query: T.Buffer((1, 128, 12, 64), "float32"), value: T.Buffer((1, 128, 12, 64), "float32"), C: T.Buffer((1, 12, 128, 128), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":512, "meta_schedule.vectorize":64}) - value_T = T.alloc_buffer([1, 12, 64, 128], dtype="float32") - for i0_0, i1_0, i2_0, i3_0, i0_1, i1_1, i2_1, i3_1 in T.grid(1, 1, 1, 2, 1, 6, 2, 8): + T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) + value_T = T.alloc_buffer((1, 12, 64, 128)) + for b_0, h_0, i_0, j_0, b_1, h_1, i_1, j_1 in T.grid(1, 1, 1, 2, 1, 6, 2, 8): for ax0, ax1, ax2, ax3 in T.grid(1, 2, 64, 8): with T.block("value_T"): - b = T.axis.spatial(1, ax0) - h = T.axis.spatial(12, i1_1 * 2 + ax1) - d = T.axis.spatial(64, ax2) - l = T.axis.spatial(128, i3_0 * 64 + i3_1 * 8 + ax3) - T.reads(value[b, l, h, d]) - T.writes(value_T[b, h, d, l]) - value_T[b, h, d, l] = value[b, l, h, d] - for i4_0, i0_2, i1_2, i2_2, i3_2, i4_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 1, 2, 2, 4, 64, 1, 1, 32, 2): + v_b = T.axis.spatial(1, ax0) + v_h = T.axis.spatial(12, h_1 * 2 + ax1) + v_d = T.axis.spatial(64, ax2) + v_l = T.axis.spatial(128, j_0 * 64 + j_1 * 8 + ax3) + T.reads(value[v_b, v_l, v_h, v_d]) + T.writes(value_T[v_b, v_h, v_d, v_l]) + value_T[v_b, v_h, v_d, v_l] = value[v_b, v_l, v_h, v_d] + for k_0, b_2, h_2, i_2, j_2, k_1, b_3, h_3, i_3, j_3 in T.grid(1, 1, 2, 2, 4, 64, 1, 1, 32, 2): with T.block("C"): - b = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0) - h = T.axis.spatial(12, i1_0 * 12 + i1_1 * 2 + i1_2 + i1_3) - i = T.axis.spatial(128, i2_0 * 128 + i2_1 * 64 + i2_2 * 32 + i2_3) - j = T.axis.spatial(128, i3_0 * 64 + i3_1 * 8 + i3_2 * 2 + i3_3) - k = T.axis.reduce(64, i4_0 * 64 + i4_1) - T.reads(query[b, i, h, k], value_T[b, h, k, j]) - T.writes(C[b, h, i, j]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + v_b = T.axis.spatial(1, b_0 + b_1 + b_2 + b_3) + v_h = T.axis.spatial(12, h_0 * 12 + h_1 * 2 + h_2 + h_3) + v_i = T.axis.spatial(128, i_0 * 128 + i_1 * 64 + i_2 * 32 + i_3) + v_j = T.axis.spatial(128, j_0 * 64 + j_1 * 8 + j_2 * 2 + j_3) + v_k = T.axis.reduce(64, k_0 * 64 + k_1) + T.reads(query[v_b, v_i, v_h, v_k], value_T[v_b, v_h, v_k, v_j]) + T.writes(C[v_b, v_h, v_i, v_j]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): - C[b, h, i, j] = T.float32(0) - C[b, h, i, j] = C[b, h, i, j] + query[b, i, h, k] * value_T[b, h, k, j] + C[v_b, v_h, v_i, v_j] = T.float32(0) + C[v_b, v_h, v_i, v_j] = C[v_b, v_h, v_i, v_j] + query[v_b, v_i, v_h, v_k] * value_T[v_b, v_h, v_k, v_j] # fmt: on decision_0 = [ ("SamplePerfectTile", [1, 1, 1, 1]), diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py b/tests/python/unittest/test_meta_schedule_space_cuda.py index 1e7ba24fa8b1..23746ba0dc6d 100644 --- a/tests/python/unittest/test_meta_schedule_space_cuda.py +++ b/tests/python/unittest/test_meta_schedule_space_cuda.py @@ -43,56 +43,54 @@ def test_cuda_c1d(): # fmt: off @T.prim_func def c1d_0(inputs: T.Buffer((1, 256, 64), "float32"), weight: T.Buffer((3, 64, 128), "float32"), conv1d_nlc: T.Buffer((1, 128, 128), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit":16}) - conv1d_nlc_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local") - PadInput_shared = T.alloc_buffer([1, 258, 64], dtype="float32", scope="shared") - weight_shared = T.alloc_buffer([3, 64, 128], dtype="float32", scope="shared") - for i0_0_i1_0_i2_0_fused in T.thread_binding(4, thread="blockIdx.x"): - for i0_1_i1_1_i2_1_fused in T.thread_binding(16, thread="vthread.x"): - for i0_2_i1_2_i2_2_fused in T.thread_binding(4, thread="threadIdx.x"): - for i3_0, i4_0 in T.grid(1, 16): - for ax0_ax1_ax2_fused in T.serial(260): + T.block_attr({"meta_schedule.unroll_explicit": 16}) + conv1d_nlc_local = T.alloc_buffer((1, 128, 128), scope="local") + PadInput_shared = T.alloc_buffer((1, 258, 64), scope="shared") + weight_shared = T.alloc_buffer((3, 64, 128), scope="shared") + for n_0_l_0_co_0_fused in T.thread_binding(4, thread="blockIdx.x"): + for n_1_l_1_co_1_fused in T.thread_binding(16, thread="vthread.x"): + for n_2_l_2_co_2_fused in T.thread_binding(4, thread="threadIdx.x"): + for rl_0, rc_0 in T.grid(1, 16): + for ax0_ax1_ax2_fused in range(260): with T.block("PadInput_shared"): v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(258, i0_0_i1_0_i2_0_fused * 64 + ax0_ax1_ax2_fused // 4) - v2 = T.axis.spatial(64, i4_0 * 4 + ax0_ax1_ax2_fused % 4) + v1 = T.axis.spatial(258, n_0_l_0_co_0_fused * 64 + ax0_ax1_ax2_fused // 4) + v2 = T.axis.spatial(64, rc_0 * 4 + ax0_ax1_ax2_fused % 4) T.reads(inputs[v0, v1 - 1, v2]) T.writes(PadInput_shared[v0, v1, v2]) - T.block_attr({"meta_schedule.cooperative_fetch":4}) - PadInput_shared[v0, v1, v2] = T.if_then_else(1 <= v1 and v1 < 257, inputs[v0, v1 - 1, v2], T.float32(0), dtype="float32") - for ax0_ax1_ax2_fused in T.serial(1536): + T.block_attr({"meta_schedule.cooperative_fetch": 4}) + PadInput_shared[v0, v1, v2] = T.if_then_else(1 <= v1 and v1 < 257, inputs[v0, v1 - 1, v2], T.float32(0)) + for ax0_ax1_ax2_fused in range(1536): with T.block("weight_shared"): v0 = T.axis.spatial(3, ax0_ax1_ax2_fused // 512) - v1 = T.axis.spatial(64, i4_0 * 4 + ax0_ax1_ax2_fused % 512 // 128) + v1 = T.axis.spatial(64, rc_0 * 4 + ax0_ax1_ax2_fused % 512 // 128) v2 = T.axis.spatial(128, ax0_ax1_ax2_fused % 128) T.reads(weight[v0, v1, v2]) T.writes(weight_shared[v0, v1, v2]) - T.block_attr({"meta_schedule.cooperative_fetch":3}) + T.block_attr({"meta_schedule.cooperative_fetch": 3}) weight_shared[v0, v1, v2] = weight[v0, v1, v2] - for i3_1, i4_1, i0_3, i1_3, i2_3, i3_2, i4_2, i0_4, i1_4, i2_4 in T.grid(1, 2, 1, 1, 2, 3, 2, 1, 4, 8): + for rl_1, rc_1, n_3, l_3, co_3, rl_2, rc_2, n_4, l_4, co_4 in T.grid(1, 2, 1, 1, 2, 3, 2, 1, 4, 8): with T.block("conv1d_nlc"): - n = T.axis.spatial(1, i0_4 + i0_3) - l = T.axis.spatial(128, i0_0_i1_0_i2_0_fused * 32 + i0_1_i1_1_i2_1_fused // 2 * 4 + i1_3 * 4 + i1_4) - co = T.axis.spatial(128, i0_1_i1_1_i2_1_fused % 2 * 64 + i0_2_i1_2_i2_2_fused * 16 + i2_3 * 8 + i2_4) - rl = T.axis.reduce(3, i3_0 * 3 + i3_1 * 3 + i3_2) - rc = T.axis.reduce(64, i4_0 * 4 + i4_1 * 2 + i4_2) - T.reads(PadInput_shared[n, l * 2 + rl, co // 128 * 64 + rc], weight_shared[rl, rc, co]) - T.writes(conv1d_nlc_local[n, l, co]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) + v_n = T.axis.spatial(1, n_3 + n_4) + v_l = T.axis.spatial(128, n_0_l_0_co_0_fused * 32 + n_1_l_1_co_1_fused // 2 * 4 + l_3 * 4 + l_4) + v_co = T.axis.spatial(128, n_1_l_1_co_1_fused % 2 * 64 + n_2_l_2_co_2_fused * 16 + co_3 * 8 + co_4) + v_rl = T.axis.reduce(3, rl_0 * 3 + rl_1 * 3 + rl_2) + v_rc = T.axis.reduce(64, rc_0 * 4 + rc_1 * 2 + rc_2) + T.reads(PadInput_shared[v_n, v_l * 2 + v_rl, v_co // 128 * 64 + v_rc], weight_shared[v_rl, v_rc, v_co]) + T.writes(conv1d_nlc_local[v_n, v_l, v_co]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): - conv1d_nlc_local[n, l, co] = T.float32(0) - conv1d_nlc_local[n, l, co] = conv1d_nlc_local[n, l, co] + PadInput_shared[n, l * 2 + rl, co // 128 * 64 + rc] * weight_shared[rl, rc, co] + conv1d_nlc_local[v_n, v_l, v_co] = T.float32(0) + conv1d_nlc_local[v_n, v_l, v_co] = conv1d_nlc_local[v_n, v_l, v_co] + PadInput_shared[v_n, v_l * 2 + v_rl, v_co // 128 * 64 + v_rc] * weight_shared[v_rl, v_rc, v_co] for ax0, ax1, ax2 in T.grid(1, 4, 16): with T.block("conv1d_nlc_local"): v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused * 32 + i0_1_i1_1_i2_1_fused // 2 * 4 + ax1) - v2 = T.axis.spatial(128, i0_1_i1_1_i2_1_fused % 2 * 64 + i0_2_i1_2_i2_2_fused * 16 + ax2) + v1 = T.axis.spatial(128, n_0_l_0_co_0_fused * 32 + n_1_l_1_co_1_fused // 2 * 4 + ax1) + v2 = T.axis.spatial(128, n_1_l_1_co_1_fused % 2 * 64 + n_2_l_2_co_2_fused * 16 + ax2) T.reads(conv1d_nlc_local[v0, v1, v2]) T.writes(conv1d_nlc[v0, v1, v2]) conv1d_nlc[v0, v1, v2] = conv1d_nlc_local[v0, v1, v2] @@ -123,59 +121,59 @@ def test_cuda_c2d(): # fmt: off @T.prim_func def c2d_0(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit":16}) - conv2d_nhwc_local = T.alloc_buffer([1, 112, 112, 64], dtype="float32", scope="local") - PadInput_shared = T.alloc_buffer([1, 230, 230, 3], dtype="float32", scope="shared") - weight_shared = T.alloc_buffer([7, 7, 3, 64], dtype="float32", scope="shared") - for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(16, thread="blockIdx.x"): - for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(56, thread="vthread.x"): - for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(14, thread="threadIdx.x"): - for i4_0, i5_0, i6_0 in T.grid(1, 1, 1): - for ax0_ax1_ax2_ax3_fused in T.serial(80379): + T.block_attr({"meta_schedule.unroll_explicit": 16}) + conv2d_nhwc_local = T.alloc_buffer((1, 112, 112, 64), scope="local") + PadInput_shared = T.alloc_buffer((1, 230, 230, 3), scope="shared") + weight_shared = T.alloc_buffer((7, 7, 3, 64), scope="shared") + for n_0_h_0_w_0_co_0_fused in T.thread_binding(16, thread="blockIdx.x"): + for n_1_h_1_w_1_co_1_fused in T.thread_binding(56, thread="vthread.x"): + for n_2_h_2_w_2_co_2_fused in T.thread_binding(14, thread="threadIdx.x"): + for rh_0, rw_0, rc_0 in T.grid(1, 1, 1): + for ax0_ax1_ax2_ax3_fused in range(80379): with T.block("PadInput_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(230, ax0_ax1_ax2_ax3_fused // 351) - v2 = T.axis.spatial(230, i0_0_i1_0_i2_0_i3_0_fused // 8 * 112 + ax0_ax1_ax2_ax3_fused % 351 // 3) + v2 = T.axis.spatial(230, n_0_h_0_w_0_co_0_fused // 8 * 112 + ax0_ax1_ax2_ax3_fused % 351 // 3) v3 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 3) T.reads(inputs[v0, v1 - 3, v2 - 3, v3]) T.writes(PadInput_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch":2}) - PadInput_shared[v0, v1, v2, v3] = T.if_then_else(3 <= v1 and v1 < 227 and 3 <= v2 and v2 < 227, inputs[v0, v1 - 3, v2 - 3, v3], T.float32(0), dtype="float32") - for ax0_ax1_ax2_ax3_fused in T.serial(1176): + T.block_attr({"meta_schedule.cooperative_fetch": 2}) + PadInput_shared[v0, v1, v2, v3] = T.if_then_else(3 <= v1 and v1 < 227 and 3 <= v2 and v2 < 227, inputs[v0, v1 - 3, v2 - 3, v3], T.float32(0)) + for ax0_ax1_ax2_ax3_fused in range(1176): with T.block("weight_shared"): v0 = T.axis.spatial(7, ax0_ax1_ax2_ax3_fused // 168) v1 = T.axis.spatial(7, ax0_ax1_ax2_ax3_fused % 168 // 24) v2 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 24 // 8) - v3 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 8 * 8 + ax0_ax1_ax2_ax3_fused % 8) + v3 = T.axis.spatial(64, n_0_h_0_w_0_co_0_fused % 8 * 8 + ax0_ax1_ax2_ax3_fused % 8) T.reads(weight[v0, v1, v2, v3]) T.writes(weight_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch":4}) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) weight_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3] - for i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(1, 7, 1, 1, 8, 4, 1, 7, 1, 3, 1, 1, 1, 2): + for rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3, rh_2, rw_2, rc_2, n_4, h_4, w_4, co_4 in T.grid(1, 7, 1, 1, 8, 4, 1, 7, 1, 3, 1, 1, 1, 2): with T.block("conv2d_nhwc"): - n = T.axis.spatial(1, i0_3 + i0_4) - h = T.axis.spatial(112, i1_4 + i0_2_i1_2_i2_2_i3_2_fused * 8 + i1_3) - w = T.axis.spatial(112, i0_0_i1_0_i2_0_i3_0_fused // 8 * 56 + i0_1_i1_1_i2_1_i3_1_fused // 4 * 4 + i2_3 + i2_4) - co = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 8 * 8 + i0_1_i1_1_i2_1_i3_1_fused % 4 * 2 + i3_3 * 2 + i3_4) - rh = T.axis.reduce(7, i4_0 * 7 + i4_1 * 7 + i4_2) - rw = T.axis.reduce(7, i5_2 + i5_0 * 7 + i5_1) - rc = T.axis.reduce(3, i6_0 * 3 + i6_1 * 3 + i6_2) - T.reads(PadInput_shared[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight_shared[rh, rw, rc, co]) - T.writes(conv2d_nhwc_local[n, h, w, co]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) + v_n = T.axis.spatial(1, n_3 + n_4) + v_h = T.axis.spatial(112, n_2_h_2_w_2_co_2_fused * 8 + h_3 + h_4) + v_w = T.axis.spatial(112, n_0_h_0_w_0_co_0_fused // 8 * 56 + n_1_h_1_w_1_co_1_fused // 4 * 4 + w_3 + w_4) + v_co = T.axis.spatial(64, n_0_h_0_w_0_co_0_fused % 8 * 8 + n_1_h_1_w_1_co_1_fused % 4 * 2 + co_3 * 2 + co_4) + v_rh = T.axis.reduce(7, rh_0 * 7 + rh_1 * 7 + rh_2) + v_rw = T.axis.reduce(7, rw_0 * 7 + rw_1 + rw_2) + v_rc = T.axis.reduce(3, rc_0 * 3 + rc_1 * 3 + rc_2) + T.reads(PadInput_shared[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc], weight_shared[v_rh, v_rw, v_rc, v_co]) + T.writes(conv2d_nhwc_local[v_n, v_h, v_w, v_co]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): - conv2d_nhwc_local[n, h, w, co] = T.float32(0) - conv2d_nhwc_local[n, h, w, co] = conv2d_nhwc_local[n, h, w, co] + PadInput_shared[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc] * weight_shared[rh, rw, rc, co] + conv2d_nhwc_local[v_n, v_h, v_w, v_co] = T.float32(0) + conv2d_nhwc_local[v_n, v_h, v_w, v_co] = conv2d_nhwc_local[v_n, v_h, v_w, v_co] + PadInput_shared[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc] * weight_shared[v_rh, v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3 in T.grid(1, 8, 4, 2): with T.block("conv2d_nhwc_local"): v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(112, i0_2_i1_2_i2_2_i3_2_fused * 8 + ax1) - v2 = T.axis.spatial(112, i0_0_i1_0_i2_0_i3_0_fused // 8 * 56 + i0_1_i1_1_i2_1_i3_1_fused // 4 * 4 + ax2) - v3 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 8 * 8 + i0_1_i1_1_i2_1_i3_1_fused % 4 * 2 + ax3) + v1 = T.axis.spatial(112, n_2_h_2_w_2_co_2_fused * 8 + ax1) + v2 = T.axis.spatial(112, n_0_h_0_w_0_co_0_fused // 8 * 56 + n_1_h_1_w_1_co_1_fused // 4 * 4 + ax2) + v3 = T.axis.spatial(64, n_0_h_0_w_0_co_0_fused % 8 * 8 + n_1_h_1_w_1_co_1_fused % 4 * 2 + ax3) T.reads(conv2d_nhwc_local[v0, v1, v2, v3]) T.writes(conv2d_nhwc[v0, v1, v2, v3]) conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_local[v0, v1, v2, v3] @@ -207,30 +205,30 @@ def test_cuda_c3d(): # fmt: off @T.prim_func def c3d_0(inputs: T.Buffer((1, 16, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 7, 3, 64), "float32"), conv3d_ndhwc: T.Buffer((1, 8, 112, 112, 64), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit":16}) - conv3d_ndhwc_local = T.alloc_buffer([1, 8, 112, 112, 64], dtype="float32", scope="local") - PadInput_shared = T.alloc_buffer([1, 22, 230, 230, 3], dtype="float32", scope="shared") - weight_shared = T.alloc_buffer([7, 7, 7, 3, 64], dtype="float32", scope="shared") - for i0_0_i1_0_i2_0_i3_0_i4_0_fused in T.thread_binding(2, thread="blockIdx.x"): - for i0_1_i1_1_i2_1_i3_1_i4_1_fused in T.thread_binding(8, thread="vthread.x"): - for i0_2_i1_2_i2_2_i3_2_i4_2_fused in T.thread_binding(392, thread="threadIdx.x"): - for i5_0, i6_0, i7_0, i8_0 in T.grid(1, 1, 1, 1): - for ax0_ax1_ax2_ax3_ax4_fused in T.serial(1687959): + T.block_attr({"meta_schedule.unroll_explicit": 16}) + conv3d_ndhwc_local = T.alloc_buffer((1, 8, 112, 112, 64), scope="local") + PadInput_shared = T.alloc_buffer((1, 22, 230, 230, 3), scope="shared") + weight_shared = T.alloc_buffer((7, 7, 7, 3, 64), scope="shared") + for n_0_d_0_h_0_w_0_co_0_fused in T.thread_binding(2, thread="blockIdx.x"): + for n_1_d_1_h_1_w_1_co_1_fused in T.thread_binding(8, thread="vthread.x"): + for n_2_d_2_h_2_w_2_co_2_fused in T.thread_binding(392, thread="threadIdx.x"): + for rd_0, rh_0, rw_0, rc_0 in T.grid(1, 1, 1, 1): + for ax0_ax1_ax2_ax3_ax4_fused in range(1687959): with T.block("PadInput_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(22, ax0_ax1_ax2_ax3_ax4_fused // 80379) v2 = T.axis.spatial(230, ax0_ax1_ax2_ax3_ax4_fused % 80379 // 351) - v3 = T.axis.spatial(230, i0_0_i1_0_i2_0_i3_0_i4_0_fused * 112 + ax0_ax1_ax2_ax3_ax4_fused % 351 // 3) + v3 = T.axis.spatial(230, n_0_d_0_h_0_w_0_co_0_fused * 112 + ax0_ax1_ax2_ax3_ax4_fused % 351 // 3) v4 = T.axis.spatial(3, ax0_ax1_ax2_ax3_ax4_fused % 3) T.reads(inputs[v0, v1 - 3, v2 - 3, v3 - 3, v4]) T.writes(PadInput_shared[v0, v1, v2, v3, v4]) - T.block_attr({"meta_schedule.cooperative_fetch":4}) - PadInput_shared[v0, v1, v2, v3, v4] = T.if_then_else(3 <= v1 and v1 < 19 and 3 <= v2 and v2 < 227 and 3 <= v3 and v3 < 227, inputs[v0, v1 - 3, v2 - 3, v3 - 3, v4], T.float32(0), dtype="float32") - for ax0_ax1_ax2_ax3_ax4_fused in T.serial(65856): + T.block_attr({"meta_schedule.cooperative_fetch": 4}) + PadInput_shared[v0, v1, v2, v3, v4] = T.if_then_else(3 <= v1 and v1 < 19 and 3 <= v2 and v2 < 227 and 3 <= v3 and v3 < 227, inputs[v0, v1 - 3, v2 - 3, v3 - 3, v4], T.float32(0)) + for ax0_ax1_ax2_ax3_ax4_fused in range(65856): with T.block("weight_shared"): v0 = T.axis.spatial(7, ax0_ax1_ax2_ax3_ax4_fused // 9408) v1 = T.axis.spatial(7, ax0_ax1_ax2_ax3_ax4_fused % 9408 // 1344) @@ -239,32 +237,32 @@ def c3d_0(inputs: T.Buffer((1, 16, 224, 224, 3), "float32"), weight: T.Buffer((7 v4 = T.axis.spatial(64, ax0_ax1_ax2_ax3_ax4_fused % 64) T.reads(weight[v0, v1, v2, v3, v4]) T.writes(weight_shared[v0, v1, v2, v3, v4]) - T.block_attr({"meta_schedule.cooperative_fetch":3}) + T.block_attr({"meta_schedule.cooperative_fetch": 3}) weight_shared[v0, v1, v2, v3, v4] = weight[v0, v1, v2, v3, v4] - for i5_1, i6_1, i7_1, i8_1, i0_3, i1_3, i2_3, i3_3, i4_3, i5_2, i6_2, i7_2, i8_2, i0_4, i1_4, i2_4, i3_4, i4_4 in T.grid(7, 7, 1, 3, 1, 2, 2, 1, 32, 1, 1, 7, 1, 1, 1, 2, 4, 1): + for rd_1, rh_1, rw_1, rc_1, n_3, d_3, h_3, w_3, co_3, rd_2, rh_2, rw_2, rc_2, n_4, d_4, h_4, w_4, co_4 in T.grid(7, 7, 1, 3, 1, 2, 2, 1, 32, 1, 1, 7, 1, 1, 1, 2, 4, 1): with T.block("conv3d_ndhwc"): - n = T.axis.spatial(1, i0_4 + i0_3) - d = T.axis.spatial(8, i1_4 + i0_2_i1_2_i2_2_i3_2_i4_2_fused // 98 * 2 + i1_3) - h = T.axis.spatial(112, i0_1_i1_1_i2_1_i3_1_i4_1_fused // 2 * 28 + i0_2_i1_2_i2_2_i3_2_i4_2_fused % 98 // 14 * 4 + i2_3 * 2 + i2_4) - w = T.axis.spatial(112, i0_0_i1_0_i2_0_i3_0_i4_0_fused * 56 + i0_1_i1_1_i2_1_i3_1_i4_1_fused % 2 * 28 + i0_2_i1_2_i2_2_i3_2_i4_2_fused % 14 // 2 * 4 + i3_3 * 4 + i3_4) - co = T.axis.spatial(64, i0_2_i1_2_i2_2_i3_2_i4_2_fused % 2 * 32 + i4_3 + i4_4) - rd = T.axis.reduce(7, i5_2 + i5_0 * 7 + i5_1) - rh = T.axis.reduce(7, i6_0 * 7 + i6_1 + i6_2) - rw = T.axis.reduce(7, i7_0 * 7 + i7_1 * 7 + i7_2) - rc = T.axis.reduce(3, i8_0 * 3 + i8_1 + i8_2) - T.reads(PadInput_shared[n, d * 2 + rd, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight_shared[rd, rh, rw, rc, co]) - T.writes(conv3d_ndhwc_local[n, d, h, w, co]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) + v_n = T.axis.spatial(1, n_3 + n_4) + v_d = T.axis.spatial(8, n_2_d_2_h_2_w_2_co_2_fused // 98 * 2 + d_3 + d_4) + v_h = T.axis.spatial(112, n_1_d_1_h_1_w_1_co_1_fused // 2 * 28 + n_2_d_2_h_2_w_2_co_2_fused % 98 // 14 * 4 + h_3 * 2 + h_4) + v_w = T.axis.spatial(112, n_0_d_0_h_0_w_0_co_0_fused * 56 + n_1_d_1_h_1_w_1_co_1_fused % 2 * 28 + n_2_d_2_h_2_w_2_co_2_fused % 14 // 2 * 4 + w_3 * 4 + w_4) + v_co = T.axis.spatial(64, n_2_d_2_h_2_w_2_co_2_fused % 2 * 32 + co_3 + co_4) + v_rd = T.axis.reduce(7, rd_0 * 7 + rd_1 + rd_2) + v_rh = T.axis.reduce(7, rh_0 * 7 + rh_1 + rh_2) + v_rw = T.axis.reduce(7, rw_0 * 7 + rw_1 * 7 + rw_2) + v_rc = T.axis.reduce(3, rc_0 * 3 + rc_1 + rc_2) + T.reads(PadInput_shared[v_n, v_d * 2 + v_rd, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc], weight_shared[v_rd, v_rh, v_rw, v_rc, v_co]) + T.writes(conv3d_ndhwc_local[v_n, v_d, v_h, v_w, v_co]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): - conv3d_ndhwc_local[n, d, h, w, co] = T.float32(0) - conv3d_ndhwc_local[n, d, h, w, co] = conv3d_ndhwc_local[n, d, h, w, co] + PadInput_shared[n, d * 2 + rd, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc] * weight_shared[rd, rh, rw, rc, co] + conv3d_ndhwc_local[v_n, v_d, v_h, v_w, v_co] = T.float32(0) + conv3d_ndhwc_local[v_n, v_d, v_h, v_w, v_co] = conv3d_ndhwc_local[v_n, v_d, v_h, v_w, v_co] + PadInput_shared[v_n, v_d * 2 + v_rd, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc] * weight_shared[v_rd, v_rh, v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 2, 4, 4, 32): with T.block("conv3d_ndhwc_local"): v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(8, i0_2_i1_2_i2_2_i3_2_i4_2_fused // 98 * 2 + ax1) - v2 = T.axis.spatial(112, i0_1_i1_1_i2_1_i3_1_i4_1_fused // 2 * 28 + i0_2_i1_2_i2_2_i3_2_i4_2_fused % 98 // 14 * 4 + ax2) - v3 = T.axis.spatial(112, i0_0_i1_0_i2_0_i3_0_i4_0_fused * 56 + i0_1_i1_1_i2_1_i3_1_i4_1_fused % 2 * 28 + i0_2_i1_2_i2_2_i3_2_i4_2_fused % 14 // 2 * 4 + ax3) - v4 = T.axis.spatial(64, i0_2_i1_2_i2_2_i3_2_i4_2_fused % 2 * 32 + ax4) + v1 = T.axis.spatial(8, n_2_d_2_h_2_w_2_co_2_fused // 98 * 2 + ax1) + v2 = T.axis.spatial(112, n_1_d_1_h_1_w_1_co_1_fused // 2 * 28 + n_2_d_2_h_2_w_2_co_2_fused % 98 // 14 * 4 + ax2) + v3 = T.axis.spatial(112, n_0_d_0_h_0_w_0_co_0_fused * 56 + n_1_d_1_h_1_w_1_co_1_fused % 2 * 28 + n_2_d_2_h_2_w_2_co_2_fused % 14 // 2 * 4 + ax3) + v4 = T.axis.spatial(64, n_2_d_2_h_2_w_2_co_2_fused % 2 * 32 + ax4) T.reads(conv3d_ndhwc_local[v0, v1, v2, v3, v4]) T.writes(conv3d_ndhwc[v0, v1, v2, v3, v4]) conv3d_ndhwc[v0, v1, v2, v3, v4] = conv3d_ndhwc_local[v0, v1, v2, v3, v4] @@ -297,69 +295,67 @@ def test_cuda_cap(): # fmt: off @T.prim_func def cap_0(inputs: T.Buffer((1, 16, 16, 4, 4, 32), "float32"), weight: T.Buffer((3, 3, 4, 4, 32, 32), "float32"), conv2d_capsule_nhwijc: T.Buffer((1, 8, 8, 4, 4, 32), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit":64}) - conv2d_capsule_nhwijc_local = T.alloc_buffer([1, 8, 8, 4, 4, 32], dtype="float32", scope="local") - PadInput_shared = T.alloc_buffer([1, 18, 18, 4, 4, 32], dtype="float32", scope="shared") - weight_shared = T.alloc_buffer([3, 3, 4, 4, 32, 32], dtype="float32", scope="shared") - for i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused in T.thread_binding(256, thread="blockIdx.x"): - for i0_1_i1_1_i2_1_i3_1_i4_1_i5_1_fused in T.thread_binding(1, thread="vthread.x"): - for i0_2_i1_2_i2_2_i3_2_i4_2_i5_2_fused in T.thread_binding(4, thread="threadIdx.x"): - for i6_0, i7_0, i8_0, i9_0 in T.grid(3, 3, 2, 8): - for ax0_ax1_ax2_ax3_ax4_ax5_fused in T.serial(48): + T.block_attr({"meta_schedule.unroll_explicit": 64}) + conv2d_capsule_nhwijc_local = T.alloc_buffer((1, 8, 8, 4, 4, 32), scope="local") + PadInput_shared = T.alloc_buffer((1, 18, 18, 4, 4, 32), scope="shared") + weight_shared = T.alloc_buffer((3, 3, 4, 4, 32, 32), scope="shared") + for n_0_h_0_w_0_cap_i_0_cap_j_0_co_0_fused in T.thread_binding(256, thread="blockIdx.x"): + for n_1_h_1_w_1_cap_i_1_cap_j_1_co_1_fused in T.thread_binding(1, thread="vthread.x"): + for n_2_h_2_w_2_cap_i_2_cap_j_2_co_2_fused in T.thread_binding(4, thread="threadIdx.x"): + for rh_0, rw_0, cap_k_0, rc_0 in T.grid(3, 3, 2, 8): + for ax0_ax1_ax2_ax3_ax4_ax5_fused in range(48): with T.block("PadInput_shared"): v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(18, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused // 64 * 4 + i6_0 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 48 // 16) - v2 = T.axis.spatial(18, T.Add(i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 64 // 8 * 2 + i7_0, 0)) - v3 = T.axis.spatial(4, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 8 // 4 * 2 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 16 // 8) - v4 = T.axis.spatial(4, i8_0 * 2 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 8 // 4) - v5 = T.axis.spatial(32, i9_0 * 4 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 4) + v1 = T.axis.spatial(18, n_0_h_0_w_0_cap_i_0_cap_j_0_co_0_fused // 64 * 4 + rh_0 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 48 // 16) + v2 = T.axis.spatial(18, T.Add(n_0_h_0_w_0_cap_i_0_cap_j_0_co_0_fused % 64 // 8 * 2 + rw_0, 0)) + v3 = T.axis.spatial(4, n_0_h_0_w_0_cap_i_0_cap_j_0_co_0_fused % 8 // 4 * 2 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 16 // 8) + v4 = T.axis.spatial(4, cap_k_0 * 2 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 8 // 4) + v5 = T.axis.spatial(32, rc_0 * 4 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 4) T.reads(inputs[v0, v1 - 1, v2 - 1, v3, v4, v5]) T.writes(PadInput_shared[v0, v1, v2, v3, v4, v5]) - T.block_attr({"meta_schedule.cooperative_fetch":2}) - PadInput_shared[v0, v1, v2, v3, v4, v5] = T.if_then_else(1 <= v1 and v1 < 17 and 1 <= v2 and v2 < 17, inputs[v0, v1 - 1, v2 - 1, v3, v4, v5], T.float32(0), dtype="float32") - for ax0_ax1_ax2_ax3_ax4_ax5_fused in T.serial(256): + T.block_attr({"meta_schedule.cooperative_fetch": 2}) + PadInput_shared[v0, v1, v2, v3, v4, v5] = T.if_then_else(1 <= v1 and v1 < 17 and 1 <= v2 and v2 < 17, inputs[v0, v1 - 1, v2 - 1, v3, v4, v5], T.float32(0)) + for ax0_ax1_ax2_ax3_ax4_ax5_fused in range(256): with T.block("weight_shared"): - v0, v1 = T.axis.remap("SS", [i6_0, i7_0]) - v2 = T.axis.spatial(4, i8_0 * 2 + ax0_ax1_ax2_ax3_ax4_ax5_fused // 128) + v0, v1 = T.axis.remap("SS", [rh_0, rw_0]) + v2 = T.axis.spatial(4, cap_k_0 * 2 + ax0_ax1_ax2_ax3_ax4_ax5_fused // 128) v3 = T.axis.spatial(4, ax0_ax1_ax2_ax3_ax4_ax5_fused % 128 // 32) - v4 = T.axis.spatial(32, i9_0 * 4 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 32 // 8) - v5 = T.axis.spatial(32, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 4 * 8 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 8) + v4 = T.axis.spatial(32, rc_0 * 4 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 32 // 8) + v5 = T.axis.spatial(32, n_0_h_0_w_0_cap_i_0_cap_j_0_co_0_fused % 4 * 8 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 8) T.reads(weight[v0, v1, v2, v3, v4, v5]) T.writes(weight_shared[v0, v1, v2, v3, v4, v5]) - T.block_attr({"meta_schedule.cooperative_fetch":4}) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) weight_shared[v0, v1, v2, v3, v4, v5] = weight[v0, v1, v2, v3, v4, v5] - for i6_1, i7_1, i8_1, i9_1, i0_3, i1_3, i2_3, i3_3, i4_3, i5_3, i6_2, i7_2, i8_2, i9_2, i0_4, i1_4, i2_4, i3_4, i4_4, i5_4 in T.grid(1, 1, 1, 4, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 8): + for rh_1, rw_1, cap_k_1, rc_1, n_3, h_3, w_3, cap_i_3, cap_j_3, co_3, rh_2, rw_2, cap_k_2, rc_2, n_4, h_4, w_4, cap_i_4, cap_j_4, co_4 in T.grid(1, 1, 1, 4, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 8): with T.block("conv2d_capsule_nhwijc"): - n = T.axis.spatial(1, i0_4 + i0_3) - h = T.axis.spatial(8, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused // 64 * 2 + i1_3 + i1_4) - w = T.axis.spatial(8, i2_3 + i2_4 + i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 64 // 8) - cap_i = T.axis.spatial(4, i3_3 + i3_4 + i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 8 // 4 * 2 + i0_2_i1_2_i2_2_i3_2_i4_2_i5_2_fused // 2) - cap_j = T.axis.spatial(4, i0_2_i1_2_i2_2_i3_2_i4_2_i5_2_fused % 2 * 2 + i4_3 * 2 + i4_4) - co = T.axis.spatial(32, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 4 * 8 + i5_3 * 8 + i5_4) - rh = T.axis.reduce(3, i6_1 + i6_2 + i6_0) - rw = T.axis.reduce(3, i7_0 + i7_1 + i7_2) - cap_k = T.axis.reduce(4, i8_0 * 2 + i8_1 * 2 + i8_2) - rc = T.axis.reduce(32, i9_0 * 4 + i9_1 + i9_2) - T.reads(PadInput_shared[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc], weight_shared[rh, rw, cap_k, cap_j, rc, co]) - T.writes(conv2d_capsule_nhwijc_local[n, h, w, cap_i, cap_j, co]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) + v_n = T.axis.spatial(1, n_3 + n_4) + v_h = T.axis.spatial(8, n_0_h_0_w_0_cap_i_0_cap_j_0_co_0_fused // 64 * 2 + h_3 + h_4) + v_w = T.axis.spatial(8, n_0_h_0_w_0_cap_i_0_cap_j_0_co_0_fused % 64 // 8 + w_3 + w_4) + v_cap_i = T.axis.spatial(4, n_0_h_0_w_0_cap_i_0_cap_j_0_co_0_fused % 8 // 4 * 2 + n_2_h_2_w_2_cap_i_2_cap_j_2_co_2_fused // 2 + cap_i_3 + cap_i_4) + v_cap_j = T.axis.spatial(4, n_2_h_2_w_2_cap_i_2_cap_j_2_co_2_fused % 2 * 2 + cap_j_3 * 2 + cap_j_4) + v_co = T.axis.spatial(32, n_0_h_0_w_0_cap_i_0_cap_j_0_co_0_fused % 4 * 8 + co_3 * 8 + co_4) + v_rh = T.axis.reduce(3, rh_0 + rh_1 + rh_2) + v_rw = T.axis.reduce(3, rw_0 + rw_1 + rw_2) + v_cap_k = T.axis.reduce(4, cap_k_0 * 2 + cap_k_1 * 2 + cap_k_2) + v_rc = T.axis.reduce(32, rc_0 * 4 + rc_1 + rc_2) + T.reads(PadInput_shared[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_cap_i, v_cap_k, v_rc], weight_shared[v_rh, v_rw, v_cap_k, v_cap_j, v_rc, v_co]) + T.writes(conv2d_capsule_nhwijc_local[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): - conv2d_capsule_nhwijc_local[n, h, w, cap_i, cap_j, co] = T.float32(0) - conv2d_capsule_nhwijc_local[n, h, w, cap_i, cap_j, co] = conv2d_capsule_nhwijc_local[n, h, w, cap_i, cap_j, co] + PadInput_shared[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc] * weight_shared[rh, rw, cap_k, cap_j, rc, co] + conv2d_capsule_nhwijc_local[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co] = T.float32(0) + conv2d_capsule_nhwijc_local[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co] = conv2d_capsule_nhwijc_local[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co] + PadInput_shared[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_cap_i, v_cap_k, v_rc] * weight_shared[v_rh, v_rw, v_cap_k, v_cap_j, v_rc, v_co] for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(1, 2, 1, 1, 2, 8): with T.block("conv2d_capsule_nhwijc_local"): v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(8, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused // 64 * 2 + ax1) - v2 = T.axis.spatial(8, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 64 // 8 + ax2) - v3 = T.axis.spatial(4, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 8 // 4 * 2 + i0_2_i1_2_i2_2_i3_2_i4_2_i5_2_fused // 2 + ax3) - v4 = T.axis.spatial(4, i0_2_i1_2_i2_2_i3_2_i4_2_i5_2_fused % 2 * 2 + ax4) - v5 = T.axis.spatial(32, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 4 * 8 + ax5) + v1 = T.axis.spatial(8, n_0_h_0_w_0_cap_i_0_cap_j_0_co_0_fused // 64 * 2 + ax1) + v2 = T.axis.spatial(8, n_0_h_0_w_0_cap_i_0_cap_j_0_co_0_fused % 64 // 8 + ax2) + v3 = T.axis.spatial(4, n_0_h_0_w_0_cap_i_0_cap_j_0_co_0_fused % 8 // 4 * 2 + n_2_h_2_w_2_cap_i_2_cap_j_2_co_2_fused // 2 + ax3) + v4 = T.axis.spatial(4, n_2_h_2_w_2_cap_i_2_cap_j_2_co_2_fused % 2 * 2 + ax4) + v5 = T.axis.spatial(32, n_0_h_0_w_0_cap_i_0_cap_j_0_co_0_fused % 4 * 8 + ax5) T.reads(conv2d_capsule_nhwijc_local[v0, v1, v2, v3, v4, v5]) T.writes(conv2d_capsule_nhwijc[v0, v1, v2, v3, v4, v5]) conv2d_capsule_nhwijc[v0, v1, v2, v3, v4, v5] = conv2d_capsule_nhwijc_local[v0, v1, v2, v3, v4, v5] @@ -393,21 +389,19 @@ def test_cuda_dep(): # fmt: off @T.prim_func def dep_0(placeholder: T.Buffer((1, 112, 112, 32), "float32"), placeholder_1: T.Buffer((1, 3, 3, 32), "float32"), depth_conv2d_nhwc: T.Buffer((1, 112, 112, 32), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit":16}) - depth_conv2d_nhwc_local = T.alloc_buffer([1, 112, 112, 32], dtype="float32", scope="local") - PadInput_shared = T.alloc_buffer([1, 114, 114, 32], dtype="float32", scope="shared") - placeholder_shared = T.alloc_buffer([1, 3, 3, 32], dtype="float32", scope="shared") - for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(1, thread="blockIdx.x"): - for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(8, thread="vthread.x"): - for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(14, thread="threadIdx.x"): - for i4_0, i5_0 in T.grid(1, 1): - for ax0_ax1_ax2_ax3_fused in T.serial(415872): + T.block_attr({"meta_schedule.unroll_explicit": 16}) + depth_conv2d_nhwc_local = T.alloc_buffer((1, 112, 112, 32), scope="local") + PadInput_shared = T.alloc_buffer((1, 114, 114, 32), scope="shared") + placeholder_shared = T.alloc_buffer((1, 3, 3, 32), scope="shared") + for n_0_h_0_w_0_c_0_fused in T.thread_binding(1, thread="blockIdx.x"): + for n_1_h_1_w_1_c_1_fused in T.thread_binding(8, thread="vthread.x"): + for n_2_h_2_w_2_c_2_fused in T.thread_binding(14, thread="threadIdx.x"): + for rh_0, rw_0 in T.grid(1, 1): + for ax0_ax1_ax2_ax3_fused in range(415872): with T.block("PadInput_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(114, ax0_ax1_ax2_ax3_fused // 3648) @@ -415,9 +409,9 @@ def dep_0(placeholder: T.Buffer((1, 112, 112, 32), "float32"), placeholder_1: T. v3 = T.axis.spatial(32, ax0_ax1_ax2_ax3_fused % 32) T.reads(placeholder[v0, v1 - 1, v2 - 1, v3]) T.writes(PadInput_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch":3}) - PadInput_shared[v0, v1, v2, v3] = T.if_then_else(1 <= v1 and v1 < 113 and 1 <= v2 and v2 < 113, placeholder[v0, v1 - 1, v2 - 1, v3], T.float32(0), dtype="float32") - for ax0_ax1_ax2_ax3_fused in T.serial(288): + T.block_attr({"meta_schedule.cooperative_fetch": 3}) + PadInput_shared[v0, v1, v2, v3] = T.if_then_else(1 <= v1 and v1 < 113 and 1 <= v2 and v2 < 113, placeholder[v0, v1 - 1, v2 - 1, v3], T.float32(0)) + for ax0_ax1_ax2_ax3_fused in range(288): with T.block("placeholder_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused // 96) @@ -425,28 +419,28 @@ def dep_0(placeholder: T.Buffer((1, 112, 112, 32), "float32"), placeholder_1: T. v3 = T.axis.spatial(32, ax0_ax1_ax2_ax3_fused % 32) T.reads(placeholder_1[v0, v1, v2, v3]) T.writes(placeholder_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch":3}) + T.block_attr({"meta_schedule.cooperative_fetch": 3}) placeholder_shared[v0, v1, v2, v3] = placeholder_1[v0, v1, v2, v3] - for i4_1, i5_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i0_4, i1_4, i2_4, i3_4 in T.grid(3, 1, 1, 4, 16, 8, 1, 3, 1, 7, 1, 1): + for rh_1, rw_1, n_3, h_3, w_3, c_3, rh_2, rw_2, n_4, h_4, w_4, c_4 in T.grid(3, 1, 1, 4, 16, 8, 1, 3, 1, 7, 1, 1): with T.block("depth_conv2d_nhwc"): - n = T.axis.spatial(1, i0_4 + i0_3) - h = T.axis.spatial(112, i0_1_i1_1_i2_1_i3_1_fused // 2 * 28 + i1_3 * 7 + i1_4) - w = T.axis.spatial(112, i2_4 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 16 + i2_3) - c = T.axis.spatial(32, i0_1_i1_1_i2_1_i3_1_fused % 2 * 16 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 8 + i3_3 + i3_4) - rh = T.axis.reduce(3, i4_2 + i4_0 * 3 + i4_1) - rw = T.axis.reduce(3, i5_0 * 3 + i5_1 * 3 + i5_2) - T.reads(PadInput_shared[n, h + rh, w + rw, c], placeholder_shared[0, rh, rw, c]) - T.writes(depth_conv2d_nhwc_local[n, h, w, c]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) + v_n = T.axis.spatial(1, n_3 + n_4) + v_h = T.axis.spatial(112, n_1_h_1_w_1_c_1_fused // 2 * 28 + h_3 * 7 + h_4) + v_w = T.axis.spatial(112, n_2_h_2_w_2_c_2_fused // 2 * 16 + w_3 + w_4) + v_c = T.axis.spatial(32, n_1_h_1_w_1_c_1_fused % 2 * 16 + n_2_h_2_w_2_c_2_fused % 2 * 8 + c_3 + c_4) + v_rh = T.axis.reduce(3, rh_0 * 3 + rh_1 + rh_2) + v_rw = T.axis.reduce(3, rw_0 * 3 + rw_1 * 3 + rw_2) + T.reads(PadInput_shared[v_n, v_h + v_rh, v_w + v_rw, v_c], placeholder_shared[0, v_rh, v_rw, v_c]) + T.writes(depth_conv2d_nhwc_local[v_n, v_h, v_w, v_c]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): - depth_conv2d_nhwc_local[n, h, w, c] = T.float32(0) - depth_conv2d_nhwc_local[n, h, w, c] = depth_conv2d_nhwc_local[n, h, w, c] + PadInput_shared[n, h + rh, w + rw, c] * placeholder_shared[0, rh, rw, c] + depth_conv2d_nhwc_local[v_n, v_h, v_w, v_c] = T.float32(0) + depth_conv2d_nhwc_local[v_n, v_h, v_w, v_c] = depth_conv2d_nhwc_local[v_n, v_h, v_w, v_c] + PadInput_shared[v_n, v_h + v_rh, v_w + v_rw, v_c] * placeholder_shared[0, v_rh, v_rw, v_c] for ax0, ax1, ax2, ax3 in T.grid(1, 28, 16, 8): with T.block("depth_conv2d_nhwc_local"): v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(112, i0_1_i1_1_i2_1_i3_1_fused // 2 * 28 + ax1) - v2 = T.axis.spatial(112, i0_2_i1_2_i2_2_i3_2_fused // 2 * 16 + ax2) - v3 = T.axis.spatial(32, i0_1_i1_1_i2_1_i3_1_fused % 2 * 16 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 8 + ax3) + v1 = T.axis.spatial(112, n_1_h_1_w_1_c_1_fused // 2 * 28 + ax1) + v2 = T.axis.spatial(112, n_2_h_2_w_2_c_2_fused // 2 * 16 + ax2) + v3 = T.axis.spatial(32, n_1_h_1_w_1_c_1_fused % 2 * 16 + n_2_h_2_w_2_c_2_fused % 2 * 8 + ax3) T.reads(depth_conv2d_nhwc_local[v0, v1, v2, v3]) T.writes(depth_conv2d_nhwc[v0, v1, v2, v3]) depth_conv2d_nhwc[v0, v1, v2, v3] = depth_conv2d_nhwc_local[v0, v1, v2, v3] @@ -476,59 +470,57 @@ def test_cuda_dil(): # fmt: off @T.prim_func def dil_0(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 109, 109, 64), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit":512}) - conv2d_nhwc_local = T.alloc_buffer([1, 109, 109, 64], dtype="float32", scope="local") - PadInput_shared = T.alloc_buffer([1, 230, 230, 3], dtype="float32", scope="shared") - weight_shared = T.alloc_buffer([7, 7, 3, 64], dtype="float32", scope="shared") - for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(218, thread="blockIdx.x"): - for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(109, thread="vthread.x"): - for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(1, thread="threadIdx.x"): - for i4_0, i5_0, i6_0 in T.grid(7, 7, 3): - for ax0_ax1_ax2_ax3_fused in T.serial(217): + T.block_attr({"meta_schedule.unroll_explicit": 512}) + conv2d_nhwc_local = T.alloc_buffer((1, 109, 109, 64), scope="local") + PadInput_shared = T.alloc_buffer((1, 230, 230, 3), scope="shared") + weight_shared = T.alloc_buffer((7, 7, 3, 64), scope="shared") + for n_0_h_0_w_0_co_0_fused in T.thread_binding(218, thread="blockIdx.x"): + for n_1_h_1_w_1_co_1_fused in T.thread_binding(109, thread="vthread.x"): + for n_2_h_2_w_2_co_2_fused in T.thread_binding(1, thread="threadIdx.x"): + for rh_0, rw_0, rc_0 in T.grid(7, 7, 3): + for ax0_ax1_ax2_ax3_fused in range(217): with T.block("PadInput_shared"): v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(230, T.Add(i0_0_i1_0_i2_0_i3_0_fused // 2 * 2 + i4_0 * 2, 0)) - v2 = T.axis.spatial(230, i5_0 * 2 + ax0_ax1_ax2_ax3_fused % 217) - v3 = T.axis.spatial(3, T.Add(i6_0, 0)) + v1 = T.axis.spatial(230, T.Add(n_0_h_0_w_0_co_0_fused // 2 * 2 + rh_0 * 2, 0)) + v2 = T.axis.spatial(230, rw_0 * 2 + ax0_ax1_ax2_ax3_fused % 217) + v3 = T.axis.spatial(3, T.Add(rc_0, 0)) T.reads(inputs[v0, v1 - 3, v2 - 3, v3]) T.writes(PadInput_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch":2}) - PadInput_shared[v0, v1, v2, v3] = T.if_then_else(3 <= v1 and v1 < 227 and 3 <= v2 and v2 < 227, inputs[v0, v1 - 3, v2 - 3, v3], T.float32(0), dtype="float32") - for ax0_ax1_ax2_ax3_fused in T.serial(32): + T.block_attr({"meta_schedule.cooperative_fetch": 2}) + PadInput_shared[v0, v1, v2, v3] = T.if_then_else(3 <= v1 and v1 < 227 and 3 <= v2 and v2 < 227, inputs[v0, v1 - 3, v2 - 3, v3], T.float32(0)) + for ax0_ax1_ax2_ax3_fused in range(32): with T.block("weight_shared"): - v0, v1, v2 = T.axis.remap("SSS", [i4_0, i5_0, i6_0]) - v3 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + ax0_ax1_ax2_ax3_fused) + v0, v1, v2 = T.axis.remap("SSS", [rh_0, rw_0, rc_0]) + v3 = T.axis.spatial(64, n_0_h_0_w_0_co_0_fused % 2 * 32 + ax0_ax1_ax2_ax3_fused) T.reads(weight[v0, v1, v2, v3]) T.writes(weight_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch":4}) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) weight_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3] - for i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(1, 1, 1, 1, 1, 1, 8, 1, 1, 1, 1, 1, 1, 4): + for rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3, rh_2, rw_2, rc_2, n_4, h_4, w_4, co_4 in T.grid(1, 1, 1, 1, 1, 1, 8, 1, 1, 1, 1, 1, 1, 4): with T.block("conv2d_nhwc"): - n = T.axis.spatial(1, i0_3 + i0_4) - h = T.axis.spatial(109, i1_4 + i0_0_i1_0_i2_0_i3_0_fused // 2 + i1_3) - w = T.axis.spatial(109, i0_1_i1_1_i2_1_i3_1_fused + i2_3 + i2_4) - co = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + i3_3 * 4 + i3_4) - rh = T.axis.reduce(7, i4_0 + i4_1 + i4_2) - rw = T.axis.reduce(7, i5_2 + i5_0 + i5_1) - rc = T.axis.reduce(3, i6_1 + i6_2 + i6_0) - T.reads(PadInput_shared[n, h * 2 + rh * 2, w * 2 + rw * 2, co // 64 * 3 + rc], weight_shared[rh, rw, rc, co]) - T.writes(conv2d_nhwc_local[n, h, w, co]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) + v_n = T.axis.spatial(1, n_3 + n_4) + v_h = T.axis.spatial(109, n_0_h_0_w_0_co_0_fused // 2 + h_3 + h_4) + v_w = T.axis.spatial(109, n_1_h_1_w_1_co_1_fused + w_3 + w_4) + v_co = T.axis.spatial(64, n_0_h_0_w_0_co_0_fused % 2 * 32 + co_3 * 4 + co_4) + v_rh = T.axis.reduce(7, rh_0 + rh_1 + rh_2) + v_rw = T.axis.reduce(7, rw_0 + rw_1 + rw_2) + v_rc = T.axis.reduce(3, rc_0 + rc_1 + rc_2) + T.reads(PadInput_shared[v_n, v_h * 2 + v_rh * 2, v_w * 2 + v_rw * 2, v_co // 64 * 3 + v_rc], weight_shared[v_rh, v_rw, v_rc, v_co]) + T.writes(conv2d_nhwc_local[v_n, v_h, v_w, v_co]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): - conv2d_nhwc_local[n, h, w, co] = T.float32(0) - conv2d_nhwc_local[n, h, w, co] = conv2d_nhwc_local[n, h, w, co] + PadInput_shared[n, h * 2 + rh * 2, w * 2 + rw * 2, co // 64 * 3 + rc] * weight_shared[rh, rw, rc, co] + conv2d_nhwc_local[v_n, v_h, v_w, v_co] = T.float32(0) + conv2d_nhwc_local[v_n, v_h, v_w, v_co] = conv2d_nhwc_local[v_n, v_h, v_w, v_co] + PadInput_shared[v_n, v_h * 2 + v_rh * 2, v_w * 2 + v_rw * 2, v_co // 64 * 3 + v_rc] * weight_shared[v_rh, v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3 in T.grid(1, 1, 1, 32): with T.block("conv2d_nhwc_local"): v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(109, i0_0_i1_0_i2_0_i3_0_fused // 2 + ax1) - v2 = T.axis.spatial(109, i0_1_i1_1_i2_1_i3_1_fused + ax2) - v3 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + ax3) + v1 = T.axis.spatial(109, n_0_h_0_w_0_co_0_fused // 2 + ax1) + v2 = T.axis.spatial(109, n_1_h_1_w_1_co_1_fused + ax2) + v3 = T.axis.spatial(64, n_0_h_0_w_0_co_0_fused % 2 * 32 + ax3) T.reads(conv2d_nhwc_local[v0, v1, v2, v3]) T.writes(conv2d_nhwc[v0, v1, v2, v3]) conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_local[v0, v1, v2, v3] @@ -559,55 +551,53 @@ def test_cuda_gmm(): # fmt: off @T.prim_func def gmm_0(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), "float32"), Z: T.Buffer((1, 128, 128), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit":1024}) - Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local") - X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") - Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") - for i0_0_i1_0_i2_0_fused in T.thread_binding(1, thread="blockIdx.x"): - for i0_1_i1_1_i2_1_fused in T.thread_binding(32, thread="vthread.x"): - for i0_2_i1_2_i2_2_fused in T.thread_binding(2, thread="threadIdx.x"): - for i3_0 in T.serial(1): - for ax0_ax1_ax2_fused in T.serial(16384): + T.block_attr({"meta_schedule.unroll_explicit": 1024}) + Z_local = T.alloc_buffer((1, 128, 128), scope="local") + X_shared = T.alloc_buffer((1, 128, 128), scope="shared") + Y_shared = T.alloc_buffer((1, 128, 128), scope="shared") + for b_0_i_0_j_0_fused in T.thread_binding(1, thread="blockIdx.x"): + for b_1_i_1_j_1_fused in T.thread_binding(32, thread="vthread.x"): + for b_2_i_2_j_2_fused in T.thread_binding(2, thread="threadIdx.x"): + for k_0 in range(1): + for ax0_ax1_ax2_fused in range(16384): with T.block("X_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(128, ax0_ax1_ax2_fused // 128) v2 = T.axis.spatial(128, ax0_ax1_ax2_fused % 128) T.reads(X[v0, v1, v2]) T.writes(X_shared[v0, v1, v2]) - T.block_attr({"meta_schedule.cooperative_fetch":2}) + T.block_attr({"meta_schedule.cooperative_fetch": 2}) X_shared[v0, v1, v2] = X[v0, v1, v2] - for ax0_ax1_ax2_fused in T.serial(16384): + for ax0_ax1_ax2_fused in range(16384): with T.block("Y_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(128, ax0_ax1_ax2_fused // 128) v2 = T.axis.spatial(128, ax0_ax1_ax2_fused % 128) T.reads(Y[v0, v1, v2]) T.writes(Y_shared[v0, v1, v2]) - T.block_attr({"meta_schedule.cooperative_fetch":1}) + T.block_attr({"meta_schedule.cooperative_fetch": 1}) Y_shared[v0, v1, v2] = Y[v0, v1, v2] - for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(32, 1, 2, 64, 4, 1, 2, 1): + for k_1, b_3, i_3, j_3, k_2, b_4, i_4, j_4 in T.grid(32, 1, 2, 64, 4, 1, 2, 1): with T.block("Z"): - b = T.axis.spatial(1, i0_4 + i0_3) - i = T.axis.spatial(128, i0_1_i1_1_i2_1_fused * 4 + i1_3 * 2 + i1_4) - j = T.axis.spatial(128, i2_4 + i0_2_i1_2_i2_2_fused * 64 + i2_3) - k = T.axis.reduce(128, i3_0 * 128 + i3_1 * 4 + i3_2) - T.reads(X_shared[b, i, k], Y_shared[b, k, j]) - T.writes(Z_local[b, i, j]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) + v_b = T.axis.spatial(1, b_3 + b_4) + v_i = T.axis.spatial(128, b_1_i_1_j_1_fused * 4 + i_3 * 2 + i_4) + v_j = T.axis.spatial(128, b_2_i_2_j_2_fused * 64 + j_3 + j_4) + v_k = T.axis.reduce(128, k_0 * 128 + k_1 * 4 + k_2) + T.reads(X_shared[v_b, v_i, v_k], Y_shared[v_b, v_k, v_j]) + T.writes(Z_local[v_b, v_i, v_j]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): - Z_local[b, i, j] = T.float32(0) - Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j] + Z_local[v_b, v_i, v_j] = T.float32(0) + Z_local[v_b, v_i, v_j] = Z_local[v_b, v_i, v_j] + X_shared[v_b, v_i, v_k] * Y_shared[v_b, v_k, v_j] for ax0, ax1, ax2 in T.grid(1, 4, 64): with T.block("Z_local"): v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(128, i0_1_i1_1_i2_1_fused * 4 + ax1) - v2 = T.axis.spatial(128, i0_2_i1_2_i2_2_fused * 64 + ax2) + v1 = T.axis.spatial(128, b_1_i_1_j_1_fused * 4 + ax1) + v2 = T.axis.spatial(128, b_2_i_2_j_2_fused * 64 + ax2) T.reads(Z_local[v0, v1, v2]) T.writes(Z[v0, v1, v2]) Z[v0, v1, v2] = Z_local[v0, v1, v2] @@ -635,60 +625,58 @@ def test_cuda_grp(): # fmt: off @T.prim_func def grp_0(inputs: T.Buffer((1, 56, 56, 64), "float32"), weight: T.Buffer((3, 3, 16, 128), "float32"), conv2d_nhwc: T.Buffer((1, 28, 28, 128), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit":16}) - conv2d_nhwc_local = T.alloc_buffer([1, 28, 28, 128], dtype="float32", scope="local") - PadInput_shared = T.alloc_buffer([1, 58, 58, 64], dtype="float32", scope="shared") - weight_shared = T.alloc_buffer([3, 3, 16, 128], dtype="float32", scope="shared") - for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(2, thread="blockIdx.x"): - for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(1, thread="vthread.x"): - for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(112, thread="threadIdx.x"): - for i4_0, i5_0, i6_0 in T.grid(3, 3, 1): - for ax0_ax1_ax2_ax3_fused in T.serial(95040): + T.block_attr({"meta_schedule.unroll_explicit": 16}) + conv2d_nhwc_local = T.alloc_buffer((1, 28, 28, 128), scope="local") + PadInput_shared = T.alloc_buffer((1, 58, 58, 64), scope="shared") + weight_shared = T.alloc_buffer((3, 3, 16, 128), scope="shared") + for n_0_h_0_w_0_co_0_fused in T.thread_binding(2, thread="blockIdx.x"): + for n_1_h_1_w_1_co_1_fused in T.thread_binding(1, thread="vthread.x"): + for n_2_h_2_w_2_co_2_fused in T.thread_binding(112, thread="threadIdx.x"): + for rh_0, rw_0, rc_0 in T.grid(3, 3, 1): + for ax0_ax1_ax2_ax3_fused in range(95040): with T.block("PadInput_shared"): v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused * 28 + i4_0 + ax0_ax1_ax2_ax3_fused % 95040 // 3520) - v2 = T.axis.spatial(58, i5_0 + ax0_ax1_ax2_ax3_fused % 3520 // 64) + v1 = T.axis.spatial(58, n_0_h_0_w_0_co_0_fused * 28 + rh_0 + ax0_ax1_ax2_ax3_fused % 95040 // 3520) + v2 = T.axis.spatial(58, rw_0 + ax0_ax1_ax2_ax3_fused % 3520 // 64) v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64) T.reads(inputs[v0, v1 - 1, v2 - 1, v3]) T.writes(PadInput_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch":2}) - PadInput_shared[v0, v1, v2, v3] = T.if_then_else(1 <= v1 and v1 < 57 and 1 <= v2 and v2 < 57, inputs[v0, v1 - 1, v2 - 1, v3], T.float32(0), dtype="float32") - for ax0_ax1_ax2_ax3_fused in T.serial(2048): + T.block_attr({"meta_schedule.cooperative_fetch": 2}) + PadInput_shared[v0, v1, v2, v3] = T.if_then_else(1 <= v1 and v1 < 57 and 1 <= v2 and v2 < 57, inputs[v0, v1 - 1, v2 - 1, v3], T.float32(0)) + for ax0_ax1_ax2_ax3_fused in range(2048): with T.block("weight_shared"): - v0, v1 = T.axis.remap("SS", [i4_0, i5_0]) + v0, v1 = T.axis.remap("SS", [rh_0, rw_0]) v2 = T.axis.spatial(16, ax0_ax1_ax2_ax3_fused // 128) v3 = T.axis.spatial(128, ax0_ax1_ax2_ax3_fused % 128) T.reads(weight[v0, v1, v2, v3]) T.writes(weight_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch":1}) + T.block_attr({"meta_schedule.cooperative_fetch": 1}) weight_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3] - for i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(1, 1, 2, 1, 2, 1, 2, 1, 1, 8, 1, 7, 4, 4): + for rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3, rh_2, rw_2, rc_2, n_4, h_4, w_4, co_4 in T.grid(1, 1, 2, 1, 2, 1, 2, 1, 1, 8, 1, 7, 4, 4): with T.block("conv2d_nhwc"): - n = T.axis.spatial(1, i0_3 + i0_4) - h = T.axis.spatial(28, i0_0_i1_0_i2_0_i3_0_fused * 14 + i1_3 * 7 + i1_4) - w = T.axis.spatial(28, i0_2_i1_2_i2_2_i3_2_fused // 16 * 4 + i2_3 * 4 + i2_4) - co = T.axis.spatial(128, i0_2_i1_2_i2_2_i3_2_fused % 16 * 8 + i3_3 * 4 + i3_4) - rh = T.axis.reduce(3, i4_0 + i4_1 + i4_2) - rw = T.axis.reduce(3, i5_2 + i5_0 + i5_1) - rc = T.axis.reduce(16, i6_0 * 16 + i6_1 * 8 + i6_2) - T.reads(PadInput_shared[n, h * 2 + rh, w * 2 + rw, co // 32 * 16 + rc], weight_shared[rh, rw, rc, co]) - T.writes(conv2d_nhwc_local[n, h, w, co]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) + v_n = T.axis.spatial(1, n_3 + n_4) + v_h = T.axis.spatial(28, n_0_h_0_w_0_co_0_fused * 14 + h_3 * 7 + h_4) + v_w = T.axis.spatial(28, n_2_h_2_w_2_co_2_fused // 16 * 4 + w_3 * 4 + w_4) + v_co = T.axis.spatial(128, n_2_h_2_w_2_co_2_fused % 16 * 8 + co_3 * 4 + co_4) + v_rh = T.axis.reduce(3, rh_0 + rh_1 + rh_2) + v_rw = T.axis.reduce(3, rw_0 + rw_1 + rw_2) + v_rc = T.axis.reduce(16, rc_0 * 16 + rc_1 * 8 + rc_2) + T.reads(PadInput_shared[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 32 * 16 + v_rc], weight_shared[v_rh, v_rw, v_rc, v_co]) + T.writes(conv2d_nhwc_local[v_n, v_h, v_w, v_co]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): - conv2d_nhwc_local[n, h, w, co] = T.float32(0) - conv2d_nhwc_local[n, h, w, co] = conv2d_nhwc_local[n, h, w, co] + PadInput_shared[n, h * 2 + rh, w * 2 + rw, co // 32 * 16 + rc] * weight_shared[rh, rw, rc, co] + conv2d_nhwc_local[v_n, v_h, v_w, v_co] = T.float32(0) + conv2d_nhwc_local[v_n, v_h, v_w, v_co] = conv2d_nhwc_local[v_n, v_h, v_w, v_co] + PadInput_shared[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 32 * 16 + v_rc] * weight_shared[v_rh, v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3 in T.grid(1, 14, 4, 8): with T.block("conv2d_nhwc_local"): v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(28, i0_0_i1_0_i2_0_i3_0_fused * 14 + ax1) - v2 = T.axis.spatial(28, i0_2_i1_2_i2_2_i3_2_fused // 16 * 4 + ax2) - v3 = T.axis.spatial(128, i0_2_i1_2_i2_2_i3_2_fused % 16 * 8 + ax3) + v1 = T.axis.spatial(28, n_0_h_0_w_0_co_0_fused * 14 + ax1) + v2 = T.axis.spatial(28, n_2_h_2_w_2_co_2_fused // 16 * 4 + ax2) + v3 = T.axis.spatial(128, n_2_h_2_w_2_co_2_fused % 16 * 8 + ax3) T.reads(conv2d_nhwc_local[v0, v1, v2, v3]) T.writes(conv2d_nhwc[v0, v1, v2, v3]) conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_local[v0, v1, v2, v3] @@ -719,61 +707,59 @@ def test_cuda_t2d(): # fmt: off @T.prim_func def t2d_0(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 512, 256), "float32"), conv2d_transpose_nhwc: T.Buffer((1, 8, 8, 256), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit":64}) - conv2d_transpose_nhwc_local = T.alloc_buffer([1, 8, 8, 256], dtype="float32", scope="local") - PadInput_shared = T.alloc_buffer([1, 6, 6, 512], dtype="float32", scope="shared") - weight_shared = T.alloc_buffer([4, 4, 512, 256], dtype="float32", scope="shared") - for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(256, thread="blockIdx.x"): - for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(2, thread="vthread.x"): - for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(1, thread="threadIdx.x"): - for i4_0, i5_0, i6_0 in T.grid(4, 1, 16): - for ax0_ax1_ax2_ax3_fused in T.serial((i4_0 % 2 + 1) // 2 * 96 + 96): + T.block_attr({"meta_schedule.unroll_explicit": 64}) + conv2d_transpose_nhwc_local = T.alloc_buffer((1, 8, 8, 256), scope="local") + PadInput_shared = T.alloc_buffer((1, 6, 6, 512), scope="shared") + weight_shared = T.alloc_buffer((4, 4, 512, 256), scope="shared") + for n_0_h_0_w_0_co_0_fused in T.thread_binding(256, thread="blockIdx.x"): + for n_1_h_1_w_1_co_1_fused in T.thread_binding(2, thread="vthread.x"): + for n_2_h_2_w_2_co_2_fused in T.thread_binding(1, thread="threadIdx.x"): + for rh_0, rw_0, rc_0 in T.grid(4, 1, 16): + for ax0_ax1_ax2_ax3_fused in range(rh_0 % 2 * 96 + 96): with T.block("PadInput_shared"): v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 64 + i4_0 // 2 + ax0_ax1_ax2_ax3_fused % (96 * (i4_0 % 2 + 1)) // 96) - v2 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused % 64 // 16 + ax0_ax1_ax2_ax3_fused % 96 // 32) - v3 = T.axis.spatial(512, i6_0 * 32 + ax0_ax1_ax2_ax3_fused % 32) + v1 = T.axis.spatial(6, n_0_h_0_w_0_co_0_fused // 64 + rh_0 // 2 + ax0_ax1_ax2_ax3_fused % (96 * (rh_0 % 2 + 1)) // 96) + v2 = T.axis.spatial(6, n_0_h_0_w_0_co_0_fused % 64 // 16 + ax0_ax1_ax2_ax3_fused % 96 // 32) + v3 = T.axis.spatial(512, rc_0 * 32 + ax0_ax1_ax2_ax3_fused % 32) T.reads(inputs[v0, v1 - 1, v2 - 1, v3]) T.writes(PadInput_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch":2}) - PadInput_shared[v0, v1, v2, v3] = T.if_then_else(1 <= v1 and v1 < 5 and 1 <= v2 and v2 < 5, inputs[v0, v1 - 1, v2 - 1, v3], T.float32(0), dtype="float32") - for ax0_ax1_ax2_ax3_fused in T.serial(2048): + T.block_attr({"meta_schedule.cooperative_fetch": 2}) + PadInput_shared[v0, v1, v2, v3] = T.if_then_else(1 <= v1 and v1 < 5 and 1 <= v2 and v2 < 5, inputs[v0, v1 - 1, v2 - 1, v3], T.float32(0)) + for ax0_ax1_ax2_ax3_fused in range(2048): with T.block("weight_shared"): - v0 = T.axis.spatial(4, i4_0 * -1 + 3) + v0 = T.axis.spatial(4, rh_0 * -1 + 3) v1 = T.axis.spatial(4, ax0_ax1_ax2_ax3_fused // 512) - v2 = T.axis.spatial(512, i6_0 * 32 + ax0_ax1_ax2_ax3_fused % 512 // 16) - v3 = T.axis.spatial(256, i0_0_i1_0_i2_0_i3_0_fused % 16 * 16 + ax0_ax1_ax2_ax3_fused % 16) + v2 = T.axis.spatial(512, rc_0 * 32 + ax0_ax1_ax2_ax3_fused % 512 // 16) + v3 = T.axis.spatial(256, n_0_h_0_w_0_co_0_fused % 16 * 16 + ax0_ax1_ax2_ax3_fused % 16) T.reads(weight[v0, v1, v2, v3]) T.writes(weight_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch":4}) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) weight_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3] - for i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(1, 1, 4, 1, 2, 1, 8, 1, 4, 8, 1, 1, 2, 1): + for rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3, rh_2, rw_2, rc_2, n_4, h_4, w_4, co_4 in T.grid(1, 1, 4, 1, 2, 1, 8, 1, 4, 8, 1, 1, 2, 1): with T.block("conv2d_transpose_nhwc"): - n = T.axis.spatial(1, i0_3 + i0_4) - h = T.axis.spatial(8, i1_4 + i0_0_i1_0_i2_0_i3_0_fused // 64 * 2 + i1_3) - w = T.axis.spatial(8, i0_0_i1_0_i2_0_i3_0_fused % 64 // 16 * 2 + i2_3 * 2 + i2_4) - co = T.axis.spatial(256, i3_4 + i0_0_i1_0_i2_0_i3_0_fused % 16 * 16 + i0_1_i1_1_i2_1_i3_1_fused * 8 + i3_3) - rh = T.axis.reduce(4, i4_0 + i4_1 + i4_2) - rw = T.axis.reduce(4, i5_0 * 4 + i5_1 * 4 + i5_2) - rc = T.axis.reduce(512, i6_0 * 32 + i6_1 * 8 + i6_2) - T.reads(PadInput_shared[n, (h + rh) // 2, (w + rw) // 2, rc], weight_shared[3 - rh, 3 - rw, rc, co]) - T.writes(conv2d_transpose_nhwc_local[n, h, w, co]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) + v_n = T.axis.spatial(1, n_3 + n_4) + v_h = T.axis.spatial(8, n_0_h_0_w_0_co_0_fused // 64 * 2 + h_3 + h_4) + v_w = T.axis.spatial(8, n_0_h_0_w_0_co_0_fused % 64 // 16 * 2 + w_3 * 2 + w_4) + v_co = T.axis.spatial(256, n_0_h_0_w_0_co_0_fused % 16 * 16 + n_1_h_1_w_1_co_1_fused * 8 + co_3 + co_4) + v_rh = T.axis.reduce(4, rh_0 + rh_1 + rh_2) + v_rw = T.axis.reduce(4, rw_0 * 4 + rw_1 * 4 + rw_2) + v_rc = T.axis.reduce(512, rc_0 * 32 + rc_1 * 8 + rc_2) + T.reads(PadInput_shared[v_n, (v_h + v_rh) // 2, (v_w + v_rw) // 2, v_rc], weight_shared[3 - v_rh, 3 - v_rw, v_rc, v_co]) + T.writes(conv2d_transpose_nhwc_local[v_n, v_h, v_w, v_co]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): - conv2d_transpose_nhwc_local[n, h, w, co] = T.float32(0) - conv2d_transpose_nhwc_local[n, h, w, co] = conv2d_transpose_nhwc_local[n, h, w, co] + T.if_then_else((h + rh) % 2 == 0 and (w + rw) % 2 == 0, PadInput_shared[n, (h + rh) // 2, (w + rw) // 2, rc], T.float32(0), dtype="float32") * weight_shared[3 - rh, 3 - rw, rc, co] + conv2d_transpose_nhwc_local[v_n, v_h, v_w, v_co] = T.float32(0) + conv2d_transpose_nhwc_local[v_n, v_h, v_w, v_co] = conv2d_transpose_nhwc_local[v_n, v_h, v_w, v_co] + T.if_then_else((v_h + v_rh) % 2 == 0 and (v_w + v_rw) % 2 == 0, PadInput_shared[v_n, (v_h + v_rh) // 2, (v_w + v_rw) // 2, v_rc], T.float32(0)) * weight_shared[3 - v_rh, 3 - v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3 in T.grid(1, 2, 2, 8): with T.block("conv2d_transpose_nhwc_local"): v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(8, i0_0_i1_0_i2_0_i3_0_fused // 64 * 2 + ax1) - v2 = T.axis.spatial(8, i0_0_i1_0_i2_0_i3_0_fused % 64 // 16 * 2 + ax2) - v3 = T.axis.spatial(256, i0_0_i1_0_i2_0_i3_0_fused % 16 * 16 + i0_1_i1_1_i2_1_i3_1_fused * 8 + ax3) + v1 = T.axis.spatial(8, n_0_h_0_w_0_co_0_fused // 64 * 2 + ax1) + v2 = T.axis.spatial(8, n_0_h_0_w_0_co_0_fused % 64 // 16 * 2 + ax2) + v3 = T.axis.spatial(256, n_0_h_0_w_0_co_0_fused % 16 * 16 + n_1_h_1_w_1_co_1_fused * 8 + ax3) T.reads(conv2d_transpose_nhwc_local[v0, v1, v2, v3]) T.writes(conv2d_transpose_nhwc[v0, v1, v2, v3]) conv2d_transpose_nhwc[v0, v1, v2, v3] = conv2d_transpose_nhwc_local[v0, v1, v2, v3] @@ -805,61 +791,57 @@ def test_cuda_nrm(): # fmt: off @T.prim_func def nrm_0(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer(1, "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit":512}) - C = T.alloc_buffer([1], dtype="float32") - for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"): - for i0_fused_1 in T.thread_binding(1, thread="threadIdx.x"): - for i1, i2 in T.grid(256, 256): + T.block_attr({"meta_schedule.unroll_explicit": 512}) + C = T.alloc_buffer((1,)) + for b_fused_0 in T.thread_binding(1, thread="blockIdx.x"): + for b_fused_1 in T.thread_binding(1, thread="threadIdx.x"): + for i, j in T.grid(256, 256): with T.block("C"): - b = T.axis.spatial(1, 0) - i, j = T.axis.remap("RR", [i1, i2]) - T.reads(A[b, i, j]) - T.writes(C[b]) + v_b = T.axis.spatial(1, 0) + v_i, v_j = T.axis.remap("RR", [i, j]) + T.reads(A[v_b, v_i, v_j]) + T.writes(C[v_b]) with T.init(): - C[b] = T.float32(0) - C[b] = C[b] + A[b, i, j] * A[b, i, j] - for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"): - for i0_fused_1 in T.thread_binding(1, thread="threadIdx.x"): + C[v_b] = T.float32(0) + C[v_b] = C[v_b] + A[v_b, v_i, v_j] * A[v_b, v_i, v_j] + for b_fused_0 in T.thread_binding(1, thread="blockIdx.x"): + for b_fused_1 in T.thread_binding(1, thread="threadIdx.x"): with T.block("D"): - b = T.axis.spatial(1, 0) - T.reads(C[b]) - T.writes(D[b]) - D[b] = T.sqrt(C[b], dtype="float32") + v_b = T.axis.spatial(1, 0) + T.reads(C[v_b]) + T.writes(D[v_b]) + D[v_b] = T.sqrt(C[v_b]) @T.prim_func def nrm_1(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer(1, "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit":1024}) - C_shared = T.alloc_buffer([1], dtype="float32", scope="shared") - for i0_0_fused in T.thread_binding(1, thread="blockIdx.x"): + T.block_attr({"meta_schedule.unroll_explicit": 1024}) + C_shared = T.alloc_buffer((1,), scope="shared") + for b_0_fused in T.thread_binding(1, thread="blockIdx.x"): for ax0, ax1_ax2_fused_0 in T.grid(1, 512): for ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): with T.block("C"): - b = T.axis.spatial(1, ax0) - i = T.axis.reduce(256, (ax1_ax2_fused_0 * 128 + ax1_ax2_fused_1) // 256) - j = T.axis.reduce(256, (ax1_ax2_fused_0 * 128 + ax1_ax2_fused_1) % 256) - T.reads(A[b, i, j]) - T.writes(C_shared[b]) + v_b = T.axis.spatial(1, ax0) + v_i = T.axis.reduce(256, (ax1_ax2_fused_0 * 128 + ax1_ax2_fused_1) // 256) + v_j = T.axis.reduce(256, (ax1_ax2_fused_0 * 128 + ax1_ax2_fused_1) % 256) + T.reads(A[v_b, v_i, v_j]) + T.writes(C_shared[v_b]) with T.init(): - C_shared[b] = T.float32(0) - C_shared[b] = C_shared[b] + A[b, i, j] * A[b, i, j] - for i0_1 in T.thread_binding(128, thread="threadIdx.x"): + C_shared[v_b] = T.float32(0) + C_shared[v_b] = C_shared[v_b] + A[v_b, v_i, v_j] * A[v_b, v_i, v_j] + for b_1 in T.thread_binding(128, thread="threadIdx.x"): with T.block("D"): - b = T.axis.spatial(1, i0_1) - T.where(T.Mul(0, 128) + i0_1 < 1) - T.reads(C_shared[b]) - T.writes(D[b]) - D[b] = T.sqrt(C_shared[b], dtype="float32") + v_b = T.axis.spatial(1, b_1) + T.where(T.Mul(0, 128) + b_1 < 1) + T.reads(C_shared[v_b]) + T.writes(D[v_b]) + D[v_b] = T.sqrt(C_shared[v_b]) # fmt: on decision_0 = [ ("SampleCategorical", 3), @@ -882,176 +864,168 @@ def test_cuda_sfm(): # fmt: off @T.prim_func def sfm_0(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit":0}) - T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") - T_softmax_expsum = T.alloc_buffer([256], dtype="float32") + T.block_attr({"meta_schedule.unroll_explicit": 0}) + T_softmax_maxelem = T.alloc_buffer((256,)) + T_softmax_expsum = T.alloc_buffer((256,)) for i0_fused_0 in T.thread_binding(2, thread="blockIdx.x"): for i0_fused_1 in T.thread_binding(128, thread="threadIdx.x"): - for i1 in T.serial(256): + for k in range(256): with T.block("T_softmax_maxelem"): - i0 = T.axis.spatial(256, i0_fused_0 * 128 + i0_fused_1) - k = T.axis.reduce(256, i1) - T.reads(A[i0, k]) - T.writes(T_softmax_maxelem[i0]) + v_i0 = T.axis.spatial(256, i0_fused_0 * 128 + i0_fused_1) + v_k = T.axis.reduce(256, k) + T.reads(A[v_i0, v_k]) + T.writes(T_softmax_maxelem[v_i0]) with T.init(): - T_softmax_maxelem[i0] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem[i0] = T.max(T_softmax_maxelem[i0], A[i0, k]) + T_softmax_maxelem[v_i0] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0] = T.max(T_softmax_maxelem[v_i0], A[v_i0, v_k]) for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for i0_fused_1 in T.thread_binding(256, thread="threadIdx.x"): - for i1 in T.serial(256): + for k in range(256): with T.block("T_softmax_expsum"): - i0 = T.axis.spatial(256, i0_fused_0 * 256 + i0_fused_1) - k = T.axis.reduce(256, i1) - T.reads(A[i0, k], T_softmax_maxelem[i0]) - T.writes(T_softmax_expsum[i0]) + v_i0 = T.axis.spatial(256, i0_fused_0 * 256 + i0_fused_1) + v_k = T.axis.reduce(256, k) + T.reads(A[v_i0, v_k], T_softmax_maxelem[v_i0]) + T.writes(T_softmax_expsum[v_i0]) with T.init(): - T_softmax_expsum[i0] = T.float32(0) - T_softmax_expsum[i0] = T_softmax_expsum[i0] + T.exp(A[i0, k] - T_softmax_maxelem[i0], dtype="float32") + T_softmax_expsum[v_i0] = T.float32(0) + T_softmax_expsum[v_i0] = T_softmax_expsum[v_i0] + T.exp(A[v_i0, v_k] - T_softmax_maxelem[v_i0]) for i0_i1_fused_0 in T.thread_binding(1024, thread="blockIdx.x"): for i0_i1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): with T.block("T_softmax_norm"): - i0 = T.axis.spatial(256, (i0_i1_fused_0 * 64 + i0_i1_fused_1) // 256) - i1 = T.axis.spatial(256, (i0_i1_fused_0 * 64 + i0_i1_fused_1) % 256) - T.reads(A[i0, i1], T_softmax_maxelem[i0], T_softmax_expsum[i0]) - T.writes(T_softmax_norm[i0, i1]) - T.block_attr({"axis":1}) - T_softmax_norm[i0, i1] = T.exp(A[i0, i1] - T_softmax_maxelem[i0], dtype="float32") / T_softmax_expsum[i0] + v_i0 = T.axis.spatial(256, (i0_i1_fused_0 * 64 + i0_i1_fused_1) // 256) + v_i1 = T.axis.spatial(256, (i0_i1_fused_0 * 64 + i0_i1_fused_1) % 256) + T.reads(A[v_i0, v_i1], T_softmax_maxelem[v_i0], T_softmax_expsum[v_i0]) + T.writes(T_softmax_norm[v_i0, v_i1]) + T.block_attr({"axis": 1}) + T_softmax_norm[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) / T_softmax_expsum[v_i0] @T.prim_func def sfm_1(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit":16}) - T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") - T_softmax_expsum = T.alloc_buffer([256], dtype="float32") + T.block_attr({"meta_schedule.unroll_explicit": 16}) + T_softmax_maxelem = T.alloc_buffer((256,)) + T_softmax_expsum = T.alloc_buffer((256,)) for i0_fused in T.thread_binding(256, thread="blockIdx.x"): - for i1_0 in T.serial(64): - for i1_1 in T.thread_binding(4, thread="threadIdx.x"): + for k_0 in range(64): + for k_1 in T.thread_binding(4, thread="threadIdx.x"): with T.block("T_softmax_maxelem"): - i0 = T.axis.spatial(256, i0_fused) - k = T.axis.reduce(256, i1_0 * 4 + i1_1) - T.reads(A[i0, k]) - T.writes(T_softmax_maxelem[i0]) + v_i0 = T.axis.spatial(256, i0_fused) + v_k = T.axis.reduce(256, k_0 * 4 + k_1) + T.reads(A[v_i0, v_k]) + T.writes(T_softmax_maxelem[v_i0]) with T.init(): - T_softmax_maxelem[i0] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem[i0] = T.max(T_softmax_maxelem[i0], A[i0, k]) + T_softmax_maxelem[v_i0] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0] = T.max(T_softmax_maxelem[v_i0], A[v_i0, v_k]) for i0_fused_0 in T.thread_binding(4, thread="blockIdx.x"): for i0_fused_1 in T.thread_binding(64, thread="threadIdx.x"): - for i1 in T.serial(256): + for k in range(256): with T.block("T_softmax_expsum"): - i0 = T.axis.spatial(256, i0_fused_0 * 64 + i0_fused_1) - k = T.axis.reduce(256, i1) - T.reads(A[i0, k], T_softmax_maxelem[i0]) - T.writes(T_softmax_expsum[i0]) + v_i0 = T.axis.spatial(256, i0_fused_0 * 64 + i0_fused_1) + v_k = T.axis.reduce(256, k) + T.reads(A[v_i0, v_k], T_softmax_maxelem[v_i0]) + T.writes(T_softmax_expsum[v_i0]) with T.init(): - T_softmax_expsum[i0] = T.float32(0) - T_softmax_expsum[i0] = T_softmax_expsum[i0] + T.exp(A[i0, k] - T_softmax_maxelem[i0], dtype="float32") + T_softmax_expsum[v_i0] = T.float32(0) + T_softmax_expsum[v_i0] = T_softmax_expsum[v_i0] + T.exp(A[v_i0, v_k] - T_softmax_maxelem[v_i0]) for i0_i1_fused_0 in T.thread_binding(256, thread="blockIdx.x"): for i0_i1_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("T_softmax_norm"): - i0 = T.axis.spatial(256, (i0_i1_fused_0 * 256 + i0_i1_fused_1) // 256) - i1 = T.axis.spatial(256, (i0_i1_fused_0 * 256 + i0_i1_fused_1) % 256) - T.reads(A[i0, i1], T_softmax_maxelem[i0], T_softmax_expsum[i0]) - T.writes(T_softmax_norm[i0, i1]) - T.block_attr({"axis":1}) - T_softmax_norm[i0, i1] = T.exp(A[i0, i1] - T_softmax_maxelem[i0], dtype="float32") / T_softmax_expsum[i0] + v_i0 = T.axis.spatial(256, (i0_i1_fused_0 * 256 + i0_i1_fused_1) // 256) + v_i1 = T.axis.spatial(256, (i0_i1_fused_0 * 256 + i0_i1_fused_1) % 256) + T.reads(A[v_i0, v_i1], T_softmax_maxelem[v_i0], T_softmax_expsum[v_i0]) + T.writes(T_softmax_norm[v_i0, v_i1]) + T.block_attr({"axis": 1}) + T_softmax_norm[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) / T_softmax_expsum[v_i0] @T.prim_func def sfm_2(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit":512}) - T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") - T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") + T.block_attr({"meta_schedule.unroll_explicit": 512}) + T_softmax_maxelem = T.alloc_buffer((256,)) + T_softmax_expsum_shared = T.alloc_buffer((256,), scope="shared") for i0_fused_0 in T.thread_binding(8, thread="blockIdx.x"): for i0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): - for i1 in T.serial(256): + for k in range(256): with T.block("T_softmax_maxelem"): - i0 = T.axis.spatial(256, i0_fused_0 * 32 + i0_fused_1) - k = T.axis.reduce(256, i1) - T.reads(A[i0, k]) - T.writes(T_softmax_maxelem[i0]) + v_i0 = T.axis.spatial(256, i0_fused_0 * 32 + i0_fused_1) + v_k = T.axis.reduce(256, k) + T.reads(A[v_i0, v_k]) + T.writes(T_softmax_maxelem[v_i0]) with T.init(): - T_softmax_maxelem[i0] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem[i0] = T.max(T_softmax_maxelem[i0], A[i0, k]) + T_softmax_maxelem[v_i0] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0] = T.max(T_softmax_maxelem[v_i0], A[v_i0, v_k]) for i0_fused in T.thread_binding(256, thread="blockIdx.x"): for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_expsum"): - i0 = T.axis.spatial(256, i0_fused + ax0) - k = T.axis.reduce(256, ax1_0 * 512 + ax1_1) + v_i0 = T.axis.spatial(256, i0_fused + ax0) + v_k = T.axis.reduce(256, ax1_0 * 512 + ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) - T.reads(A[i0, k], T_softmax_maxelem[i0]) - T.writes(T_softmax_expsum_shared[i0]) + T.reads(A[v_i0, v_k], T_softmax_maxelem[v_i0]) + T.writes(T_softmax_expsum_shared[v_i0]) with T.init(): - T_softmax_expsum_shared[i0] = T.float32(0) - T_softmax_expsum_shared[i0] = T_softmax_expsum_shared[i0] + T.exp(A[i0, k] - T_softmax_maxelem[i0], dtype="float32") - for i1_0 in T.serial(1): + T_softmax_expsum_shared[v_i0] = T.float32(0) + T_softmax_expsum_shared[v_i0] = T_softmax_expsum_shared[v_i0] + T.exp(A[v_i0, v_k] - T_softmax_maxelem[v_i0]) + for i1_0 in range(1): for i1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_norm"): - i0 = T.axis.spatial(256, i0_fused) - i1 = T.axis.spatial(256, i1_0 * 512 + i1_1) + v_i0 = T.axis.spatial(256, i0_fused) + v_i1 = T.axis.spatial(256, i1_0 * 512 + i1_1) T.where(i1_0 * 512 + i1_1 < 256) - T.reads(A[i0, i1], T_softmax_maxelem[i0], T_softmax_expsum_shared[i0]) - T.writes(T_softmax_norm[i0, i1]) - T.block_attr({"axis":1}) - T_softmax_norm[i0, i1] = T.exp(A[i0, i1] - T_softmax_maxelem[i0], dtype="float32") / T_softmax_expsum_shared[i0] + T.reads(A[v_i0, v_i1], T_softmax_maxelem[v_i0], T_softmax_expsum_shared[v_i0]) + T.writes(T_softmax_norm[v_i0, v_i1]) + T.block_attr({"axis": 1}) + T_softmax_norm[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) / T_softmax_expsum_shared[v_i0] @T.prim_func def sfm_3(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit":0}) - T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") - T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") + T.block_attr({"meta_schedule.unroll_explicit": 0}) + T_softmax_maxelem_shared = T.alloc_buffer((256,), scope="shared") + T_softmax_expsum_shared = T.alloc_buffer((256,), scope="shared") for i0_fused in T.thread_binding(256, thread="blockIdx.x"): for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_maxelem"): - i0 = T.axis.spatial(256, i0_fused + ax0) - k = T.axis.reduce(256, ax1_0 * 512 + ax1_1) + v_i0 = T.axis.spatial(256, i0_fused + ax0) + v_k = T.axis.reduce(256, ax1_0 * 512 + ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) - T.reads(A[i0, k]) - T.writes(T_softmax_maxelem_shared[i0]) + T.reads(A[v_i0, v_k]) + T.writes(T_softmax_maxelem_shared[v_i0]) with T.init(): - T_softmax_maxelem_shared[i0] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem_shared[i0] = T.max(T_softmax_maxelem_shared[i0], A[i0, k]) + T_softmax_maxelem_shared[v_i0] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem_shared[v_i0] = T.max(T_softmax_maxelem_shared[v_i0], A[v_i0, v_k]) for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_expsum"): - i0 = T.axis.spatial(256, i0_fused + ax0) - k = T.axis.reduce(256, ax1_0 * 512 + ax1_1) + v_i0 = T.axis.spatial(256, i0_fused + ax0) + v_k = T.axis.reduce(256, ax1_0 * 512 + ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) - T.reads(A[i0, k], T_softmax_maxelem_shared[i0]) - T.writes(T_softmax_expsum_shared[i0]) + T.reads(A[v_i0, v_k], T_softmax_maxelem_shared[v_i0]) + T.writes(T_softmax_expsum_shared[v_i0]) with T.init(): - T_softmax_expsum_shared[i0] = T.float32(0) - T_softmax_expsum_shared[i0] = T_softmax_expsum_shared[i0] + T.exp(A[i0, k] - T_softmax_maxelem_shared[i0], dtype="float32") - for i1_0 in T.serial(1): + T_softmax_expsum_shared[v_i0] = T.float32(0) + T_softmax_expsum_shared[v_i0] = T_softmax_expsum_shared[v_i0] + T.exp(A[v_i0, v_k] - T_softmax_maxelem_shared[v_i0]) + for i1_0 in range(1): for i1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_norm"): - i0 = T.axis.spatial(256, i0_fused) - i1 = T.axis.spatial(256, i1_0 * 512 + i1_1) + v_i0 = T.axis.spatial(256, i0_fused) + v_i1 = T.axis.spatial(256, i1_0 * 512 + i1_1) T.where(i1_0 * 512 + i1_1 < 256) - T.reads(A[i0, i1], T_softmax_maxelem_shared[i0], T_softmax_expsum_shared[i0]) - T.writes(T_softmax_norm[i0, i1]) - T.block_attr({"axis":1}) - T_softmax_norm[i0, i1] = T.exp(A[i0, i1] - T_softmax_maxelem_shared[i0], dtype="float32") / T_softmax_expsum_shared[i0] + T.reads(A[v_i0, v_i1], T_softmax_maxelem_shared[v_i0], T_softmax_expsum_shared[v_i0]) + T.writes(T_softmax_norm[v_i0, v_i1]) + T.block_attr({"axis": 1}) + T_softmax_norm[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem_shared[v_i0]) / T_softmax_expsum_shared[v_i0] # fmt: on decision_0 = [ ("SampleCategorical", 0), @@ -1089,61 +1063,59 @@ def test_cuda_cbr(): # fmt: off @T.prim_func def cbr_0(data: T.Buffer((1, 224, 224, 3), "float32"), kernel: T.Buffer((7, 7, 3, 64), "float32"), bias: T.Buffer(64, "float32"), bn_offset: T.Buffer(64, "float32"), bn_scale: T.Buffer(64, "float32"), compute: T.Buffer((1, 112, 112, 64), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit":512}) - Conv2dOutput_local = T.alloc_buffer([1, 112, 112, 64], dtype="float32", scope="local") - PaddedInput_shared = T.alloc_buffer([1, 230, 230, 3], dtype="float32", scope="shared") - kernel_shared = T.alloc_buffer([7, 7, 3, 64], dtype="float32", scope="shared") - for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(14, thread="blockIdx.x"): - for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(4, thread="vthread.x"): - for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(128, thread="threadIdx.x"): - for i4_0, i5_0, i6_0 in T.grid(7, 1, 3): - for ax0_ax1_ax2_ax3_fused in T.serial(8251): + T.block_attr({"meta_schedule.unroll_explicit": 512}) + Conv2dOutput_local = T.alloc_buffer((1, 112, 112, 64), scope="local") + PaddedInput_shared = T.alloc_buffer((1, 230, 230, 3), scope="shared") + kernel_shared = T.alloc_buffer((7, 7, 3, 64), scope="shared") + for nn_0_yy_0_xx_0_ff_0_fused in T.thread_binding(14, thread="blockIdx.x"): + for nn_1_yy_1_xx_1_ff_1_fused in T.thread_binding(4, thread="vthread.x"): + for nn_2_yy_2_xx_2_ff_2_fused in T.thread_binding(128, thread="threadIdx.x"): + for ry_0, rx_0, rc_0 in T.grid(7, 1, 3): + for ax0_ax1_ax2_ax3_fused in range(8251): with T.block("PaddedInput_shared"): v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(230, ax0_ax1_ax2_ax3_fused // 37 + i4_0) - v2 = T.axis.spatial(230, i0_0_i1_0_i2_0_i3_0_fused // 2 * 32 + ax0_ax1_ax2_ax3_fused % 37) - v3 = T.axis.spatial(3, i6_0) + v1 = T.axis.spatial(230, ry_0 + ax0_ax1_ax2_ax3_fused // 37) + v2 = T.axis.spatial(230, nn_0_yy_0_xx_0_ff_0_fused // 2 * 32 + ax0_ax1_ax2_ax3_fused % 37) + v3 = T.axis.spatial(3, rc_0) T.reads(data[v0, v1 - 3, v2 - 3, v3]) T.writes(PaddedInput_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch":1}) - PaddedInput_shared[v0, v1, v2, v3] = T.if_then_else(3 <= v1 and v1 < 227 and 3 <= v2 and v2 < 227, data[v0, v1 - 3, v2 - 3, v3], T.float32(0), dtype="float32") - for ax0_ax1_ax2_ax3_fused in T.serial(224): + T.block_attr({"meta_schedule.cooperative_fetch": 1}) + PaddedInput_shared[v0, v1, v2, v3] = T.if_then_else(3 <= v1 and v1 < 227 and 3 <= v2 and v2 < 227, data[v0, v1 - 3, v2 - 3, v3], T.float32(0)) + for ax0_ax1_ax2_ax3_fused in range(224): with T.block("kernel_shared"): - v0 = T.axis.spatial(7, i4_0) + v0 = T.axis.spatial(7, ry_0) v1 = T.axis.spatial(7, ax0_ax1_ax2_ax3_fused // 32) - v2 = T.axis.spatial(3, i6_0) - v3 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + ax0_ax1_ax2_ax3_fused % 32) + v2 = T.axis.spatial(3, rc_0) + v3 = T.axis.spatial(64, nn_0_yy_0_xx_0_ff_0_fused % 2 * 32 + ax0_ax1_ax2_ax3_fused % 32) T.reads(kernel[v0, v1, v2, v3]) T.writes(kernel_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch":1}) + T.block_attr({"meta_schedule.cooperative_fetch": 1}) kernel_shared[v0, v1, v2, v3] = kernel[v0, v1, v2, v3] - for i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(1, 1, 1, 1, 1, 1, 2, 1, 7, 1, 1, 7, 1, 8): + for ry_1, rx_1, rc_1, nn_3, yy_3, xx_3, ff_3, ry_2, rx_2, rc_2, nn_4, yy_4, xx_4, ff_4 in T.grid(1, 1, 1, 1, 1, 1, 2, 1, 7, 1, 1, 7, 1, 8): with T.block("Conv2dOutput"): - nn = T.axis.spatial(1, i0_3 + i0_4) - yy = T.axis.spatial(112, i0_1_i1_1_i2_1_i3_1_fused // 2 * 56 + i0_2_i1_2_i2_2_i3_2_fused // 16 * 7 + i1_3 * 7 + i1_4) - xx = T.axis.spatial(112, i2_4 + i0_0_i1_0_i2_0_i3_0_fused // 2 * 16 + i0_2_i1_2_i2_2_i3_2_fused % 16 + i2_3) - ff = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + i0_1_i1_1_i2_1_i3_1_fused % 2 * 16 + i3_3 * 8 + i3_4) - ry = T.axis.reduce(7, i4_0 + i4_1 + i4_2) - rx = T.axis.reduce(7, i5_0 * 7 + i5_1 * 7 + i5_2) - rc = T.axis.reduce(3, i6_1 + i6_2 + i6_0) - T.reads(PaddedInput_shared[nn, yy * 2 + ry, xx * 2 + rx, rc], kernel_shared[ry, rx, rc, ff]) - T.writes(Conv2dOutput_local[nn, yy, xx, ff]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) + v_nn = T.axis.spatial(1, nn_3 + nn_4) + v_yy = T.axis.spatial(112, nn_1_yy_1_xx_1_ff_1_fused // 2 * 56 + nn_2_yy_2_xx_2_ff_2_fused // 16 * 7 + yy_3 * 7 + yy_4) + v_xx = T.axis.spatial(112, nn_0_yy_0_xx_0_ff_0_fused // 2 * 16 + nn_2_yy_2_xx_2_ff_2_fused % 16 + xx_3 + xx_4) + v_ff = T.axis.spatial(64, nn_0_yy_0_xx_0_ff_0_fused % 2 * 32 + nn_1_yy_1_xx_1_ff_1_fused % 2 * 16 + ff_3 * 8 + ff_4) + v_ry = T.axis.reduce(7, ry_0 + ry_1 + ry_2) + v_rx = T.axis.reduce(7, rx_0 * 7 + rx_1 * 7 + rx_2) + v_rc = T.axis.reduce(3, rc_0 + rc_1 + rc_2) + T.reads(PaddedInput_shared[v_nn, v_yy * 2 + v_ry, v_xx * 2 + v_rx, v_rc], kernel_shared[v_ry, v_rx, v_rc, v_ff]) + T.writes(Conv2dOutput_local[v_nn, v_yy, v_xx, v_ff]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): - Conv2dOutput_local[nn, yy, xx, ff] = T.float32(0) - Conv2dOutput_local[nn, yy, xx, ff] = Conv2dOutput_local[nn, yy, xx, ff] + PaddedInput_shared[nn, yy * 2 + ry, xx * 2 + rx, rc] * kernel_shared[ry, rx, rc, ff] + Conv2dOutput_local[v_nn, v_yy, v_xx, v_ff] = T.float32(0) + Conv2dOutput_local[v_nn, v_yy, v_xx, v_ff] = Conv2dOutput_local[v_nn, v_yy, v_xx, v_ff] + PaddedInput_shared[v_nn, v_yy * 2 + v_ry, v_xx * 2 + v_rx, v_rc] * kernel_shared[v_ry, v_rx, v_rc, v_ff] for ax0, ax1, ax2, ax3 in T.grid(1, 7, 1, 16): with T.block("Conv2dOutput_local"): v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(112, i0_1_i1_1_i2_1_i3_1_fused // 2 * 56 + i0_2_i1_2_i2_2_i3_2_fused // 16 * 7 + ax1) - v2 = T.axis.spatial(112, i0_0_i1_0_i2_0_i3_0_fused // 2 * 16 + i0_2_i1_2_i2_2_i3_2_fused % 16 + ax2) - v3 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + i0_1_i1_1_i2_1_i3_1_fused % 2 * 16 + ax3) + v1 = T.axis.spatial(112, nn_1_yy_1_xx_1_ff_1_fused // 2 * 56 + nn_2_yy_2_xx_2_ff_2_fused // 16 * 7 + ax1) + v2 = T.axis.spatial(112, nn_0_yy_0_xx_0_ff_0_fused // 2 * 16 + nn_2_yy_2_xx_2_ff_2_fused % 16 + ax2) + v3 = T.axis.spatial(64, nn_0_yy_0_xx_0_ff_0_fused % 2 * 32 + nn_1_yy_1_xx_1_ff_1_fused % 2 * 16 + ax3) T.reads(Conv2dOutput_local[v0, v1, v2, v3], bias[v3], bn_scale[v3], bn_offset[v3]) T.writes(compute[v0, v1, v2, v3]) compute[v0, v1, v2, v3] = T.max((Conv2dOutput_local[v0, v1, v2, v3] + bias[v3]) * bn_scale[v3] + bn_offset[v3], T.float32(0)) @@ -1174,57 +1146,57 @@ def test_cuda_tbg(): # fmt: off @T.prim_func def tbg_0(query: T.Buffer((1, 128, 12, 64), "float32"), value: T.Buffer((1, 128, 12, 64), "float32"), C: T.Buffer((1, 12, 128, 128), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit":1024}) - C_local = T.alloc_buffer([1, 12, 128, 128], dtype="float32", scope="local") - query_T_shared = T.alloc_buffer([1, 12, 128, 64], dtype="float32", scope="shared") - value_T_shared = T.alloc_buffer([1, 12, 64, 128], dtype="float32", scope="shared") - for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(4, thread="blockIdx.x"): - for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(192, thread="vthread.x"): - for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(32, thread="threadIdx.x"): - for i4_0 in T.serial(8): - for ax0_ax1_ax2_ax3_fused in T.serial(12288): + T.block_attr({"meta_schedule.unroll_explicit": 1024}) + C_local = T.alloc_buffer((1, 12, 128, 128), scope="local") + query_T_shared = T.alloc_buffer((1, 12, 128, 64), scope="shared") + value_T_shared = T.alloc_buffer((1, 12, 64, 128), scope="shared") + for b_0_h_0_i_0_j_0_fused in T.thread_binding(4, thread="blockIdx.x"): + for b_1_h_1_i_1_j_1_fused in T.thread_binding(192, thread="vthread.x"): + for b_2_h_2_i_2_j_2_fused in T.thread_binding(32, thread="threadIdx.x"): + for k_0 in range(8): + for ax0_ax1_ax2_ax3_fused in range(12288): with T.block("query_T_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(12, ax0_ax1_ax2_ax3_fused // 1024) v2 = T.axis.spatial(128, ax0_ax1_ax2_ax3_fused % 1024 // 8) - v3 = T.axis.spatial(64, i4_0 * 8 + ax0_ax1_ax2_ax3_fused % 8) + v3 = T.axis.spatial(64, k_0 * 8 + ax0_ax1_ax2_ax3_fused % 8) T.reads(query[v0, v2, v1, v3]) T.writes(query_T_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch":3}) + T.block_attr({"meta_schedule.cooperative_fetch": 3}) query_T_shared[v0, v1, v2, v3] = query[v0, v2, v1, v3] - for ax0_ax1_ax2_ax3_fused in T.serial(3072): + for ax0_ax1_ax2_ax3_fused in range(3072): with T.block("value_T_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(12, ax0_ax1_ax2_ax3_fused // 256) - v2 = T.axis.spatial(64, i4_0 * 8 + ax0_ax1_ax2_ax3_fused % 256 // 32) - v3 = T.axis.spatial(128, i0_0_i1_0_i2_0_i3_0_fused * 32 + ax0_ax1_ax2_ax3_fused % 32) + v2 = T.axis.spatial(64, k_0 * 8 + ax0_ax1_ax2_ax3_fused % 256 // 32) + v3 = T.axis.spatial(128, b_0_h_0_i_0_j_0_fused * 32 + ax0_ax1_ax2_ax3_fused % 32) T.reads(value[v0, v3, v1, v2]) T.writes(value_T_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch":4}) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) value_T_shared[v0, v1, v2, v3] = value[v0, v3, v1, v2] - for i4_1, i0_3, i1_3, i2_3, i3_3, i4_2, i0_4, i1_4, i2_4, i3_4 in T.grid(4, 1, 2, 1, 1, 2, 1, 1, 4, 1): + for k_1, b_3, h_3, i_3, j_3, k_2, b_4, h_4, i_4, j_4 in T.grid(4, 1, 2, 1, 1, 2, 1, 1, 4, 1): with T.block("C"): - b = T.axis.spatial(1, i0_4 + i0_3) - h = T.axis.spatial(12, i1_4 + i0_1_i1_1_i2_1_i3_1_fused // 32 * 2 + i1_3) - i = T.axis.spatial(128, i0_1_i1_1_i2_1_i3_1_fused % 32 // 8 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 4 * 4 + i2_3 * 4 + i2_4) - j = T.axis.spatial(128, i3_4 + i0_0_i1_0_i2_0_i3_0_fused * 32 + i0_1_i1_1_i2_1_i3_1_fused % 8 * 4 + i0_2_i1_2_i2_2_i3_2_fused % 4 + i3_3) - k = T.axis.reduce(64, i4_0 * 8 + i4_1 * 2 + i4_2) - T.reads(query_T_shared[b, h, i, k], value_T_shared[b, h, k, j]) - T.writes(C_local[b, h, i, j]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) + v_b = T.axis.spatial(1, b_3 + b_4) + v_h = T.axis.spatial(12, b_1_h_1_i_1_j_1_fused // 32 * 2 + h_3 + h_4) + v_i = T.axis.spatial(128, b_1_h_1_i_1_j_1_fused % 32 // 8 * 32 + b_2_h_2_i_2_j_2_fused // 4 * 4 + i_3 * 4 + i_4) + v_j = T.axis.spatial(128, b_0_h_0_i_0_j_0_fused * 32 + b_1_h_1_i_1_j_1_fused % 8 * 4 + b_2_h_2_i_2_j_2_fused % 4 + j_3 + j_4) + v_k = T.axis.reduce(64, k_0 * 8 + k_1 * 2 + k_2) + T.reads(query_T_shared[v_b, v_h, v_i, v_k], value_T_shared[v_b, v_h, v_k, v_j]) + T.writes(C_local[v_b, v_h, v_i, v_j]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): - C_local[b, h, i, j] = T.float32(0) - C_local[b, h, i, j] = C_local[b, h, i, j] + query_T_shared[b, h, i, k] * value_T_shared[b, h, k, j] + C_local[v_b, v_h, v_i, v_j] = T.float32(0) + C_local[v_b, v_h, v_i, v_j] = C_local[v_b, v_h, v_i, v_j] + query_T_shared[v_b, v_h, v_i, v_k] * value_T_shared[v_b, v_h, v_k, v_j] for ax0, ax1, ax2, ax3 in T.grid(1, 2, 4, 1): with T.block("C_local"): v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(12, i0_1_i1_1_i2_1_i3_1_fused // 32 * 2 + ax1) - v2 = T.axis.spatial(128, i0_1_i1_1_i2_1_i3_1_fused % 32 // 8 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 4 * 4 + ax2) - v3 = T.axis.spatial(128, i0_0_i1_0_i2_0_i3_0_fused * 32 + i0_1_i1_1_i2_1_i3_1_fused % 8 * 4 + i0_2_i1_2_i2_2_i3_2_fused % 4 + ax3) + v1 = T.axis.spatial(12, b_1_h_1_i_1_j_1_fused // 32 * 2 + ax1) + v2 = T.axis.spatial(128, b_1_h_1_i_1_j_1_fused % 32 // 8 * 32 + b_2_h_2_i_2_j_2_fused // 4 * 4 + ax2) + v3 = T.axis.spatial(128, b_0_h_0_i_0_j_0_fused * 32 + b_1_h_1_i_1_j_1_fused % 8 * 4 + b_2_h_2_i_2_j_2_fused % 4 + ax3) T.reads(C_local[v0, v1, v2, v3]) T.writes(C[v0, v1, v2, v3]) C[v0, v1, v2, v3] = C_local[v0, v1, v2, v3] diff --git a/tests/python/unittest/test_meta_schedule_space_cuda_async.py b/tests/python/unittest/test_meta_schedule_space_cuda_async.py index d31d62669687..8ea067d24301 100644 --- a/tests/python/unittest/test_meta_schedule_space_cuda_async.py +++ b/tests/python/unittest/test_meta_schedule_space_cuda_async.py @@ -44,7 +44,7 @@ def get_c2d_prim_func(stage: int): # fmt: off @T.prim_func def c2d(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), "float32")): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() @@ -81,10 +81,10 @@ def c2d(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3 v_n = T.axis.spatial(1, n_3 + n_4) v_h = T.axis.spatial(112, n_0_h_0_w_0_co_0_fused // 8 * 8 + n_1_h_1_w_1_co_1_fused // 4 * 4 + n_2_h_2_w_2_co_2_fused // 16 + h_3 + h_4) v_w = T.axis.spatial(112, n_0_h_0_w_0_co_0_fused % 8 * 14 + w_3 + w_4) - v_co = T.axis.spatial(64, co_3 + co_4 + n_1_h_1_w_1_co_1_fused % 4 * 16 + n_2_h_2_w_2_co_2_fused % 16) + v_co = T.axis.spatial(64, n_1_h_1_w_1_co_1_fused % 4 * 16 + n_2_h_2_w_2_co_2_fused % 16 + co_3 + co_4) v_rh = T.axis.reduce(7, rh_0 * 7 + rh_1 + rh_2) v_rw = T.axis.reduce(7, rw_0 * 7 + rw_1 * 7 + rw_2) - v_rc = T.axis.reduce(3, rc_1 + rc_2 + rc_0) + v_rc = T.axis.reduce(3, rc_0 + rc_1 + rc_2) T.reads(PadInput_shared[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc], weight_shared[v_rh, v_rw, v_rc, v_co]) T.writes(conv2d_nhwc_local[v_n, v_h, v_w, v_co]) T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) @@ -100,13 +100,12 @@ def c2d(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3 T.reads(conv2d_nhwc_local[v0, v1, v2, v3]) T.writes(conv2d_nhwc[v0, v1, v2, v3]) conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_local[v0, v1, v2, v3] - # fmt: on else: # fmt: off @T.prim_func def c2d(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), "float32")): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() @@ -140,13 +139,13 @@ def c2d(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3 weight_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3] for rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3, rh_2, rw_2, rc_2, n_4, h_4, w_4, co_4 in T.grid(7, 1, 1, 1, 1, 14, 1, 1, 7, 1, 1, 1, 1, 1): with T.block("conv2d_nhwc"): - v_n = T.axis.spatial(1, n_4 + n_3) - v_h = T.axis.spatial(112, h_3 + h_4 + n_0_h_0_w_0_co_0_fused // 8 * 8 + n_1_h_1_w_1_co_1_fused // 4 * 4 + n_2_h_2_w_2_co_2_fused // 16) - v_w = T.axis.spatial(112, w_4 + n_0_h_0_w_0_co_0_fused % 8 * 14 + w_3) - v_co = T.axis.spatial(64, co_4 + n_1_h_1_w_1_co_1_fused % 4 * 16 + n_2_h_2_w_2_co_2_fused % 16 + co_3) - v_rh = T.axis.reduce(7, rh_2 + rh_1) + v_n = T.axis.spatial(1, n_3 + n_4) + v_h = T.axis.spatial(112, n_0_h_0_w_0_co_0_fused // 8 * 8 + n_1_h_1_w_1_co_1_fused // 4 * 4 + n_2_h_2_w_2_co_2_fused // 16 + h_3 + h_4) + v_w = T.axis.spatial(112, n_0_h_0_w_0_co_0_fused % 8 * 14 + w_3 + w_4) + v_co = T.axis.spatial(64, n_1_h_1_w_1_co_1_fused % 4 * 16 + n_2_h_2_w_2_co_2_fused % 16 + co_3 + co_4) + v_rh = T.axis.reduce(7, rh_1 + rh_2) v_rw = T.axis.reduce(7, rw_1 * 7 + rw_2) - v_rc = T.axis.reduce(3, rc_2 + rh_0_rw_0_rc_0_fused + rc_1) + v_rc = T.axis.reduce(3, rh_0_rw_0_rc_0_fused + rc_1 + rc_2) T.reads(PadInput_shared[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc], weight_shared[v_rh, v_rw, v_rc, v_co]) T.writes(conv2d_nhwc_local[v_n, v_h, v_w, v_co]) T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) @@ -198,114 +197,112 @@ def get_gmm_prim_func(stage: int): if stage == 0: # fmt: off @T.prim_func - def gmm(A: T.Buffer((1, 1024, 1024), "float32"), B: T.Buffer((1, 1024, 1024), "float32"), Y: T.Buffer((1, 1024, 1024), "float32")): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + def gmm(X: T.Buffer((1, 1024, 1024), "float32"), Y: T.Buffer((1, 1024, 1024), "float32"), Z: T.Buffer((1, 1024, 1024), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() T.block_attr({"meta_schedule.unroll_explicit": 16}) - Y_local = T.alloc_buffer((1, 1024, 1024), scope="local") - A_shared = T.alloc_buffer((1, 1024, 1024), scope="shared") - B_shared = T.alloc_buffer((1, 1024, 1024), scope="shared") + Z_local = T.alloc_buffer((1, 1024, 1024), scope="local") + X_shared = T.alloc_buffer((1, 1024, 1024), scope="shared") + Y_shared = T.alloc_buffer((1, 1024, 1024), scope="shared") for b_0_i_0_j_0_fused in T.thread_binding(256, thread="blockIdx.x"): for b_1_i_1_j_1_fused in T.thread_binding(32, thread="vthread.x"): for b_2_i_2_j_2_fused in T.thread_binding(64, thread="threadIdx.x"): for k_0 in range(64): for ax0_ax1_ax2_fused in range(1024): - with T.block("A_shared"): + with T.block("X_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(1024, b_0_i_0_j_0_fused // 16 * 64 + ax0_ax1_ax2_fused // 16) v2 = T.axis.spatial(1024, k_0 * 16 + ax0_ax1_ax2_fused % 16) - T.reads(A[v0, v1, v2]) - T.writes(A_shared[v0, v1, v2]) + T.reads(X[v0, v1, v2]) + T.writes(X_shared[v0, v1, v2]) T.block_attr({"meta_schedule.cooperative_fetch": 4}) - A_shared[v0, v1, v2] = A[v0, v1, v2] + X_shared[v0, v1, v2] = X[v0, v1, v2] for ax0_ax1_ax2_fused in range(1024): - with T.block("B_shared"): + with T.block("Y_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(1024, k_0 * 16 + ax0_ax1_ax2_fused // 64) v2 = T.axis.spatial(1024, b_0_i_0_j_0_fused % 16 * 64 + ax0_ax1_ax2_fused % 64) - T.reads(B[v0, v1, v2]) - T.writes(B_shared[v0, v1, v2]) + T.reads(Y[v0, v1, v2]) + T.writes(Y_shared[v0, v1, v2]) T.block_attr({"meta_schedule.cooperative_fetch": 4}) - B_shared[v0, v1, v2] = B[v0, v1, v2] + Y_shared[v0, v1, v2] = Y[v0, v1, v2] for k_1, b_3, i_3, j_3, k_2, b_4, i_4, j_4 in T.grid(2, 1, 1, 1, 8, 1, 1, 2): - with T.block("Y"): - v_b = T.axis.spatial(1, b_4 + b_3) + with T.block("Z"): + v_b = T.axis.spatial(1, b_3 + b_4) v_i = T.axis.spatial(1024, b_0_i_0_j_0_fused // 16 * 64 + b_1_i_1_j_1_fused // 4 * 8 + b_2_i_2_j_2_fused // 8 + i_3 + i_4) v_j = T.axis.spatial(1024, b_0_i_0_j_0_fused % 16 * 64 + b_1_i_1_j_1_fused % 4 * 16 + b_2_i_2_j_2_fused % 8 * 2 + j_3 * 2 + j_4) v_k = T.axis.reduce(1024, k_0 * 16 + k_1 * 8 + k_2) - T.reads(A_shared[v_b, v_i, v_k], B_shared[v_b, v_k, v_j]) - T.writes(Y_local[v_b, v_i, v_j]) + T.reads(X_shared[v_b, v_i, v_k], Y_shared[v_b, v_k, v_j]) + T.writes(Z_local[v_b, v_i, v_j]) T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): - Y_local[v_b, v_i, v_j] = T.float32(0) - Y_local[v_b, v_i, v_j] = Y_local[v_b, v_i, v_j] + A_shared[v_b, v_i, v_k] * B_shared[v_b, v_k, v_j] + Z_local[v_b, v_i, v_j] = T.float32(0) + Z_local[v_b, v_i, v_j] = Z_local[v_b, v_i, v_j] + X_shared[v_b, v_i, v_k] * Y_shared[v_b, v_k, v_j] for ax0, ax1, ax2 in T.grid(1, 1, 2): - with T.block("Y_local"): + with T.block("Z_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(1024, b_0_i_0_j_0_fused // 16 * 64 + b_1_i_1_j_1_fused // 4 * 8 + b_2_i_2_j_2_fused // 8 + ax1) v2 = T.axis.spatial(1024, b_0_i_0_j_0_fused % 16 * 64 + b_1_i_1_j_1_fused % 4 * 16 + b_2_i_2_j_2_fused % 8 * 2 + ax2) - T.reads(Y_local[v0, v1, v2]) - T.writes(Y[v0, v1, v2]) - Y[v0, v1, v2] = Y_local[v0, v1, v2] - + T.reads(Z_local[v0, v1, v2]) + T.writes(Z[v0, v1, v2]) + Z[v0, v1, v2] = Z_local[v0, v1, v2] # fmt: on else: # fmt: off @T.prim_func - def gmm(A: T.Buffer((1, 1024, 1024), "float32"), B: T.Buffer((1, 1024, 1024), "float32"), Y: T.Buffer((1, 1024, 1024), "float32")): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + def gmm(X: T.Buffer((1, 1024, 1024), "float32"), Y: T.Buffer((1, 1024, 1024), "float32"), Z: T.Buffer((1, 1024, 1024), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() T.block_attr({"meta_schedule.unroll_explicit": 16}) - Y_local = T.alloc_buffer((1, 1024, 1024), scope="local") - A_shared = T.alloc_buffer((1, 1024, 1024), scope="shared") - B_shared = T.alloc_buffer((1, 1024, 1024), scope="shared") + Z_local = T.alloc_buffer((1, 1024, 1024), scope="local") + X_shared = T.alloc_buffer((1, 1024, 1024), scope="shared") + Y_shared = T.alloc_buffer((1, 1024, 1024), scope="shared") for b_0_i_0_j_0_fused in T.thread_binding(256, thread="blockIdx.x"): for b_1_i_1_j_1_fused in T.thread_binding(32, thread="vthread.x"): for b_2_i_2_j_2_fused in T.thread_binding(64, thread="threadIdx.x"): for k_0_fused in T.serial(64, annotations={"software_pipeline_async_stages": [0], "software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, stage - 2]}): for ax0_ax1_ax2_fused in range(1024): - with T.block("A_shared"): + with T.block("X_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(1024, b_0_i_0_j_0_fused // 16 * 64 + ax0_ax1_ax2_fused // 16) v2 = T.axis.spatial(1024, k_0_fused * 16 + ax0_ax1_ax2_fused % 16) - T.reads(A[v0, v1, v2]) - T.writes(A_shared[v0, v1, v2]) + T.reads(X[v0, v1, v2]) + T.writes(X_shared[v0, v1, v2]) T.block_attr({"meta_schedule.cooperative_fetch": 4}) - A_shared[v0, v1, v2] = A[v0, v1, v2] + X_shared[v0, v1, v2] = X[v0, v1, v2] for ax0_ax1_ax2_fused in range(1024): - with T.block("B_shared"): + with T.block("Y_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(1024, k_0_fused * 16 + ax0_ax1_ax2_fused // 64) v2 = T.axis.spatial(1024, b_0_i_0_j_0_fused % 16 * 64 + ax0_ax1_ax2_fused % 64) - T.reads(B[v0, v1, v2]) - T.writes(B_shared[v0, v1, v2]) + T.reads(Y[v0, v1, v2]) + T.writes(Y_shared[v0, v1, v2]) T.block_attr({"meta_schedule.cooperative_fetch": 4}) - B_shared[v0, v1, v2] = B[v0, v1, v2] + Y_shared[v0, v1, v2] = Y[v0, v1, v2] for k_1, b_3, i_3, j_3, k_2, b_4, i_4, j_4 in T.grid(2, 1, 1, 1, 8, 1, 1, 2): - with T.block("Y"): + with T.block("Z"): v_b = T.axis.spatial(1, b_3 + b_4) - v_i = T.axis.spatial(1024, i_3 + i_4 + b_0_i_0_j_0_fused // 16 * 64 + b_1_i_1_j_1_fused // 4 * 8 + b_2_i_2_j_2_fused // 8) + v_i = T.axis.spatial(1024, b_0_i_0_j_0_fused // 16 * 64 + b_1_i_1_j_1_fused // 4 * 8 + b_2_i_2_j_2_fused // 8 + i_3 + i_4) v_j = T.axis.spatial(1024, b_0_i_0_j_0_fused % 16 * 64 + b_1_i_1_j_1_fused % 4 * 16 + b_2_i_2_j_2_fused % 8 * 2 + j_3 * 2 + j_4) v_k = T.axis.reduce(1024, k_0_fused * 16 + k_1 * 8 + k_2) - T.reads(A_shared[v_b, v_i, v_k], B_shared[v_b, v_k, v_j]) - T.writes(Y_local[v_b, v_i, v_j]) + T.reads(X_shared[v_b, v_i, v_k], Y_shared[v_b, v_k, v_j]) + T.writes(Z_local[v_b, v_i, v_j]) T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): - Y_local[v_b, v_i, v_j] = T.float32(0) - Y_local[v_b, v_i, v_j] = Y_local[v_b, v_i, v_j] + A_shared[v_b, v_i, v_k] * B_shared[v_b, v_k, v_j] + Z_local[v_b, v_i, v_j] = T.float32(0) + Z_local[v_b, v_i, v_j] = Z_local[v_b, v_i, v_j] + X_shared[v_b, v_i, v_k] * Y_shared[v_b, v_k, v_j] for ax0, ax1, ax2 in T.grid(1, 1, 2): - with T.block("Y_local"): + with T.block("Z_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(1024, b_0_i_0_j_0_fused // 16 * 64 + b_1_i_1_j_1_fused // 4 * 8 + b_2_i_2_j_2_fused // 8 + ax1) v2 = T.axis.spatial(1024, b_0_i_0_j_0_fused % 16 * 64 + b_1_i_1_j_1_fused % 4 * 16 + b_2_i_2_j_2_fused % 8 * 2 + ax2) - T.reads(Y_local[v0, v1, v2]) - T.writes(Y[v0, v1, v2]) - Y[v0, v1, v2] = Y_local[v0, v1, v2] - + T.reads(Z_local[v0, v1, v2]) + T.writes(Z[v0, v1, v2]) + Z[v0, v1, v2] = Z_local[v0, v1, v2] # fmt: on return gmm diff --git a/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py b/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py index e8ed3bb8b2a1..27fe47ab8699 100644 --- a/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py +++ b/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py @@ -43,131 +43,129 @@ def test_cuda_nhwc(): # fmt: off @T.prim_func def cuda_nhwc_0(data: T.Buffer((1, 14, 14, 128), "float32"), weight: T.Buffer((6, 6, 128, 128), "float32"), conv2d_winograd: T.Buffer((1, 12, 12, 128), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) - # body + T.func_attr({"global_symbol": "main", "layout_free_buffers": [1], "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit":16}) - input_tile_local = T.alloc_buffer([6, 6, 9, 128], dtype="float32", scope="local") - data_pack = T.alloc_buffer([6, 6, 9, 128], dtype="float32") - bgemm = T.alloc_buffer([6, 6, 9, 128], dtype="float32") - inverse = T.alloc_buffer([4, 4, 9, 128], dtype="float32") - data_pack_local = T.alloc_buffer([6, 6, 9, 128], dtype="float32", scope="local") - bgemm_local = T.alloc_buffer([6, 6, 9, 128], dtype="float32", scope="local") - data_pack_shared = T.alloc_buffer([6, 6, 9, 128], dtype="float32", scope="shared") - weight_shared = T.alloc_buffer([6, 6, 128, 128], dtype="float32", scope="shared") - for i2_0_i3_0_i2_1_i3_1_fused_0 in T.thread_binding(2, thread="blockIdx.x"): - for i2_0_i3_0_i2_1_i3_1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): + T.block_attr({"meta_schedule.unroll_explicit": 16}) + input_tile_local = T.alloc_buffer((6, 6, 9, 128), scope="local") + data_pack = T.alloc_buffer((6, 6, 9, 128)) + bgemm = T.alloc_buffer((6, 6, 9, 128)) + inverse = T.alloc_buffer((4, 4, 9, 128)) + data_pack_local = T.alloc_buffer((6, 6, 9, 128), scope="local") + bgemm_local = T.alloc_buffer((6, 6, 9, 128), scope="local") + data_pack_shared = T.alloc_buffer((6, 6, 9, 128), scope="shared") + weight_shared = T.alloc_buffer((6, 6, 128, 128), scope="shared") + for p_0_ci_0_p_1_ci_1_fused_0 in T.thread_binding(2, thread="blockIdx.x"): + for p_0_ci_0_p_1_ci_1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): for ax0, ax1, ax2, ax3 in T.grid(6, 6, 1, 1): with T.block("input_tile"): - T.where(i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1 < 1152) - eps, nu = T.axis.remap("SS", [ax0, ax1]) - p = T.axis.spatial(9, (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1) // 384 * 3 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1) % 24 // 8 + ax2) - ci = T.axis.spatial(128, (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1) % 384 // 24 * 8 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1) % 8 + ax3) - T.reads(data[p // 9, p % 9 // 3 * 4 + eps, p % 3 * 4 + nu, ci]) - T.writes(input_tile_local[eps, nu, p, ci]) - T.block_attr({"schedule_rule":"None"}) - input_tile_local[eps, nu, p, ci] = T.if_then_else(0 <= p % 9 // 3 * 4 + eps and p % 9 // 3 * 4 + eps < 14 and 0 <= p % 3 * 4 + nu and p % 3 * 4 + nu < 14, data[p // 9, p % 9 // 3 * 4 + eps, p % 3 * 4 + nu, ci], T.float32(0), dtype="float32") - for i0 in T.unroll(6): - for i1 in T.unroll(6): - for i4 in T.unroll(6): - for i5 in T.unroll(6): + v_eps, v_nu = T.axis.remap("SS", [ax0, ax1]) + v_p = T.axis.spatial(9, (p_0_ci_0_p_1_ci_1_fused_0 * 1024 + p_0_ci_0_p_1_ci_1_fused_1) // 384 * 3 + (p_0_ci_0_p_1_ci_1_fused_0 * 1024 + p_0_ci_0_p_1_ci_1_fused_1) % 24 // 8 + ax2) + v_ci = T.axis.spatial(128, (p_0_ci_0_p_1_ci_1_fused_0 * 1024 + p_0_ci_0_p_1_ci_1_fused_1) % 384 // 24 * 8 + (p_0_ci_0_p_1_ci_1_fused_0 * 1024 + p_0_ci_0_p_1_ci_1_fused_1) % 8 + ax3) + T.where(p_0_ci_0_p_1_ci_1_fused_0 * 1024 + p_0_ci_0_p_1_ci_1_fused_1 < 1152) + T.reads(data[v_p // 9, v_p % 9 // 3 * 4 + v_eps, v_p % 3 * 4 + v_nu, v_ci]) + T.writes(input_tile_local[v_eps, v_nu, v_p, v_ci]) + T.block_attr({"schedule_rule": "None"}) + input_tile_local[v_eps, v_nu, v_p, v_ci] = T.if_then_else(0 <= v_p % 9 // 3 * 4 + v_eps and v_p % 9 // 3 * 4 + v_eps < 14 and 0 <= v_p % 3 * 4 + v_nu and v_p % 3 * 4 + v_nu < 14, data[v_p // 9, v_p % 9 // 3 * 4 + v_eps, v_p % 3 * 4 + v_nu, v_ci], T.float32(0)) + for eps in T.unroll(6): + for nu in T.unroll(6): + for r_a in T.unroll(6): + for r_b in T.unroll(6): with T.block("data_pack"): - T.where(i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1 < 1152) - eps, nu = T.axis.remap("SS", [i0, i1]) - p = T.axis.spatial(9, (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1) // 384 * 3 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1) % 24 // 8) - ci = T.axis.spatial(128, (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1) % 384 // 24 * 8 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1) % 8) - r_a, r_b = T.axis.remap("RR", [i4, i5]) - T.reads(input_tile_local[r_a, r_b, p, ci]) - T.writes(data_pack_local[eps, nu, p, ci]) - T.block_attr({"schedule_rule":"conv2d_nhwc_winograd_data_pack"}) + v_eps, v_nu = T.axis.remap("SS", [eps, nu]) + v_p = T.axis.spatial(9, (p_0_ci_0_p_1_ci_1_fused_0 * 1024 + p_0_ci_0_p_1_ci_1_fused_1) // 384 * 3 + (p_0_ci_0_p_1_ci_1_fused_0 * 1024 + p_0_ci_0_p_1_ci_1_fused_1) % 24 // 8) + v_ci = T.axis.spatial(128, (p_0_ci_0_p_1_ci_1_fused_0 * 1024 + p_0_ci_0_p_1_ci_1_fused_1) % 384 // 24 * 8 + (p_0_ci_0_p_1_ci_1_fused_0 * 1024 + p_0_ci_0_p_1_ci_1_fused_1) % 8) + v_r_a, v_r_b = T.axis.remap("RR", [r_a, r_b]) + T.where(p_0_ci_0_p_1_ci_1_fused_0 * 1024 + p_0_ci_0_p_1_ci_1_fused_1 < 1152) + T.reads(input_tile_local[v_r_a, v_r_b, v_p, v_ci]) + T.writes(data_pack_local[v_eps, v_nu, v_p, v_ci]) + T.block_attr({"schedule_rule": "conv2d_nhwc_winograd_data_pack"}) with T.init(): - data_pack_local[eps, nu, p, ci] = T.float32(0) - data_pack_local[eps, nu, p, ci] = data_pack_local[eps, nu, p, ci] + input_tile_local[r_a, r_b, p, ci] * T.Select(r_a % 6 == 5 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 5 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 0, T.float32(0), T.Select(r_a % 6 == 4 and eps % 6 == 5, T.float32(1.5), T.Select(r_a % 6 == 4 and eps % 6 == 4, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 3, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 2, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 0, T.float32(1), T.Select(r_a % 6 == 3 and eps % 6 == 5, T.float32(-2), T.Select(r_a % 6 == 3 and eps % 6 == 4, T.float32(-0.5), T.Select(r_a % 6 == 3 and eps % 6 == 3, T.float32(2), T.Select(r_a % 6 == 3 and eps % 6 == 2, T.float32(2.5), T.Select(r_a % 6 == 3 and eps % 6 == 1, T.float32(0.5), T.Select(r_a % 6 == 3 and eps % 6 == 0, T.float32(1.5), T.Select(r_a % 6 == 2 and eps % 6 == 5, T.float32(-1.5), T.Select(r_a % 6 == 2 and eps % 6 == 4, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 3, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 2, T.float32(0.5), T.Select(r_a % 6 == 2 and eps % 6 == 1, T.float32(-2.5), T.Select(r_a % 6 == 2 and eps % 6 == 0, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 4, T.float32(0.5), T.Select(r_a % 6 == 1 and eps % 6 == 3, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 2, T.float32(-1), T.Select(r_a % 6 == 1 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 0, T.float32(-1.5), T.Select(r_a % 6 == 0 and eps % 6 == 5, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) * T.Select(r_b % 6 == 5 and nu % 6 == 5, T.float32(1), T.Select(r_b % 6 == 5 and nu % 6 == 4, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 3, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 2, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 1, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 0, T.float32(0), T.Select(r_b % 6 == 4 and nu % 6 == 5, T.float32(1.5), T.Select(r_b % 6 == 4 and nu % 6 == 4, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 3, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 2, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 1, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 0, T.float32(1), T.Select(r_b % 6 == 3 and nu % 6 == 5, T.float32(-2), T.Select(r_b % 6 == 3 and nu % 6 == 4, T.float32(-0.5), T.Select(r_b % 6 == 3 and nu % 6 == 3, T.float32(2), T.Select(r_b % 6 == 3 and nu % 6 == 2, T.float32(2.5), T.Select(r_b % 6 == 3 and nu % 6 == 1, T.float32(0.5), T.Select(r_b % 6 == 3 and nu % 6 == 0, T.float32(1.5), T.Select(r_b % 6 == 2 and nu % 6 == 5, T.float32(-1.5), T.Select(r_b % 6 == 2 and nu % 6 == 4, T.float32(-1), T.Select(r_b % 6 == 2 and nu % 6 == 3, T.float32(-1), T.Select(r_b % 6 == 2 and nu % 6 == 2, T.float32(0.5), T.Select(r_b % 6 == 2 and nu % 6 == 1, T.float32(-2.5), T.Select(r_b % 6 == 2 and nu % 6 == 0, T.float32(-2), T.Select(r_b % 6 == 1 and nu % 6 == 5, T.float32(1), T.Select(r_b % 6 == 1 and nu % 6 == 4, T.float32(0.5), T.Select(r_b % 6 == 1 and nu % 6 == 3, T.float32(-2), T.Select(r_b % 6 == 1 and nu % 6 == 2, T.float32(-1), T.Select(r_b % 6 == 1 and nu % 6 == 1, T.float32(1), T.Select(r_b % 6 == 1 and nu % 6 == 0, T.float32(-1.5), T.Select(r_b % 6 == 0 and nu % 6 == 5, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 4, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 3, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 2, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 1, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) + data_pack_local[v_eps, v_nu, v_p, v_ci] = T.float32(0) + data_pack_local[v_eps, v_nu, v_p, v_ci] = data_pack_local[v_eps, v_nu, v_p, v_ci] + input_tile_local[v_r_a, v_r_b, v_p, v_ci] * T.Select(v_r_a % 6 == 5 and v_eps % 6 == 5, T.float32(1), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 4, T.float32(0), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 3, T.float32(0), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 2, T.float32(0), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 1, T.float32(0), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 0, T.float32(0), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 5, T.float32(1.5), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 4, T.float32(1), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 3, T.float32(1), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 2, T.float32(1), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 1, T.float32(1), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 0, T.float32(1), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 5, T.float32(-2), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 4, T.float32(-0.5), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 3, T.float32(2), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 2, T.float32(2.5), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 1, T.float32(0.5), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 0, T.float32(1.5), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 5, T.float32(-1.5), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 4, T.float32(-1), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 3, T.float32(-1), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 2, T.float32(0.5), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 1, T.float32(-2.5), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 0, T.float32(-2), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 5, T.float32(1), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 4, T.float32(0.5), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 3, T.float32(-2), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 2, T.float32(-1), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 1, T.float32(1), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 0, T.float32(-1.5), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 5, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 4, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 3, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 2, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 1, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) * T.Select(v_r_b % 6 == 5 and v_nu % 6 == 5, T.float32(1), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 4, T.float32(0), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 3, T.float32(0), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 2, T.float32(0), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 1, T.float32(0), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 0, T.float32(0), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 5, T.float32(1.5), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 4, T.float32(1), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 3, T.float32(1), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 2, T.float32(1), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 1, T.float32(1), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 0, T.float32(1), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 5, T.float32(-2), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 4, T.float32(-0.5), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 3, T.float32(2), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 2, T.float32(2.5), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 1, T.float32(0.5), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 0, T.float32(1.5), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 5, T.float32(-1.5), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 4, T.float32(-1), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 3, T.float32(-1), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 2, T.float32(0.5), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 1, T.float32(-2.5), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 0, T.float32(-2), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 5, T.float32(1), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 4, T.float32(0.5), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 3, T.float32(-2), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 2, T.float32(-1), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 1, T.float32(1), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 0, T.float32(-1.5), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 5, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 4, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 3, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 2, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 1, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) for ax0, ax1, ax2, ax3 in T.grid(6, 6, 1, 1): with T.block("data_pack_local"): - T.where(i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1 < 1152) v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(9, (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1) // 384 * 3 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1) % 24 // 8 + ax2) - v3 = T.axis.spatial(128, (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1) % 384 // 24 * 8 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1) % 8 + ax3) + v2 = T.axis.spatial(9, (p_0_ci_0_p_1_ci_1_fused_0 * 1024 + p_0_ci_0_p_1_ci_1_fused_1) // 384 * 3 + (p_0_ci_0_p_1_ci_1_fused_0 * 1024 + p_0_ci_0_p_1_ci_1_fused_1) % 24 // 8 + ax2) + v3 = T.axis.spatial(128, (p_0_ci_0_p_1_ci_1_fused_0 * 1024 + p_0_ci_0_p_1_ci_1_fused_1) % 384 // 24 * 8 + (p_0_ci_0_p_1_ci_1_fused_0 * 1024 + p_0_ci_0_p_1_ci_1_fused_1) % 8 + ax3) + T.where(p_0_ci_0_p_1_ci_1_fused_0 * 1024 + p_0_ci_0_p_1_ci_1_fused_1 < 1152) T.reads(data_pack_local[v0, v1, v2, v3]) T.writes(data_pack[v0, v1, v2, v3]) data_pack[v0, v1, v2, v3] = data_pack_local[v0, v1, v2, v3] - for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(96, thread="blockIdx.x"): - for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(4, thread="vthread.x"): - for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(27, thread="threadIdx.x"): - for i4_0 in T.serial(8): - for ax0_ax1_ax2_ax3_fused in T.serial(1728): + for eps_0_nu_0_p_0_co_0_fused in T.thread_binding(96, thread="blockIdx.x"): + for eps_1_nu_1_p_1_co_1_fused in T.thread_binding(4, thread="vthread.x"): + for eps_2_nu_2_p_2_co_2_fused in T.thread_binding(27, thread="threadIdx.x"): + for ci_0 in range(8): + for ax0_ax1_ax2_ax3_fused in range(1728): with T.block("data_pack_shared"): - v0 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 32 * 2 + ax0_ax1_ax2_ax3_fused // 864) + v0 = T.axis.spatial(6, eps_0_nu_0_p_0_co_0_fused // 32 * 2 + ax0_ax1_ax2_ax3_fused // 864) v1 = T.axis.spatial(6, ax0_ax1_ax2_ax3_fused % 864 // 144) v2 = T.axis.spatial(9, ax0_ax1_ax2_ax3_fused % 144 // 16) - v3 = T.axis.spatial(128, i4_0 * 16 + ax0_ax1_ax2_ax3_fused % 16) + v3 = T.axis.spatial(128, ci_0 * 16 + ax0_ax1_ax2_ax3_fused % 16) T.reads(data_pack[v0, v1, v2, v3]) T.writes(data_pack_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch":1}) + T.block_attr({"meta_schedule.cooperative_fetch": 1}) data_pack_shared[v0, v1, v2, v3] = data_pack[v0, v1, v2, v3] - for ax0_ax1_ax2_ax3_fused in T.serial(768): + for ax0_ax1_ax2_ax3_fused in range(768): with T.block("weight_shared"): - v0 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 32 * 2 + ax0_ax1_ax2_ax3_fused // 384) + v0 = T.axis.spatial(6, eps_0_nu_0_p_0_co_0_fused // 32 * 2 + ax0_ax1_ax2_ax3_fused // 384) v1 = T.axis.spatial(6, ax0_ax1_ax2_ax3_fused % 384 // 64) - v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_i3_0_fused % 32 * 4 + ax0_ax1_ax2_ax3_fused % 64 // 16) - v3 = T.axis.spatial(128, i4_0 * 16 + ax0_ax1_ax2_ax3_fused % 16) + v2 = T.axis.spatial(128, eps_0_nu_0_p_0_co_0_fused % 32 * 4 + ax0_ax1_ax2_ax3_fused % 64 // 16) + v3 = T.axis.spatial(128, ci_0 * 16 + ax0_ax1_ax2_ax3_fused % 16) T.reads(weight[v0, v1, v2, v3]) T.writes(weight_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch":3}) + T.block_attr({"meta_schedule.cooperative_fetch": 3}) weight_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3] - for i4_1, i0_3, i1_3, i2_3, i3_3, i4_2, i0_4, i1_4, i2_4, i3_4 in T.grid(1, 2, 1, 1, 2, 16, 1, 1, 1, 1): + for ci_1, eps_3, nu_3, p_3, co_3, ci_2, eps_4, nu_4, p_4, co_4 in T.grid(1, 2, 1, 1, 2, 16, 1, 1, 1, 1): with T.block("bgemm"): - eps = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 32 * 2 + i0_3 + i0_4) - nu = T.axis.spatial(6, i1_3 + i1_4 + i0_1_i1_1_i2_1_i3_1_fused // 2 * 3 + i0_2_i1_2_i2_2_i3_2_fused // 9) - p = T.axis.spatial(9, i0_2_i1_2_i2_2_i3_2_fused % 9 + i2_3 + i2_4) - co = T.axis.spatial(128, i3_4 + i0_0_i1_0_i2_0_i3_0_fused % 32 * 4 + i0_1_i1_1_i2_1_i3_1_fused % 2 * 2 + i3_3) - ci = T.axis.reduce(128, i4_0 * 16 + i4_1 * 16 + i4_2) - T.reads(data_pack_shared[eps, nu, p, ci], weight_shared[eps, nu, co, ci]) - T.writes(bgemm_local[eps, nu, p, co]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS", "meta_schedule.write_cache_level":[3]}) + v_eps = T.axis.spatial(6, eps_0_nu_0_p_0_co_0_fused // 32 * 2 + eps_3 + eps_4) + v_nu = T.axis.spatial(6, eps_1_nu_1_p_1_co_1_fused // 2 * 3 + eps_2_nu_2_p_2_co_2_fused // 9 + nu_3 + nu_4) + v_p = T.axis.spatial(9, eps_2_nu_2_p_2_co_2_fused % 9 + p_3 + p_4) + v_co = T.axis.spatial(128, eps_0_nu_0_p_0_co_0_fused % 32 * 4 + eps_1_nu_1_p_1_co_1_fused % 2 * 2 + co_3 + co_4) + v_ci = T.axis.reduce(128, ci_0 * 16 + ci_1 * 16 + ci_2) + T.reads(data_pack_shared[v_eps, v_nu, v_p, v_ci], weight_shared[v_eps, v_nu, v_co, v_ci]) + T.writes(bgemm_local[v_eps, v_nu, v_p, v_co]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS", "meta_schedule.write_cache_level": [3]}) with T.init(): - bgemm_local[eps, nu, p, co] = T.float32(0) - bgemm_local[eps, nu, p, co] = bgemm_local[eps, nu, p, co] + data_pack_shared[eps, nu, p, ci] * weight_shared[eps, nu, co, ci] + bgemm_local[v_eps, v_nu, v_p, v_co] = T.float32(0) + bgemm_local[v_eps, v_nu, v_p, v_co] = bgemm_local[v_eps, v_nu, v_p, v_co] + data_pack_shared[v_eps, v_nu, v_p, v_ci] * weight_shared[v_eps, v_nu, v_co, v_ci] for ax0, ax1, ax2, ax3 in T.grid(2, 1, 1, 2): with T.block("bgemm_local"): - v0 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 32 * 2 + ax0) - v1 = T.axis.spatial(6, i0_1_i1_1_i2_1_i3_1_fused // 2 * 3 + i0_2_i1_2_i2_2_i3_2_fused // 9 + ax1) - v2 = T.axis.spatial(9, i0_2_i1_2_i2_2_i3_2_fused % 9 + ax2) - v3 = T.axis.spatial(128, i0_0_i1_0_i2_0_i3_0_fused % 32 * 4 + i0_1_i1_1_i2_1_i3_1_fused % 2 * 2 + ax3) + v0 = T.axis.spatial(6, eps_0_nu_0_p_0_co_0_fused // 32 * 2 + ax0) + v1 = T.axis.spatial(6, eps_1_nu_1_p_1_co_1_fused // 2 * 3 + eps_2_nu_2_p_2_co_2_fused // 9 + ax1) + v2 = T.axis.spatial(9, eps_2_nu_2_p_2_co_2_fused % 9 + ax2) + v3 = T.axis.spatial(128, eps_0_nu_0_p_0_co_0_fused % 32 * 4 + eps_1_nu_1_p_1_co_1_fused % 2 * 2 + ax3) T.reads(bgemm_local[v0, v1, v2, v3]) T.writes(bgemm[v0, v1, v2, v3]) bgemm[v0, v1, v2, v3] = bgemm_local[v0, v1, v2, v3] - for i2_0_i3_0_i2_1_i3_1_fused_0 in T.thread_binding(18, thread="blockIdx.x"): - for i2_0_i3_0_i2_1_i3_1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): - for i0 in T.unroll(4): - for i1 in T.unroll(4): - for i4 in T.unroll(6): - for i5 in T.unroll(6): + for p_0_co_0_p_1_co_1_fused_0 in T.thread_binding(18, thread="blockIdx.x"): + for p_0_co_0_p_1_co_1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): + for vh in T.unroll(4): + for vw in T.unroll(4): + for r_a in T.unroll(6): + for r_b in T.unroll(6): with T.block("inverse"): - vh, vw = T.axis.remap("SS", [i0, i1]) - p = T.axis.spatial(9, (i2_0_i3_0_i2_1_i3_1_fused_0 * 64 + i2_0_i3_0_i2_1_i3_1_fused_1) // 384 * 3 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 64 + i2_0_i3_0_i2_1_i3_1_fused_1) % 24 // 8) - co = T.axis.spatial(128, (i2_0_i3_0_i2_1_i3_1_fused_0 * 64 + i2_0_i3_0_i2_1_i3_1_fused_1) % 384 // 24 * 8 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 64 + i2_0_i3_0_i2_1_i3_1_fused_1) % 8) - r_a, r_b = T.axis.remap("RR", [i4, i5]) - T.reads(bgemm[r_a, r_b, p, co]) - T.writes(inverse[vh, vw, p, co]) - T.block_attr({"schedule_rule":"conv2d_nhwc_winograd_inverse"}) + v_vh, v_vw = T.axis.remap("SS", [vh, vw]) + v_p = T.axis.spatial(9, (p_0_co_0_p_1_co_1_fused_0 * 64 + p_0_co_0_p_1_co_1_fused_1) // 384 * 3 + (p_0_co_0_p_1_co_1_fused_0 * 64 + p_0_co_0_p_1_co_1_fused_1) % 24 // 8) + v_co = T.axis.spatial(128, (p_0_co_0_p_1_co_1_fused_0 * 64 + p_0_co_0_p_1_co_1_fused_1) % 384 // 24 * 8 + (p_0_co_0_p_1_co_1_fused_0 * 64 + p_0_co_0_p_1_co_1_fused_1) % 8) + v_r_a, v_r_b = T.axis.remap("RR", [r_a, r_b]) + T.reads(bgemm[v_r_a, v_r_b, v_p, v_co]) + T.writes(inverse[v_vh, v_vw, v_p, v_co]) + T.block_attr({"schedule_rule": "conv2d_nhwc_winograd_inverse"}) with T.init(): - inverse[vh, vw, p, co] = T.float32(0) - inverse[vh, vw, p, co] = inverse[vh, vw, p, co] + bgemm[r_a, r_b, p, co] * T.Select(r_a % 6 == 5 and vh % 4 == 3, T.float32(1), T.Select(r_a % 6 == 5 and vh % 4 == 2, T.float32(0), T.Select(r_a % 6 == 5 and vh % 4 == 1, T.float32(0), T.Select(r_a % 6 == 5 and vh % 4 == 0, T.float32(0), T.Select(r_a % 6 == 4 and vh % 4 == 3, T.float32(-8), T.Select(r_a % 6 == 4 and vh % 4 == 2, T.float32(4), T.Select(r_a % 6 == 4 and vh % 4 == 1, T.float32(-2), T.Select(r_a % 6 == 4 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 3 and vh % 4 == 3, T.float32(0.125), T.Select(r_a % 6 == 3 and vh % 4 == 2, T.float32(0.25), T.Select(r_a % 6 == 3 and vh % 4 == 1, T.float32(0.5), T.Select(r_a % 6 == 3 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 3, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 2, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 1, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 1 and vh % 4 == 3, T.float32(-1), T.Select(r_a % 6 == 1 and vh % 4 == 2, T.float32(1), T.Select(r_a % 6 == 1 and vh % 4 == 1, T.float32(-1), T.Select(r_a % 6 == 1 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 0 and vh % 4 == 3, T.float32(0), T.Select(r_a % 6 == 0 and vh % 4 == 2, T.float32(0), T.Select(r_a % 6 == 0 and vh % 4 == 1, T.float32(0), T.Select(r_a % 6 == 0 and vh % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) * T.Select(r_b % 6 == 5 and vw % 4 == 3, T.float32(1), T.Select(r_b % 6 == 5 and vw % 4 == 2, T.float32(0), T.Select(r_b % 6 == 5 and vw % 4 == 1, T.float32(0), T.Select(r_b % 6 == 5 and vw % 4 == 0, T.float32(0), T.Select(r_b % 6 == 4 and vw % 4 == 3, T.float32(-8), T.Select(r_b % 6 == 4 and vw % 4 == 2, T.float32(4), T.Select(r_b % 6 == 4 and vw % 4 == 1, T.float32(-2), T.Select(r_b % 6 == 4 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 3 and vw % 4 == 3, T.float32(0.125), T.Select(r_b % 6 == 3 and vw % 4 == 2, T.float32(0.25), T.Select(r_b % 6 == 3 and vw % 4 == 1, T.float32(0.5), T.Select(r_b % 6 == 3 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 3, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 2, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 1, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 1 and vw % 4 == 3, T.float32(-1), T.Select(r_b % 6 == 1 and vw % 4 == 2, T.float32(1), T.Select(r_b % 6 == 1 and vw % 4 == 1, T.float32(-1), T.Select(r_b % 6 == 1 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 0 and vw % 4 == 3, T.float32(0), T.Select(r_b % 6 == 0 and vw % 4 == 2, T.float32(0), T.Select(r_b % 6 == 0 and vw % 4 == 1, T.float32(0), T.Select(r_b % 6 == 0 and vw % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) - for i0_i1_i2_i3_fused_0 in T.thread_binding(144, thread="blockIdx.x"): - for i0_i1_i2_i3_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + inverse[v_vh, v_vw, v_p, v_co] = T.float32(0) + inverse[v_vh, v_vw, v_p, v_co] = inverse[v_vh, v_vw, v_p, v_co] + bgemm[v_r_a, v_r_b, v_p, v_co] * T.Select(v_r_a % 6 == 5 and v_vh % 4 == 3, T.float32(1), T.Select(v_r_a % 6 == 5 and v_vh % 4 == 2, T.float32(0), T.Select(v_r_a % 6 == 5 and v_vh % 4 == 1, T.float32(0), T.Select(v_r_a % 6 == 5 and v_vh % 4 == 0, T.float32(0), T.Select(v_r_a % 6 == 4 and v_vh % 4 == 3, T.float32(-8), T.Select(v_r_a % 6 == 4 and v_vh % 4 == 2, T.float32(4), T.Select(v_r_a % 6 == 4 and v_vh % 4 == 1, T.float32(-2), T.Select(v_r_a % 6 == 4 and v_vh % 4 == 0, T.float32(1), T.Select(v_r_a % 6 == 3 and v_vh % 4 == 3, T.float32(0.125), T.Select(v_r_a % 6 == 3 and v_vh % 4 == 2, T.float32(0.25), T.Select(v_r_a % 6 == 3 and v_vh % 4 == 1, T.float32(0.5), T.Select(v_r_a % 6 == 3 and v_vh % 4 == 0, T.float32(1), T.Select(v_r_a % 6 == 2 and v_vh % 4 == 3, T.float32(1), T.Select(v_r_a % 6 == 2 and v_vh % 4 == 2, T.float32(1), T.Select(v_r_a % 6 == 2 and v_vh % 4 == 1, T.float32(1), T.Select(v_r_a % 6 == 2 and v_vh % 4 == 0, T.float32(1), T.Select(v_r_a % 6 == 1 and v_vh % 4 == 3, T.float32(-1), T.Select(v_r_a % 6 == 1 and v_vh % 4 == 2, T.float32(1), T.Select(v_r_a % 6 == 1 and v_vh % 4 == 1, T.float32(-1), T.Select(v_r_a % 6 == 1 and v_vh % 4 == 0, T.float32(1), T.Select(v_r_a % 6 == 0 and v_vh % 4 == 3, T.float32(0), T.Select(v_r_a % 6 == 0 and v_vh % 4 == 2, T.float32(0), T.Select(v_r_a % 6 == 0 and v_vh % 4 == 1, T.float32(0), T.Select(v_r_a % 6 == 0 and v_vh % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) * T.Select(v_r_b % 6 == 5 and v_vw % 4 == 3, T.float32(1), T.Select(v_r_b % 6 == 5 and v_vw % 4 == 2, T.float32(0), T.Select(v_r_b % 6 == 5 and v_vw % 4 == 1, T.float32(0), T.Select(v_r_b % 6 == 5 and v_vw % 4 == 0, T.float32(0), T.Select(v_r_b % 6 == 4 and v_vw % 4 == 3, T.float32(-8), T.Select(v_r_b % 6 == 4 and v_vw % 4 == 2, T.float32(4), T.Select(v_r_b % 6 == 4 and v_vw % 4 == 1, T.float32(-2), T.Select(v_r_b % 6 == 4 and v_vw % 4 == 0, T.float32(1), T.Select(v_r_b % 6 == 3 and v_vw % 4 == 3, T.float32(0.125), T.Select(v_r_b % 6 == 3 and v_vw % 4 == 2, T.float32(0.25), T.Select(v_r_b % 6 == 3 and v_vw % 4 == 1, T.float32(0.5), T.Select(v_r_b % 6 == 3 and v_vw % 4 == 0, T.float32(1), T.Select(v_r_b % 6 == 2 and v_vw % 4 == 3, T.float32(1), T.Select(v_r_b % 6 == 2 and v_vw % 4 == 2, T.float32(1), T.Select(v_r_b % 6 == 2 and v_vw % 4 == 1, T.float32(1), T.Select(v_r_b % 6 == 2 and v_vw % 4 == 0, T.float32(1), T.Select(v_r_b % 6 == 1 and v_vw % 4 == 3, T.float32(-1), T.Select(v_r_b % 6 == 1 and v_vw % 4 == 2, T.float32(1), T.Select(v_r_b % 6 == 1 and v_vw % 4 == 1, T.float32(-1), T.Select(v_r_b % 6 == 1 and v_vw % 4 == 0, T.float32(1), T.Select(v_r_b % 6 == 0 and v_vw % 4 == 3, T.float32(0), T.Select(v_r_b % 6 == 0 and v_vw % 4 == 2, T.float32(0), T.Select(v_r_b % 6 == 0 and v_vw % 4 == 1, T.float32(0), T.Select(v_r_b % 6 == 0 and v_vw % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) + for n_h_w_co_fused_0 in T.thread_binding(144, thread="blockIdx.x"): + for n_h_w_co_fused_1 in T.thread_binding(128, thread="threadIdx.x"): with T.block("conv2d_winograd"): - n = T.axis.spatial(1, 0) - h = T.axis.spatial(12, (i0_i1_i2_i3_fused_0 * 128 + i0_i1_i2_i3_fused_1) // 1536) - w = T.axis.spatial(12, (i0_i1_i2_i3_fused_0 * 128 + i0_i1_i2_i3_fused_1) % 1536 // 128) - co = T.axis.spatial(128, (i0_i1_i2_i3_fused_0 * 128 + i0_i1_i2_i3_fused_1) % 128) - T.reads(inverse[h % 4, w % 4, n * 9 + h // 4 * 3 + w // 4, co]) - T.writes(conv2d_winograd[n, h, w, co]) - conv2d_winograd[n, h, w, co] = inverse[h % 4, w % 4, n * 9 + h // 4 * 3 + w // 4, co] + v_n = T.axis.spatial(1, 0) + v_h = T.axis.spatial(12, (n_h_w_co_fused_0 * 128 + n_h_w_co_fused_1) // 1536) + v_w = T.axis.spatial(12, (n_h_w_co_fused_0 * 128 + n_h_w_co_fused_1) % 1536 // 128) + v_co = T.axis.spatial(128, (n_h_w_co_fused_0 * 128 + n_h_w_co_fused_1) % 128) + T.reads(inverse[v_h % 4, v_w % 4, v_n * 9 + v_h // 4 * 3 + v_w // 4, v_co]) + T.writes(conv2d_winograd[v_n, v_h, v_w, v_co]) + conv2d_winograd[v_n, v_h, v_w, v_co] = inverse[v_h % 4, v_w % 4, v_n * 9 + v_h // 4 * 3 + v_w // 4, v_co] # fmt: on decision_0 = [ ("SamplePerfectTile", [3, 3]), @@ -201,130 +199,128 @@ def test_cuda_nchw(): # fmt: off @T.prim_func def cuda_nchw_0(data: T.Buffer((1, 64, 56, 56), "float32"), weight: T.Buffer((6, 6, 64, 64), "float32"), conv2d_winograd: T.Buffer((1, 64, 56, 56), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) - # body + T.func_attr({"global_symbol": "main", "layout_free_buffers": [1], "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit":16}) - input_tile_local = T.alloc_buffer([64, 196, 6, 6], dtype="float32", scope="local") - data_pack = T.alloc_buffer([6, 6, 64, 196], dtype="float32") - bgemm = T.alloc_buffer([6, 6, 64, 196], dtype="float32") - inverse_local = T.alloc_buffer([64, 196, 4, 4], dtype="float32", scope="local") - data_pack_local = T.alloc_buffer([6, 6, 64, 196], dtype="float32", scope="local") - bgemm_local = T.alloc_buffer([6, 6, 64, 196], dtype="float32", scope="local") - data_pack_shared = T.alloc_buffer([6, 6, 64, 196], dtype="float32", scope="shared") - weight_shared = T.alloc_buffer([6, 6, 64, 64], dtype="float32", scope="shared") - for i2_i3_fused_0 in T.thread_binding(25, thread="blockIdx.x"): - for i2_i3_fused_1 in T.thread_binding(512, thread="threadIdx.x"): + T.block_attr({"meta_schedule.unroll_explicit": 16}) + input_tile_local = T.alloc_buffer((64, 196, 6, 6), scope="local") + data_pack = T.alloc_buffer((6, 6, 64, 196)) + bgemm = T.alloc_buffer((6, 6, 64, 196)) + inverse_local = T.alloc_buffer((64, 196, 4, 4), scope="local") + data_pack_local = T.alloc_buffer((6, 6, 64, 196), scope="local") + bgemm_local = T.alloc_buffer((6, 6, 64, 196), scope="local") + data_pack_shared = T.alloc_buffer((6, 6, 64, 196), scope="shared") + weight_shared = T.alloc_buffer((6, 6, 64, 64), scope="shared") + for ci_p_fused_0 in T.thread_binding(25, thread="blockIdx.x"): + for ci_p_fused_1 in T.thread_binding(512, thread="threadIdx.x"): for ax0, ax1, ax2, ax3 in T.grid(1, 1, 6, 6): with T.block("input_tile"): - T.where(i2_i3_fused_0 * 512 + i2_i3_fused_1 < 12544) - ci = T.axis.spatial(64, (i2_i3_fused_0 * 512 + i2_i3_fused_1) // 196 + ax0) - p = T.axis.spatial(196, (i2_i3_fused_0 * 120 + i2_i3_fused_1) % 196 + ax1) - eps, nu = T.axis.remap("SS", [ax2, ax3]) - T.reads(data[p // 196, ci, p % 196 // 14 * 4 + eps - 1, p % 14 * 4 + nu - 1]) - T.writes(input_tile_local[ci, p, eps, nu]) - T.block_attr({"schedule_rule":"None"}) - input_tile_local[ci, p, eps, nu] = T.if_then_else(1 <= p % 196 // 14 * 4 + eps and p % 196 // 14 * 4 + eps < 57 and 1 <= p % 14 * 4 + nu and p % 14 * 4 + nu < 57, data[p // 196, ci, p % 196 // 14 * 4 + eps - 1, p % 14 * 4 + nu - 1], T.float32(0), dtype="float32") - for i0 in T.unroll(6): - for i1 in T.unroll(6): - for i4 in T.unroll(6): - for i5 in T.unroll(6): + v_ci = T.axis.spatial(64, (ci_p_fused_0 * 512 + ci_p_fused_1) // 196 + ax0) + v_p = T.axis.spatial(196, (ci_p_fused_0 * 120 + ci_p_fused_1) % 196 + ax1) + v_eps, v_nu = T.axis.remap("SS", [ax2, ax3]) + T.where(ci_p_fused_0 * 512 + ci_p_fused_1 < 12544) + T.reads(data[v_p // 196, v_ci, v_p % 196 // 14 * 4 + v_eps - 1, v_p % 14 * 4 + v_nu - 1]) + T.writes(input_tile_local[v_ci, v_p, v_eps, v_nu]) + T.block_attr({"schedule_rule": "None"}) + input_tile_local[v_ci, v_p, v_eps, v_nu] = T.if_then_else(1 <= v_p % 196 // 14 * 4 + v_eps and v_p % 196 // 14 * 4 + v_eps < 57 and 1 <= v_p % 14 * 4 + v_nu and v_p % 14 * 4 + v_nu < 57, data[v_p // 196, v_ci, v_p % 196 // 14 * 4 + v_eps - 1, v_p % 14 * 4 + v_nu - 1], T.float32(0)) + for eps in T.unroll(6): + for nu in T.unroll(6): + for r_a in T.unroll(6): + for r_b in T.unroll(6): with T.block("data_pack"): - T.where(i2_i3_fused_0 * 512 + i2_i3_fused_1 < 12544) - eps, nu = T.axis.remap("SS", [i0, i1]) - ci = T.axis.spatial(64, (i2_i3_fused_0 * 512 + i2_i3_fused_1) // 196) - p = T.axis.spatial(196, (i2_i3_fused_0 * 512 + i2_i3_fused_1) % 196) - r_a, r_b = T.axis.remap("RR", [i4, i5]) - T.reads(input_tile_local[ci, p, r_a, r_b]) - T.writes(data_pack_local[eps, nu, ci, p]) - T.block_attr({"schedule_rule":"conv2d_nchw_winograd_data_pack"}) + v_eps, v_nu = T.axis.remap("SS", [eps, nu]) + v_ci = T.axis.spatial(64, (ci_p_fused_0 * 512 + ci_p_fused_1) // 196) + v_p = T.axis.spatial(196, (ci_p_fused_0 * 512 + ci_p_fused_1) % 196) + v_r_a, v_r_b = T.axis.remap("RR", [r_a, r_b]) + T.where(ci_p_fused_0 * 512 + ci_p_fused_1 < 12544) + T.reads(input_tile_local[v_ci, v_p, v_r_a, v_r_b]) + T.writes(data_pack_local[v_eps, v_nu, v_ci, v_p]) + T.block_attr({"schedule_rule": "conv2d_nchw_winograd_data_pack"}) with T.init(): - data_pack_local[eps, nu, ci, p] = T.float32(0) - data_pack_local[eps, nu, ci, p] = data_pack_local[eps, nu, ci, p] + input_tile_local[ci, p, r_a, r_b] * T.Select(r_a % 6 == 5 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 5 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 0, T.float32(0), T.Select(r_a % 6 == 4 and eps % 6 == 5, T.float32(1.5), T.Select(r_a % 6 == 4 and eps % 6 == 4, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 3, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 2, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 0, T.float32(1), T.Select(r_a % 6 == 3 and eps % 6 == 5, T.float32(-2), T.Select(r_a % 6 == 3 and eps % 6 == 4, T.float32(-0.5), T.Select(r_a % 6 == 3 and eps % 6 == 3, T.float32(2), T.Select(r_a % 6 == 3 and eps % 6 == 2, T.float32(2.5), T.Select(r_a % 6 == 3 and eps % 6 == 1, T.float32(0.5), T.Select(r_a % 6 == 3 and eps % 6 == 0, T.float32(1.5), T.Select(r_a % 6 == 2 and eps % 6 == 5, T.float32(-1.5), T.Select(r_a % 6 == 2 and eps % 6 == 4, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 3, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 2, T.float32(0.5), T.Select(r_a % 6 == 2 and eps % 6 == 1, T.float32(-2.5), T.Select(r_a % 6 == 2 and eps % 6 == 0, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 4, T.float32(0.5), T.Select(r_a % 6 == 1 and eps % 6 == 3, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 2, T.float32(-1), T.Select(r_a % 6 == 1 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 0, T.float32(-1.5), T.Select(r_a % 6 == 0 and eps % 6 == 5, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) * T.Select(r_b % 6 == 5 and nu % 6 == 5, T.float32(1), T.Select(r_b % 6 == 5 and nu % 6 == 4, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 3, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 2, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 1, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 0, T.float32(0), T.Select(r_b % 6 == 4 and nu % 6 == 5, T.float32(1.5), T.Select(r_b % 6 == 4 and nu % 6 == 4, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 3, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 2, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 1, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 0, T.float32(1), T.Select(r_b % 6 == 3 and nu % 6 == 5, T.float32(-2), T.Select(r_b % 6 == 3 and nu % 6 == 4, T.float32(-0.5), T.Select(r_b % 6 == 3 and nu % 6 == 3, T.float32(2), T.Select(r_b % 6 == 3 and nu % 6 == 2, T.float32(2.5), T.Select(r_b % 6 == 3 and nu % 6 == 1, T.float32(0.5), T.Select(r_b % 6 == 3 and nu % 6 == 0, T.float32(1.5), T.Select(r_b % 6 == 2 and nu % 6 == 5, T.float32(-1.5), T.Select(r_b % 6 == 2 and nu % 6 == 4, T.float32(-1), T.Select(r_b % 6 == 2 and nu % 6 == 3, T.float32(-1), T.Select(r_b % 6 == 2 and nu % 6 == 2, T.float32(0.5), T.Select(r_b % 6 == 2 and nu % 6 == 1, T.float32(-2.5), T.Select(r_b % 6 == 2 and nu % 6 == 0, T.float32(-2), T.Select(r_b % 6 == 1 and nu % 6 == 5, T.float32(1), T.Select(r_b % 6 == 1 and nu % 6 == 4, T.float32(0.5), T.Select(r_b % 6 == 1 and nu % 6 == 3, T.float32(-2), T.Select(r_b % 6 == 1 and nu % 6 == 2, T.float32(-1), T.Select(r_b % 6 == 1 and nu % 6 == 1, T.float32(1), T.Select(r_b % 6 == 1 and nu % 6 == 0, T.float32(-1.5), T.Select(r_b % 6 == 0 and nu % 6 == 5, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 4, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 3, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 2, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 1, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) + data_pack_local[v_eps, v_nu, v_ci, v_p] = T.float32(0) + data_pack_local[v_eps, v_nu, v_ci, v_p] = data_pack_local[v_eps, v_nu, v_ci, v_p] + input_tile_local[v_ci, v_p, v_r_a, v_r_b] * T.Select(v_r_a % 6 == 5 and v_eps % 6 == 5, T.float32(1), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 4, T.float32(0), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 3, T.float32(0), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 2, T.float32(0), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 1, T.float32(0), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 0, T.float32(0), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 5, T.float32(1.5), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 4, T.float32(1), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 3, T.float32(1), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 2, T.float32(1), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 1, T.float32(1), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 0, T.float32(1), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 5, T.float32(-2), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 4, T.float32(-0.5), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 3, T.float32(2), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 2, T.float32(2.5), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 1, T.float32(0.5), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 0, T.float32(1.5), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 5, T.float32(-1.5), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 4, T.float32(-1), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 3, T.float32(-1), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 2, T.float32(0.5), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 1, T.float32(-2.5), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 0, T.float32(-2), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 5, T.float32(1), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 4, T.float32(0.5), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 3, T.float32(-2), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 2, T.float32(-1), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 1, T.float32(1), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 0, T.float32(-1.5), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 5, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 4, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 3, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 2, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 1, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) * T.Select(v_r_b % 6 == 5 and v_nu % 6 == 5, T.float32(1), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 4, T.float32(0), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 3, T.float32(0), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 2, T.float32(0), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 1, T.float32(0), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 0, T.float32(0), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 5, T.float32(1.5), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 4, T.float32(1), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 3, T.float32(1), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 2, T.float32(1), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 1, T.float32(1), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 0, T.float32(1), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 5, T.float32(-2), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 4, T.float32(-0.5), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 3, T.float32(2), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 2, T.float32(2.5), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 1, T.float32(0.5), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 0, T.float32(1.5), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 5, T.float32(-1.5), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 4, T.float32(-1), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 3, T.float32(-1), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 2, T.float32(0.5), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 1, T.float32(-2.5), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 0, T.float32(-2), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 5, T.float32(1), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 4, T.float32(0.5), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 3, T.float32(-2), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 2, T.float32(-1), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 1, T.float32(1), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 0, T.float32(-1.5), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 5, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 4, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 3, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 2, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 1, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) for ax0, ax1, ax2, ax3 in T.grid(6, 6, 1, 1): with T.block("data_pack_local"): - T.where(i2_i3_fused_0 * 512 + i2_i3_fused_1 < 12544) v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(64, (i2_i3_fused_0 * 512 + i2_i3_fused_1) // 196 + ax2) - v3 = T.axis.spatial(196, (i2_i3_fused_0 * 120 + i2_i3_fused_1) % 196 + ax3) + v2 = T.axis.spatial(64, (ci_p_fused_0 * 512 + ci_p_fused_1) // 196 + ax2) + v3 = T.axis.spatial(196, (ci_p_fused_0 * 120 + ci_p_fused_1) % 196 + ax3) + T.where(ci_p_fused_0 * 512 + ci_p_fused_1 < 12544) T.reads(data_pack_local[v0, v1, v2, v3]) T.writes(data_pack[v0, v1, v2, v3]) data_pack[v0, v1, v2, v3] = data_pack_local[v0, v1, v2, v3] - for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(14, thread="blockIdx.x"): - for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(224, thread="vthread.x"): - for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(2, thread="threadIdx.x"): - for i4_0 in T.serial(2): - for ax0_ax1_ax2_ax3_fused in T.serial(32256): + for eps_0_nu_0_co_0_p_0_fused in T.thread_binding(14, thread="blockIdx.x"): + for eps_1_nu_1_co_1_p_1_fused in T.thread_binding(224, thread="vthread.x"): + for eps_2_nu_2_co_2_p_2_fused in T.thread_binding(2, thread="threadIdx.x"): + for ci_0 in range(2): + for ax0_ax1_ax2_ax3_fused in range(32256): with T.block("data_pack_shared"): v0 = T.axis.spatial(6, ax0_ax1_ax2_ax3_fused // 5376) v1 = T.axis.spatial(6, ax0_ax1_ax2_ax3_fused % 5376 // 896) - v2 = T.axis.spatial(64, i4_0 * 32 + ax0_ax1_ax2_ax3_fused % 896 // 28) - v3 = T.axis.spatial(196, i0_0_i1_0_i2_0_i3_0_fused % 7 * 28 + ax0_ax1_ax2_ax3_fused % 28) + v2 = T.axis.spatial(64, ci_0 * 32 + ax0_ax1_ax2_ax3_fused % 896 // 28) + v3 = T.axis.spatial(196, eps_0_nu_0_co_0_p_0_fused % 7 * 28 + ax0_ax1_ax2_ax3_fused % 28) T.reads(data_pack[v0, v1, v2, v3]) T.writes(data_pack_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch":4}) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) data_pack_shared[v0, v1, v2, v3] = data_pack[v0, v1, v2, v3] - for ax0_ax1_ax2_ax3_fused in T.serial(36864): + for ax0_ax1_ax2_ax3_fused in range(36864): with T.block("weight_shared"): v0 = T.axis.spatial(6, ax0_ax1_ax2_ax3_fused // 6144) v1 = T.axis.spatial(6, ax0_ax1_ax2_ax3_fused % 6144 // 1024) - v2 = T.axis.spatial(64, i4_0 * 32 + ax0_ax1_ax2_ax3_fused % 1024 // 32) - v3 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused // 7 * 32 + ax0_ax1_ax2_ax3_fused % 32) + v2 = T.axis.spatial(64, ci_0 * 32 + ax0_ax1_ax2_ax3_fused % 1024 // 32) + v3 = T.axis.spatial(64, eps_0_nu_0_co_0_p_0_fused // 7 * 32 + ax0_ax1_ax2_ax3_fused % 32) T.reads(weight[v0, v1, v2, v3]) T.writes(weight_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch":3}) + T.block_attr({"meta_schedule.cooperative_fetch": 3}) weight_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3] - for i4_1, i0_3, i1_3, i2_3, i3_3, i4_2, i0_4, i1_4, i2_4, i3_4 in T.grid(16, 2, 3, 1, 4, 2, 3, 1, 1, 1): + for ci_1, eps_3, nu_3, co_3, p_3, ci_2, eps_4, nu_4, co_4, p_4 in T.grid(16, 2, 3, 1, 4, 2, 3, 1, 1, 1): with T.block("bgemm"): - eps = T.axis.spatial(6, i0_3 * 3 + i0_4) - nu = T.axis.spatial(6, i1_4 + i0_1_i1_1_i2_1_i3_1_fused // 112 * 3 + i1_3) - co = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused // 7 * 32 + i0_1_i1_1_i2_1_i3_1_fused % 112 // 7 * 2 + i0_2_i1_2_i2_2_i3_2_fused + i2_3 + i2_4) - p = T.axis.spatial(196, i3_4 + i0_0_i1_0_i2_0_i3_0_fused % 7 * 28 + i0_1_i1_1_i2_1_i3_1_fused % 7 * 4 + i3_3) - ci = T.axis.reduce(64, i4_0 * 32 + i4_1 * 2 + i4_2) - T.reads(data_pack_shared[eps, nu, ci, p], weight_shared[eps, nu, ci, co]) - T.writes(bgemm_local[eps, nu, co, p]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) + v_eps = T.axis.spatial(6, eps_3 * 3 + eps_4) + v_nu = T.axis.spatial(6, eps_1_nu_1_co_1_p_1_fused // 112 * 3 + nu_3 + nu_4) + v_co = T.axis.spatial(64, eps_0_nu_0_co_0_p_0_fused // 7 * 32 + eps_1_nu_1_co_1_p_1_fused % 112 // 7 * 2 + eps_2_nu_2_co_2_p_2_fused + co_3 + co_4) + v_p = T.axis.spatial(196, eps_0_nu_0_co_0_p_0_fused % 7 * 28 + eps_1_nu_1_co_1_p_1_fused % 7 * 4 + p_3 + p_4) + v_ci = T.axis.reduce(64, ci_0 * 32 + ci_1 * 2 + ci_2) + T.reads(data_pack_shared[v_eps, v_nu, v_ci, v_p], weight_shared[v_eps, v_nu, v_ci, v_co]) + T.writes(bgemm_local[v_eps, v_nu, v_co, v_p]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): - bgemm_local[eps, nu, co, p] = T.float32(0) - bgemm_local[eps, nu, co, p] = bgemm_local[eps, nu, co, p] + data_pack_shared[eps, nu, ci, p] * weight_shared[eps, nu, ci, co] + bgemm_local[v_eps, v_nu, v_co, v_p] = T.float32(0) + bgemm_local[v_eps, v_nu, v_co, v_p] = bgemm_local[v_eps, v_nu, v_co, v_p] + data_pack_shared[v_eps, v_nu, v_ci, v_p] * weight_shared[v_eps, v_nu, v_ci, v_co] for ax0, ax1, ax2, ax3 in T.grid(6, 3, 1, 4): with T.block("bgemm_local"): v0 = T.axis.spatial(6, ax0) - v1 = T.axis.spatial(6, i0_1_i1_1_i2_1_i3_1_fused // 112 * 3 + ax1) - v2 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused // 7 * 32 + i0_1_i1_1_i2_1_i3_1_fused % 112 // 7 * 2 + i0_2_i1_2_i2_2_i3_2_fused + ax2) - v3 = T.axis.spatial(196, i0_0_i1_0_i2_0_i3_0_fused % 7 * 28 + i0_1_i1_1_i2_1_i3_1_fused % 7 * 4 + ax3) + v1 = T.axis.spatial(6, eps_1_nu_1_co_1_p_1_fused // 112 * 3 + ax1) + v2 = T.axis.spatial(64, eps_0_nu_0_co_0_p_0_fused // 7 * 32 + eps_1_nu_1_co_1_p_1_fused % 112 // 7 * 2 + eps_2_nu_2_co_2_p_2_fused + ax2) + v3 = T.axis.spatial(196, eps_0_nu_0_co_0_p_0_fused % 7 * 28 + eps_1_nu_1_co_1_p_1_fused % 7 * 4 + ax3) T.reads(bgemm_local[v0, v1, v2, v3]) T.writes(bgemm[v0, v1, v2, v3]) bgemm[v0, v1, v2, v3] = bgemm_local[v0, v1, v2, v3] - for i0_i1_i2_0_i3_0_fused_0 in T.thread_binding(196, thread="blockIdx.x"): - for i0_i1_i2_0_i3_0_fused_1 in T.thread_binding(64, thread="threadIdx.x"): + for n_co_h_0_w_0_fused_0 in T.thread_binding(196, thread="blockIdx.x"): + for n_co_h_0_w_0_fused_1 in T.thread_binding(64, thread="threadIdx.x"): for ax0, ax1 in T.grid(1, 1): for ax2 in T.unroll(4): for ax3 in T.unroll(4): for ax4 in T.unroll(6): for ax5 in T.unroll(6): with T.block("inverse"): - co = T.axis.spatial(64, (i0_i1_i2_0_i3_0_fused_0 * 64 + i0_i1_i2_0_i3_0_fused_1) // 196 + ax0) - p = T.axis.spatial(196, (i0_i1_i2_0_i3_0_fused_0 * 64 + i0_i1_i2_0_i3_0_fused_1) % 196 // 14 * 14 + (i0_i1_i2_0_i3_0_fused_0 * 64 + i0_i1_i2_0_i3_0_fused_1) % 14 + ax1) - vh, vw, r_a, r_b = T.axis.remap("SSRR", [ax2, ax3, ax4, ax5]) - T.reads(bgemm[r_a, r_b, co, p]) - T.writes(inverse_local[co, p, vh, vw]) - T.block_attr({"schedule_rule":"conv2d_nchw_winograd_inverse"}) + v_co = T.axis.spatial(64, (n_co_h_0_w_0_fused_0 * 64 + n_co_h_0_w_0_fused_1) // 196 + ax0) + v_p = T.axis.spatial(196, (n_co_h_0_w_0_fused_0 * 64 + n_co_h_0_w_0_fused_1) % 196 + ax1) + v_vh, v_vw, v_r_a, v_r_b = T.axis.remap("SSRR", [ax2, ax3, ax4, ax5]) + T.reads(bgemm[v_r_a, v_r_b, v_co, v_p]) + T.writes(inverse_local[v_co, v_p, v_vh, v_vw]) + T.block_attr({"schedule_rule": "conv2d_nchw_winograd_inverse"}) with T.init(): - inverse_local[co, p, vh, vw] = T.float32(0) - inverse_local[co, p, vh, vw] = inverse_local[co, p, vh, vw] + bgemm[r_a, r_b, co, p] * T.Select(r_a % 6 == 5 and vh % 4 == 3, T.float32(1), T.Select(r_a % 6 == 5 and vh % 4 == 2, T.float32(0), T.Select(r_a % 6 == 5 and vh % 4 == 1, T.float32(0), T.Select(r_a % 6 == 5 and vh % 4 == 0, T.float32(0), T.Select(r_a % 6 == 4 and vh % 4 == 3, T.float32(-8), T.Select(r_a % 6 == 4 and vh % 4 == 2, T.float32(4), T.Select(r_a % 6 == 4 and vh % 4 == 1, T.float32(-2), T.Select(r_a % 6 == 4 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 3 and vh % 4 == 3, T.float32(0.125), T.Select(r_a % 6 == 3 and vh % 4 == 2, T.float32(0.25), T.Select(r_a % 6 == 3 and vh % 4 == 1, T.float32(0.5), T.Select(r_a % 6 == 3 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 3, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 2, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 1, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 1 and vh % 4 == 3, T.float32(-1), T.Select(r_a % 6 == 1 and vh % 4 == 2, T.float32(1), T.Select(r_a % 6 == 1 and vh % 4 == 1, T.float32(-1), T.Select(r_a % 6 == 1 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 0 and vh % 4 == 3, T.float32(0), T.Select(r_a % 6 == 0 and vh % 4 == 2, T.float32(0), T.Select(r_a % 6 == 0 and vh % 4 == 1, T.float32(0), T.Select(r_a % 6 == 0 and vh % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) * T.Select(r_b % 6 == 5 and vw % 4 == 3, T.float32(1), T.Select(r_b % 6 == 5 and vw % 4 == 2, T.float32(0), T.Select(r_b % 6 == 5 and vw % 4 == 1, T.float32(0), T.Select(r_b % 6 == 5 and vw % 4 == 0, T.float32(0), T.Select(r_b % 6 == 4 and vw % 4 == 3, T.float32(-8), T.Select(r_b % 6 == 4 and vw % 4 == 2, T.float32(4), T.Select(r_b % 6 == 4 and vw % 4 == 1, T.float32(-2), T.Select(r_b % 6 == 4 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 3 and vw % 4 == 3, T.float32(0.125), T.Select(r_b % 6 == 3 and vw % 4 == 2, T.float32(0.25), T.Select(r_b % 6 == 3 and vw % 4 == 1, T.float32(0.5), T.Select(r_b % 6 == 3 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 3, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 2, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 1, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 1 and vw % 4 == 3, T.float32(-1), T.Select(r_b % 6 == 1 and vw % 4 == 2, T.float32(1), T.Select(r_b % 6 == 1 and vw % 4 == 1, T.float32(-1), T.Select(r_b % 6 == 1 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 0 and vw % 4 == 3, T.float32(0), T.Select(r_b % 6 == 0 and vw % 4 == 2, T.float32(0), T.Select(r_b % 6 == 0 and vw % 4 == 1, T.float32(0), T.Select(r_b % 6 == 0 and vw % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) - for i2_1, i3_1 in T.grid(4, 4): + inverse_local[v_co, v_p, v_vh, v_vw] = T.float32(0) + inverse_local[v_co, v_p, v_vh, v_vw] = inverse_local[v_co, v_p, v_vh, v_vw] + bgemm[v_r_a, v_r_b, v_co, v_p] * T.Select(v_r_a % 6 == 5 and v_vh % 4 == 3, T.float32(1), T.Select(v_r_a % 6 == 5 and v_vh % 4 == 2, T.float32(0), T.Select(v_r_a % 6 == 5 and v_vh % 4 == 1, T.float32(0), T.Select(v_r_a % 6 == 5 and v_vh % 4 == 0, T.float32(0), T.Select(v_r_a % 6 == 4 and v_vh % 4 == 3, T.float32(-8), T.Select(v_r_a % 6 == 4 and v_vh % 4 == 2, T.float32(4), T.Select(v_r_a % 6 == 4 and v_vh % 4 == 1, T.float32(-2), T.Select(v_r_a % 6 == 4 and v_vh % 4 == 0, T.float32(1), T.Select(v_r_a % 6 == 3 and v_vh % 4 == 3, T.float32(0.125), T.Select(v_r_a % 6 == 3 and v_vh % 4 == 2, T.float32(0.25), T.Select(v_r_a % 6 == 3 and v_vh % 4 == 1, T.float32(0.5), T.Select(v_r_a % 6 == 3 and v_vh % 4 == 0, T.float32(1), T.Select(v_r_a % 6 == 2 and v_vh % 4 == 3, T.float32(1), T.Select(v_r_a % 6 == 2 and v_vh % 4 == 2, T.float32(1), T.Select(v_r_a % 6 == 2 and v_vh % 4 == 1, T.float32(1), T.Select(v_r_a % 6 == 2 and v_vh % 4 == 0, T.float32(1), T.Select(v_r_a % 6 == 1 and v_vh % 4 == 3, T.float32(-1), T.Select(v_r_a % 6 == 1 and v_vh % 4 == 2, T.float32(1), T.Select(v_r_a % 6 == 1 and v_vh % 4 == 1, T.float32(-1), T.Select(v_r_a % 6 == 1 and v_vh % 4 == 0, T.float32(1), T.Select(v_r_a % 6 == 0 and v_vh % 4 == 3, T.float32(0), T.Select(v_r_a % 6 == 0 and v_vh % 4 == 2, T.float32(0), T.Select(v_r_a % 6 == 0 and v_vh % 4 == 1, T.float32(0), T.Select(v_r_a % 6 == 0 and v_vh % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) * T.Select(v_r_b % 6 == 5 and v_vw % 4 == 3, T.float32(1), T.Select(v_r_b % 6 == 5 and v_vw % 4 == 2, T.float32(0), T.Select(v_r_b % 6 == 5 and v_vw % 4 == 1, T.float32(0), T.Select(v_r_b % 6 == 5 and v_vw % 4 == 0, T.float32(0), T.Select(v_r_b % 6 == 4 and v_vw % 4 == 3, T.float32(-8), T.Select(v_r_b % 6 == 4 and v_vw % 4 == 2, T.float32(4), T.Select(v_r_b % 6 == 4 and v_vw % 4 == 1, T.float32(-2), T.Select(v_r_b % 6 == 4 and v_vw % 4 == 0, T.float32(1), T.Select(v_r_b % 6 == 3 and v_vw % 4 == 3, T.float32(0.125), T.Select(v_r_b % 6 == 3 and v_vw % 4 == 2, T.float32(0.25), T.Select(v_r_b % 6 == 3 and v_vw % 4 == 1, T.float32(0.5), T.Select(v_r_b % 6 == 3 and v_vw % 4 == 0, T.float32(1), T.Select(v_r_b % 6 == 2 and v_vw % 4 == 3, T.float32(1), T.Select(v_r_b % 6 == 2 and v_vw % 4 == 2, T.float32(1), T.Select(v_r_b % 6 == 2 and v_vw % 4 == 1, T.float32(1), T.Select(v_r_b % 6 == 2 and v_vw % 4 == 0, T.float32(1), T.Select(v_r_b % 6 == 1 and v_vw % 4 == 3, T.float32(-1), T.Select(v_r_b % 6 == 1 and v_vw % 4 == 2, T.float32(1), T.Select(v_r_b % 6 == 1 and v_vw % 4 == 1, T.float32(-1), T.Select(v_r_b % 6 == 1 and v_vw % 4 == 0, T.float32(1), T.Select(v_r_b % 6 == 0 and v_vw % 4 == 3, T.float32(0), T.Select(v_r_b % 6 == 0 and v_vw % 4 == 2, T.float32(0), T.Select(v_r_b % 6 == 0 and v_vw % 4 == 1, T.float32(0), T.Select(v_r_b % 6 == 0 and v_vw % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) + for h_1, w_1 in T.grid(4, 4): with T.block("conv2d_winograd"): - n = T.axis.spatial(1, 0) - co = T.axis.spatial(64, (i0_i1_i2_0_i3_0_fused_0 * 64 + i0_i1_i2_0_i3_0_fused_1) // 196) - h = T.axis.spatial(56, (i0_i1_i2_0_i3_0_fused_0 * 64 + i0_i1_i2_0_i3_0_fused_1) % 196 // 14 * 4 + i2_1) - w = T.axis.spatial(56, (i0_i1_i2_0_i3_0_fused_0 * 64 + i0_i1_i2_0_i3_0_fused_1) % 14 * 4 + i3_1) - T.reads(inverse_local[co, n * 196 + h // 4 * 14 + w // 4, h % 4, w % 4]) - T.writes(conv2d_winograd[n, co, h, w]) - conv2d_winograd[n, co, h, w] = inverse_local[co, n * 196 + h // 4 * 14 + w // 4, h % 4, w % 4] + v_n = T.axis.spatial(1, 0) + v_co = T.axis.spatial(64, (n_co_h_0_w_0_fused_0 * 64 + n_co_h_0_w_0_fused_1) // 196) + v_h = T.axis.spatial(56, (n_co_h_0_w_0_fused_0 * 64 + n_co_h_0_w_0_fused_1) % 196 // 14 * 4 + h_1) + v_w = T.axis.spatial(56, (n_co_h_0_w_0_fused_0 * 64 + n_co_h_0_w_0_fused_1) % 14 * 4 + w_1) + T.reads(inverse_local[v_co, v_n * 196 + v_h // 4 * 14 + v_w // 4, v_h % 4, v_w % 4]) + T.writes(conv2d_winograd[v_n, v_co, v_h, v_w]) + conv2d_winograd[v_n, v_co, v_h, v_w] = inverse_local[v_co, v_n * 196 + v_h // 4 * 14 + v_w // 4, v_h % 4, v_w % 4] # fmt: on decision_0 = [ ("SampleCategorical", 4), @@ -441,64 +437,62 @@ def nchw_add_relu(p0: T.Buffer((2, 2048, 50, 75), "float32"), p1: T.Buffer((4, 4 @T.prim_func def nchw_add_relu_scheduled(p0: T.Buffer((2, 2048, 50, 75), "float32"), p1: T.Buffer((4, 4, 2048, 2048), "float32"), p2: T.Buffer((1, 2048, 1, 1), "float32"), T_relu: T.Buffer((2, 2048, 50, 75), "float32")): - # function attr dict - T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"}) - # body + T.func_attr({"global_symbol": "main", "layout_free_buffers": [1], "tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit":1024}) - input_tile_local = T.alloc_buffer([2048, 1900, 4, 4], dtype="float32", scope="local") - data_pack = T.alloc_buffer([4, 4, 2048, 1900], dtype="float32") - bgemm = T.alloc_buffer([4, 4, 2048, 1900], dtype="float32") - inverse_local = T.alloc_buffer([2048, 1900, 2, 2], dtype="float32", scope="local") - data_pack_local = T.alloc_buffer([4, 4, 2048, 1900], dtype="float32", scope="local") - bgemm_local = T.alloc_buffer([4, 4, 2048, 1900], dtype="float32", scope="local") - data_pack_shared = T.alloc_buffer([4, 4, 2048, 1900], dtype="float32", scope="shared") - p1_shared = T.alloc_buffer([4, 4, 2048, 2048], dtype="float32", scope="shared") + T.block_attr({"meta_schedule.unroll_explicit": 1024}) + input_tile_local = T.alloc_buffer((2048, 1900, 4, 4), scope="local") + data_pack = T.alloc_buffer((4, 4, 2048, 1900)) + bgemm = T.alloc_buffer((4, 4, 2048, 1900)) + inverse_local = T.alloc_buffer((2048, 1900, 2, 2), scope="local") + data_pack_local = T.alloc_buffer((4, 4, 2048, 1900), scope="local") + bgemm_local = T.alloc_buffer((4, 4, 2048, 1900), scope="local") + data_pack_shared = T.alloc_buffer((4, 4, 2048, 1900), scope="shared") + p1_shared = T.alloc_buffer((4, 4, 2048, 2048), scope="shared") for i2_i3_fused_1 in T.thread_binding(256, thread="blockIdx.x"): for i2_i3_fused_2 in T.thread_binding(1024, thread="threadIdx.x"): - for i2_i3_fused_0 in T.serial(15): + for i2_i3_fused_0 in range(15): for ax0, ax1, ax2, ax3 in T.grid(1, 1, 4, 4): with T.block("input_tile"): - T.where(i2_i3_fused_0 * 262144 + i2_i3_fused_1 * 1024 + i2_i3_fused_2 < 3891200) ci = T.axis.spatial(2048, (i2_i3_fused_0 * 262144 + i2_i3_fused_1 * 1024 + i2_i3_fused_2) // 1900 + ax0) p = T.axis.spatial(1900, (i2_i3_fused_0 * 262144 + i2_i3_fused_1 * 1024 + i2_i3_fused_2) % 1900 + ax1) eps, nu = T.axis.remap("SS", [ax2, ax3]) + T.where(i2_i3_fused_0 * 262144 + i2_i3_fused_1 * 1024 + i2_i3_fused_2 < 3891200) T.reads(p0[p // 950, ci, p % 950 // 38 * 2 + eps - 1, p % 38 * 2 + nu - 1]) T.writes(input_tile_local[ci, p, eps, nu]) - T.block_attr({"schedule_rule":"None"}) - input_tile_local[ci, p, eps, nu] = T.if_then_else(1 <= p % 950 // 38 * 2 + eps and p % 950 // 38 * 2 + eps < 51 and 1 <= p % 38 * 2 + nu and p % 38 * 2 + nu < 76, p0[p // 950, ci, p % 950 // 38 * 2 + eps - 1, p % 38 * 2 + nu - 1], T.float32(0), dtype="float32") + T.block_attr({"schedule_rule": "None"}) + input_tile_local[ci, p, eps, nu] = T.if_then_else(1 <= p % 950 // 38 * 2 + eps and p % 950 // 38 * 2 + eps < 51 and 1 <= p % 38 * 2 + nu and p % 38 * 2 + nu < 76, p0[p // 950, ci, p % 950 // 38 * 2 + eps - 1, p % 38 * 2 + nu - 1], T.float32(0)) for i0 in T.unroll(4): for i1 in T.unroll(4): for i4 in T.unroll(4): for i5 in T.unroll(4): with T.block("data_pack"): - T.where((i2_i3_fused_0 * 256 + i2_i3_fused_1) * 1024 + i2_i3_fused_2 < 3891200) eps, nu = T.axis.remap("SS", [i0, i1]) ci = T.axis.spatial(2048, (i2_i3_fused_0 * 262144 + i2_i3_fused_1 * 1024 + i2_i3_fused_2) // 1900) p = T.axis.spatial(1900, (i2_i3_fused_0 * 262144 + i2_i3_fused_1 * 1024 + i2_i3_fused_2) % 1900) r_a, r_b = T.axis.remap("RR", [i4, i5]) + T.where((i2_i3_fused_0 * 256 + i2_i3_fused_1) * 1024 + i2_i3_fused_2 < 3891200) T.reads(input_tile_local[ci, p, r_a, r_b]) T.writes(data_pack_local[eps, nu, ci, p]) - T.block_attr({"schedule_rule":"conv2d_nchw_winograd_data_pack"}) + T.block_attr({"schedule_rule": "conv2d_nchw_winograd_data_pack"}) with T.init(): data_pack_local[eps, nu, ci, p] = T.float32(0) data_pack_local[eps, nu, ci, p] = data_pack_local[eps, nu, ci, p] + input_tile_local[ci, p, r_a, r_b] * T.Select(r_a % 4 == 3 and eps % 4 == 3, T.float32(1), T.Select(r_a % 4 == 3 and eps % 4 == 2, T.float32(0), T.Select(r_a % 4 == 3 and eps % 4 == 1, T.float32(0), T.Select(r_a % 4 == 3 and eps % 4 == 0, T.float32(0), T.Select(r_a % 4 == 2 and eps % 4 == 3, T.float32(0), T.Select(r_a % 4 == 2 and eps % 4 == 2, T.float32(1), T.Select(r_a % 4 == 2 and eps % 4 == 1, T.float32(1), T.Select(r_a % 4 == 2 and eps % 4 == 0, T.float32(-1), T.Select(r_a % 4 == 1 and eps % 4 == 3, T.float32(-1), T.Select(r_a % 4 == 1 and eps % 4 == 2, T.float32(1), T.Select(r_a % 4 == 1 and eps % 4 == 1, T.float32(-1), T.Select(r_a % 4 == 1 and eps % 4 == 0, T.float32(0), T.Select(r_a % 4 == 0 and eps % 4 == 3, T.float32(0), T.Select(r_a % 4 == 0 and eps % 4 == 2, T.float32(0), T.Select(r_a % 4 == 0 and eps % 4 == 1, T.float32(0), T.Select(r_a % 4 == 0 and eps % 4 == 0, T.float32(1), T.float32(0))))))))))))))))) * T.Select(r_b % 4 == 3 and nu % 4 == 3, T.float32(1), T.Select(r_b % 4 == 3 and nu % 4 == 2, T.float32(0), T.Select(r_b % 4 == 3 and nu % 4 == 1, T.float32(0), T.Select(r_b % 4 == 3 and nu % 4 == 0, T.float32(0), T.Select(r_b % 4 == 2 and nu % 4 == 3, T.float32(0), T.Select(r_b % 4 == 2 and nu % 4 == 2, T.float32(1), T.Select(r_b % 4 == 2 and nu % 4 == 1, T.float32(1), T.Select(r_b % 4 == 2 and nu % 4 == 0, T.float32(-1), T.Select(r_b % 4 == 1 and nu % 4 == 3, T.float32(-1), T.Select(r_b % 4 == 1 and nu % 4 == 2, T.float32(1), T.Select(r_b % 4 == 1 and nu % 4 == 1, T.float32(-1), T.Select(r_b % 4 == 1 and nu % 4 == 0, T.float32(0), T.Select(r_b % 4 == 0 and nu % 4 == 3, T.float32(0), T.Select(r_b % 4 == 0 and nu % 4 == 2, T.float32(0), T.Select(r_b % 4 == 0 and nu % 4 == 1, T.float32(0), T.Select(r_b % 4 == 0 and nu % 4 == 0, T.float32(1), T.float32(0))))))))))))))))) for ax0, ax1, ax2, ax3 in T.grid(4, 4, 1, 1): with T.block("data_pack_local"): - T.where(i2_i3_fused_0 * 262144 + i2_i3_fused_1 * 1024 + i2_i3_fused_2 < 3891200) v0, v1 = T.axis.remap("SS", [ax0, ax1]) v2 = T.axis.spatial(2048, (i2_i3_fused_0 * 262144 + i2_i3_fused_1 * 1024 + i2_i3_fused_2) // 1900 + ax2) v3 = T.axis.spatial(1900, (i2_i3_fused_0 * 262144 + i2_i3_fused_1 * 1024 + i2_i3_fused_2) % 1900 + ax3) + T.where(i2_i3_fused_0 * 262144 + i2_i3_fused_1 * 1024 + i2_i3_fused_2 < 3891200) T.reads(data_pack_local[v0, v1, v2, v3]) T.writes(data_pack[v0, v1, v2, v3]) data_pack[v0, v1, v2, v3] = data_pack_local[v0, v1, v2, v3] for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(24320, thread="blockIdx.x"): for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(2, thread="vthread.x"): for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(64, thread="threadIdx.x"): - for i4_0 in T.serial(256): - for ax0_ax1_ax2_ax3_fused in T.serial(640): + for i4_0 in range(256): + for ax0_ax1_ax2_ax3_fused in range(640): with T.block("data_pack_shared"): v0 = T.axis.spatial(4, i0_0_i1_0_i2_0_i3_0_fused // 12160 * 2 + ax0_ax1_ax2_ax3_fused // 320) v1 = T.axis.spatial(4, i0_0_i1_0_i2_0_i3_0_fused % 12160 // 6080 * 2 + ax0_ax1_ax2_ax3_fused % 320 // 160) @@ -506,9 +500,9 @@ def nchw_add_relu_scheduled(p0: T.Buffer((2, 2048, 50, 75), "float32"), p1: T.Bu v3 = T.axis.spatial(1900, i0_0_i1_0_i2_0_i3_0_fused % 95 * 20 + ax0_ax1_ax2_ax3_fused % 20) T.reads(data_pack[v0, v1, v2, v3]) T.writes(data_pack_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch":1}) + T.block_attr({"meta_schedule.cooperative_fetch": 1}) data_pack_shared[v0, v1, v2, v3] = data_pack[v0, v1, v2, v3] - for ax0_ax1_ax2_ax3_fused in T.serial(1024): + for ax0_ax1_ax2_ax3_fused in range(1024): with T.block("p1_shared"): v0 = T.axis.spatial(4, i0_0_i1_0_i2_0_i3_0_fused // 12160 * 2 + ax0_ax1_ax2_ax3_fused // 512) v1 = T.axis.spatial(4, i0_0_i1_0_i2_0_i3_0_fused % 12160 // 6080 * 2 + ax0_ax1_ax2_ax3_fused % 512 // 256) @@ -516,18 +510,18 @@ def nchw_add_relu_scheduled(p0: T.Buffer((2, 2048, 50, 75), "float32"), p1: T.Bu v3 = T.axis.spatial(2048, i0_0_i1_0_i2_0_i3_0_fused % 6080 // 95 * 32 + ax0_ax1_ax2_ax3_fused % 32) T.reads(p1[v0, v1, v2, v3]) T.writes(p1_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch":4}) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) p1_shared[v0, v1, v2, v3] = p1[v0, v1, v2, v3] for i4_1, i0_3, i1_3, i2_3, i3_3, i4_2, i0_4, i1_4, i2_4, i3_4 in T.grid(1, 1, 2, 1, 1, 8, 1, 1, 2, 5): with T.block("bgemm"): - eps = T.axis.spatial(4, i0_4 + i0_0_i1_0_i2_0_i3_0_fused // 12160 * 2 + i0_2_i1_2_i2_2_i3_2_fused // 32 + i0_3) - nu = T.axis.spatial(4, i1_4 + i0_0_i1_0_i2_0_i3_0_fused % 12160 // 6080 * 2 + i1_3) + eps = T.axis.spatial(4, i0_0_i1_0_i2_0_i3_0_fused // 12160 * 2 + i0_2_i1_2_i2_2_i3_2_fused // 32 + i0_3 + i0_4) + nu = T.axis.spatial(4, i0_0_i1_0_i2_0_i3_0_fused % 12160 // 6080 * 2 + i1_3 + i1_4) co = T.axis.spatial(2048, i0_0_i1_0_i2_0_i3_0_fused % 6080 // 95 * 32 + i0_1_i1_1_i2_1_i3_1_fused * 16 + i0_2_i1_2_i2_2_i3_2_fused % 32 // 4 * 2 + i2_3 * 2 + i2_4) p = T.axis.spatial(1900, i0_0_i1_0_i2_0_i3_0_fused % 95 * 20 + i0_2_i1_2_i2_2_i3_2_fused % 4 * 5 + i3_3 * 5 + i3_4) ci = T.axis.reduce(2048, i4_0 * 8 + i4_1 * 8 + i4_2) T.reads(data_pack_shared[eps, nu, ci, p], p1_shared[eps, nu, ci, co]) T.writes(bgemm_local[eps, nu, co, p]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): bgemm_local[eps, nu, co, p] = T.float32(0) bgemm_local[eps, nu, co, p] = bgemm_local[eps, nu, co, p] + data_pack_shared[eps, nu, ci, p] * p1_shared[eps, nu, ci, co] @@ -551,19 +545,19 @@ def nchw_add_relu_scheduled(p0: T.Buffer((2, 2048, 50, 75), "float32"), p1: T.Bu vh, vw, r_a, r_b = T.axis.remap("SSRR", [ax2, ax3, ax4, ax5]) T.reads(bgemm[r_a, r_b, co, p]) T.writes(inverse_local[co, p, vh, vw]) - T.block_attr({"schedule_rule":"conv2d_nchw_winograd_inverse"}) + T.block_attr({"schedule_rule": "conv2d_nchw_winograd_inverse"}) with T.init(): inverse_local[co, p, vh, vw] = T.float32(0) inverse_local[co, p, vh, vw] = inverse_local[co, p, vh, vw] + bgemm[r_a, r_b, co, p] * T.Select(r_a % 4 == 3 and vh % 2 == 1, T.float32(1), T.Select(r_a % 4 == 3 and vh % 2 == 0, T.float32(0), T.Select(r_a % 4 == 2 and vh % 2 == 1, T.float32(1), T.Select(r_a % 4 == 2 and vh % 2 == 0, T.float32(1), T.Select(r_a % 4 == 1 and vh % 2 == 1, T.float32(-1), T.Select(r_a % 4 == 1 and vh % 2 == 0, T.float32(1), T.Select(r_a % 4 == 0 and vh % 2 == 1, T.float32(0), T.Select(r_a % 4 == 0 and vh % 2 == 0, T.float32(1), T.float32(0))))))))) * T.Select(r_b % 4 == 3 and vw % 2 == 1, T.float32(1), T.Select(r_b % 4 == 3 and vw % 2 == 0, T.float32(0), T.Select(r_b % 4 == 2 and vw % 2 == 1, T.float32(1), T.Select(r_b % 4 == 2 and vw % 2 == 0, T.float32(1), T.Select(r_b % 4 == 1 and vw % 2 == 1, T.float32(-1), T.Select(r_b % 4 == 1 and vw % 2 == 0, T.float32(1), T.Select(r_b % 4 == 0 and vw % 2 == 1, T.float32(0), T.Select(r_b % 4 == 0 and vw % 2 == 0, T.float32(1), T.float32(0))))))))) for i0_i1_i2_i3_fused_1 in T.thread_binding(256, thread="blockIdx.x"): for i0_i1_i2_i3_fused_2 in T.thread_binding(1024, thread="threadIdx.x"): - for i0_i1_i2_i3_fused_0 in T.serial(59): + for i0_i1_i2_i3_fused_0 in range(59): with T.block("T_add"): - T.where((i0_i1_i2_i3_fused_0 * 256 + i0_i1_i2_i3_fused_1) * 1024 + i0_i1_i2_i3_fused_2 < 15360000) ax0 = T.axis.spatial(2, (i0_i1_i2_i3_fused_0 * 262144 + i0_i1_i2_i3_fused_1 * 1024 + i0_i1_i2_i3_fused_2) // 7680000) ax1 = T.axis.spatial(2048, (i0_i1_i2_i3_fused_0 * 262144 + i0_i1_i2_i3_fused_1 * 1024 + i0_i1_i2_i3_fused_2) % 7680000 // 3750) ax2 = T.axis.spatial(50, (i0_i1_i2_i3_fused_0 * 262144 + i0_i1_i2_i3_fused_1 * 1024 + i0_i1_i2_i3_fused_2) % 3750 // 75) ax3 = T.axis.spatial(75, (i0_i1_i2_i3_fused_0 * 262144 + i0_i1_i2_i3_fused_1 * 1024 + i0_i1_i2_i3_fused_2) % 75) + T.where((i0_i1_i2_i3_fused_0 * 256 + i0_i1_i2_i3_fused_1) * 1024 + i0_i1_i2_i3_fused_2 < 15360000) T.reads(inverse_local[ax1, ax0 * 950 + ax2 // 2 * 38 + ax3 // 2, ax2 % 2, ax3 % 2], p2[0, ax1, 0, 0]) T.writes(T_relu[ax0, ax1, ax2, ax3]) T_relu[ax0, ax1, ax2, ax3] = T.max(inverse_local[ax1, ax0 * 950 + ax2 // 2 * 38 + ax3 // 2, ax2 % 2, ax3 % 2] + p2[0, ax1, 0, 0], T.float32(0)) diff --git a/tests/python/unittest/test_meta_schedule_trace_apply.py b/tests/python/unittest/test_meta_schedule_trace_apply.py index d09f2a226cba..4d22c4ff8854 100644 --- a/tests/python/unittest/test_meta_schedule_trace_apply.py +++ b/tests/python/unittest/test_meta_schedule_trace_apply.py @@ -192,7 +192,7 @@ def main(p0: T.Buffer((128, 128), "float32"), p1: T.Buffer((128, 128), "float32" for i0_2_init, i1_2_init, i0_3_init in T.grid(4, 4, 2): for i1_3_fused_init in T.vectorized(32): with T.block("T_matmul_NT_init"): - i = T.axis.spatial(128, i0_0_i1_0_i0_1_i1_1_fused // 4 * 32 + i0_0_i1_0_i0_1_i1_1_fused % 4 * 8 + i0_2_init * 2 + i0_3_init) + i = T.axis.spatial(128, i0_0_i1_0_i0_1_i1_1_fused * 8 + i0_2_init * 2 + i0_3_init) j = T.axis.spatial(128, i1_2_init * 32 + i1_3_fused_init) T.reads() T.writes(T_matmul_NT[i, j]) @@ -201,7 +201,7 @@ def main(p0: T.Buffer((128, 128), "float32"), p1: T.Buffer((128, 128), "float32" for i2_0, i0_2, i1_2, i2_1, i0_3 in T.grid(8, 4, 4, 16, 2): for i1_3_fused in T.vectorized(32): with T.block("T_matmul_NT_update"): - i = T.axis.spatial(128, i0_0_i1_0_i0_1_i1_1_fused // 4 * 32 + i0_0_i1_0_i0_1_i1_1_fused % 4 * 8 + i0_2 * 2 + i0_3) + i = T.axis.spatial(128, i0_0_i1_0_i0_1_i1_1_fused * 8 + i0_2 * 2 + i0_3) j = T.axis.spatial(128, i1_2 * 32 + i1_3_fused) k = T.axis.reduce(128, i2_0 * 16 + i2_1) T.reads(T_matmul_NT[i, j], p0[i, k], p1_global[k // 16, j // 32, k % 16, j % 32]) @@ -250,10 +250,10 @@ def main( ) j = T.axis.spatial( 128, - i1_4_init - + i0_0_i1_0_fused % 4 * 32 + i0_0_i1_0_fused % 4 * 32 + i0_2_i1_2_fused % 8 * 4 - + i1_3_init, + + i1_3_init + + i1_4_init, ) T.reads() T.writes(T_matmul_NT_local[i, j]) @@ -339,10 +339,10 @@ def main( ) j = T.axis.spatial( 128, - i1_4 - + i0_0_i1_0_fused % 4 * 32 + i0_0_i1_0_fused % 4 * 32 + i0_2_i1_2_fused % 8 * 4 - + i1_3, + + i1_3 + + i1_4, ) k = T.axis.reduce(128, i2_0 * 4 + i2_1 * 4 + i2_2) T.reads(T_matmul_NT_local[i, j], p0_shared[i, k], p1_shared[j, k]) @@ -671,7 +671,7 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " for ax2_0_3_init, ax3_0_3_init, ax2_0_4_init, ax3_0_4_init in T.grid(1, 1, 1, 1): with T.block("conv2d_nhwc_o_init"): v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3_init + ax2_0_4_init) - v3_o = T.axis.spatial(16, ax3_0_4_init + ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2 + ax3_0_3_init) + v3_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2 + ax3_0_3_init + ax3_0_4_init) T.reads() T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "warp_execution":1}) @@ -722,10 +722,10 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " T.tvm_load_matrix_sync(C_2.data, 16, 16, 16, C_2.elem_offset // C_s0_2 // 16 * (C_s0_2 // 16) + C_2.elem_offset % C_s0_2 // 16, T.tvm_access_ptr(T.type_annotation(dtype="int8"), A_1.data, A_1.elem_offset, A_s0_1 * 16, 1, dtype="handle"), A_s0_1, "col_major", dtype="handle") for ax2_0_3, ax3_0_3, ax0_2, ax1_2, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(1, 1, 1, 1, 2, 1, 1): with T.block("conv2d_nhwc_o_update"): - v0 = T.axis.reduce(1, ax0_2 + ax0_0 + ax0_1) - v1 = T.axis.reduce(1, ax1_1 + ax1_2 + ax1_0) + v0 = T.axis.reduce(1, ax0_0 + ax0_1 + ax0_2) + v1 = T.axis.reduce(1, ax1_0 + ax1_1 + ax1_2) v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3 + ax2_0_4) - v3_o = T.axis.spatial(16, ax3_0_4 + ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2 + ax3_0_3) + v3_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2 + ax3_0_3 + ax3_0_4) v4_o = T.axis.reduce(4, ax4_0_0 * 2 + ax4_0_1 * 2 + ax4_0_2) T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 : v2_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 : v3_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16]) T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) @@ -1147,11 +1147,11 @@ def main(p0: T.Buffer((1, 32, 7, 7, 16), "uint8"), p1: T.Buffer((128, 32, 1, 1, for i2_1, i3_1, i4_0_1 in T.grid(7, 1, 1): for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i4_0_2_init, i0_3_init, i1_3_init, i2_3_init, i3_3_init, i4_0_3_init in T.grid(1, 1, 1, 1, 1, 1, 1, 1, 7, 1): with T.block("conv2d_NCHWc_int8_o_init"): - n = T.axis.spatial(1, i0_3_init + i0_2_init) - oc_chunk = T.axis.spatial(128, i1_2_init + i1_3_init + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused // 32 * 32 + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused % 32) + n = T.axis.spatial(1, i0_2_init + i0_3_init) + oc_chunk = T.axis.spatial(128, i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused + i1_2_init + i1_3_init) oh = T.axis.spatial(7, i2_1 + i2_2_init + i2_3_init) ow = T.axis.spatial(7, i3_1 * 7 + i3_2_init * 7 + i3_3_init) - oc_block_o = T.axis.spatial(1, i4_0_3_init + i4_0_1 + i4_0_2_init) + oc_block_o = T.axis.spatial(1, i4_0_1 + i4_0_2_init + i4_0_3_init) T.reads() T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16]) for i4_1 in T.vectorized(16): @@ -1162,15 +1162,15 @@ def main(p0: T.Buffer((1, 32, 7, 7, 16), "uint8"), p1: T.Buffer((128, 32, 1, 1, conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i_init] = 0 for i5_0, i6_0, i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3 in T.grid(1, 1, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 8, 1, 1, 1, 1, 1, 7, 1): with T.block("conv2d_NCHWc_int8_o_update"): - n = T.axis.spatial(1, i0_3 + i0_2) - oc_chunk = T.axis.spatial(128, i1_2 + i1_3 + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused // 32 * 32 + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused % 32) + n = T.axis.spatial(1, i0_2 + i0_3) + oc_chunk = T.axis.spatial(128, i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused + i1_2 + i1_3) oh = T.axis.spatial(7, i2_1 + i2_2 + i2_3) ow = T.axis.spatial(7, i3_1 * 7 + i3_2 * 7 + i3_3) - oc_block_o = T.axis.spatial(1, i4_0_3 + i4_0_1 + i4_0_2) + oc_block_o = T.axis.spatial(1, i4_0_1 + i4_0_2 + i4_0_3) kh = T.axis.reduce(1, i5_0 + i5_1) - kw = T.axis.reduce(1, i6_1 + i6_0) + kw = T.axis.reduce(1, i6_0 + i6_1) ic_outer = T.axis.reduce(32, i7_0 * 8 + i7_1) - ic_f_inner = T.axis.reduce(4, i8_1 + i8_0) + ic_f_inner = T.axis.reduce(4, i8_0 + i8_1) ic_s_inner_o = T.axis.reduce(1, i9_0_0 + i9_0_1) T.reads(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16], p0[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4]) T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16]) @@ -1187,7 +1187,7 @@ def main(p0: T.Buffer((1, 32, 7, 7, 16), "uint8"), p1: T.Buffer((128, 32, 1, 1, for ax4_fused in T.vectorized(16): with T.block("T_cast_8"): ax0_1 = T.axis.spatial(1, ax0) - ax1_1 = T.axis.spatial(128, i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused // 32 * 32 + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused % 32 + ax1) + ax1_1 = T.axis.spatial(128, i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused + ax1) ax2_1 = T.axis.spatial(7, i2_1 + ax2) ax3_1, ax4 = T.axis.remap("SS", [ax3, ax4_fused]) T.reads(conv2d_NCHWc_int8[ax0_1, ax1_1, ax2_1, ax3_1, ax4], p2[ax0_1, ax1_1, 0, 0, ax4], p3[ax0_1, ax1_1, 0, 0, ax4], p4[0], p5[ax0_1, ax1_1, ax2_1, ax3_1, ax4]) @@ -1440,10 +1440,10 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((6, 6, 64, 64), for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(48, thread="threadIdx.x"): for i0_3_init, i1_3_init, i2_3_init, i3_3_init, i0_4_init, i1_4_init, i2_4_init, i3_4_init in T.grid(1, 1, 14, 1, 1, 1, 1, 1): with T.block("bgemm_init"): - eps = T.axis.spatial(6, i0_4_init + i0_1_i1_1_i2_1_i3_1_fused // 2 * 3 + i0_2_i1_2_i2_2_i3_2_fused // 16 + i0_3_init) - nu = T.axis.spatial(6, i1_4_init + i0_0_i1_0_i2_0_i3_0_fused // 28 + i1_3_init) + eps = T.axis.spatial(6, i0_1_i1_1_i2_1_i3_1_fused // 2 * 3 + i0_2_i1_2_i2_2_i3_2_fused // 16 + i0_3_init + i0_4_init) + nu = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 28 + i1_3_init + i1_4_init) p = T.axis.spatial(196, i0_0_i1_0_i2_0_i3_0_fused % 28 // 4 * 28 + i0_1_i1_1_i2_1_i3_1_fused % 2 * 14 + i2_3_init + i2_4_init) - co = T.axis.spatial(64, i3_4_init + i0_0_i1_0_i2_0_i3_0_fused % 4 * 16 + i0_2_i1_2_i2_2_i3_2_fused % 16 + i3_3_init) + co = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 4 * 16 + i0_2_i1_2_i2_2_i3_2_fused % 16 + i3_3_init + i3_4_init) T.reads() T.writes(bgemm_local[eps, nu, p, co]) T.block_attr({"layout_free_placeholders":[], "meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) @@ -1473,10 +1473,10 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((6, 6, 64, 64), p1_shared[v0, v1, v2, v3] = p1[v0, v1, v2, v3] for i4_1, i0_3, i1_3, i2_3, i3_3, i4_2, i0_4, i1_4, i2_4, i3_4 in T.grid(2, 1, 1, 14, 1, 16, 1, 1, 1, 1): with T.block("bgemm_update"): - eps = T.axis.spatial(6, i0_4 + i0_1_i1_1_i2_1_i3_1_fused // 2 * 3 + i0_2_i1_2_i2_2_i3_2_fused // 16 + i0_3) - nu = T.axis.spatial(6, i1_4 + i0_0_i1_0_i2_0_i3_0_fused // 28 + i1_3) + eps = T.axis.spatial(6, i0_1_i1_1_i2_1_i3_1_fused // 2 * 3 + i0_2_i1_2_i2_2_i3_2_fused // 16 + i0_3 + i0_4) + nu = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 28 + i1_3 + i1_4) p = T.axis.spatial(196, i0_0_i1_0_i2_0_i3_0_fused % 28 // 4 * 28 + i0_1_i1_1_i2_1_i3_1_fused % 2 * 14 + i2_3 + i2_4) - co = T.axis.spatial(64, i3_4 + i0_0_i1_0_i2_0_i3_0_fused % 4 * 16 + i0_2_i1_2_i2_2_i3_2_fused % 16 + i3_3) + co = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 4 * 16 + i0_2_i1_2_i2_2_i3_2_fused % 16 + i3_3 + i3_4) ci = T.axis.reduce(64, i4_0 * 32 + i4_1 * 16 + i4_2) T.reads(bgemm_local[eps, nu, p, co], data_pack_shared[eps, nu, p, ci], p1_shared[eps, nu, co, ci]) T.writes(bgemm_local[eps, nu, p, co]) @@ -1766,8 +1766,8 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] = p1_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] for ax2_0_3, ax3_0_3, ax0_2, ax1_2, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(1, 1, 1, 1, 1, 1, 2): with T.block("conv2d_nhwc_o"): - v0 = T.axis.reduce(1, ax0_2 + ax0_0 + ax0_1) - v1 = T.axis.reduce(1, ax1_1 + ax1_2 + ax1_0) + v0 = T.axis.reduce(1, ax0_0 + ax0_1 + ax0_2) + v1 = T.axis.reduce(1, ax1_0 + ax1_1 + ax1_2) v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3 + ax2_0_4) v3_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax3_0_3 * 2 + ax3_0_4) v4_o = T.axis.reduce(4, ax4_0_0 * 2 + ax4_0_1 + ax4_0_2) From 5e2f800772b011d1b322e738756b4e44704d64d9 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 10 Apr 2023 19:08:42 -0400 Subject: [PATCH 3/3] Add extra regression tests on floormod(x, 2) change with itermap simplify. --- tests/python/unittest/test_arith_iter_affine_map.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index 8fdf6b157076..45ec5f1e27e8 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -1190,6 +1190,12 @@ def test_iter_map_simplify_unit_loop_order(): simplify_trivial_iterators=False, ) + assert_iter_map_simplfy( + {y + 64 - x % 2 * 64: y + 64 - x % 2 * 64}, + var_dom([(x, 6), (y, 64)]), + simplify_trivial_iterators=False, + ) + if __name__ == "__main__": tvm.testing.main()