From 3040864235d6e7449eda5085a3c1212e389bd7c9 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 7 Dec 2022 13:33:35 -0800 Subject: [PATCH 1/2] [TIR] Add preserve_unit_iters option to blockize/tensorize --- include/tvm/arith/iter_affine_map.h | 5 +- include/tvm/tir/schedule/schedule.h | 11 +- python/tvm/arith/iter_affine_map.py | 15 +- python/tvm/tir/schedule/schedule.py | 17 +- src/arith/iter_affine_map.cc | 31 +- src/tir/schedule/concrete_schedule.cc | 16 +- src/tir/schedule/concrete_schedule.h | 6 +- src/tir/schedule/primitive.h | 6 +- .../schedule/primitive/blockize_tensorize.cc | 51 ++-- src/tir/schedule/schedule.cc | 6 +- src/tir/schedule/traced_schedule.cc | 20 +- src/tir/schedule/traced_schedule.h | 6 +- .../unittest/test_arith_iter_affine_map.py | 29 ++ .../unittest/test_meta_schedule_runner.py | 3 + ..._meta_schedule_schedule_rule_mlt_intrin.py | 30 +- ...test_meta_schedule_schedule_rule_mlt_tc.py | 41 +-- .../test_meta_schedule_trace_apply.py | 278 +++++++++--------- .../unittest/test_tir_schedule_blockize.py | 29 +- 18 files changed, 355 insertions(+), 245 deletions(-) diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 6b98d84fdf17..0d8bd574ae6e 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -396,6 +396,8 @@ Map InverseAffineIterMap(const Array& iter_map, * \param predicate The predicate constraints on the input iterators * \param check_level The iter mapping checking level. * \param analyzer Analyzer used to get context information. + * \param simplify_trivial_iterators If true, iterators with extent of + * 1 will be replaced with a constant value. * * \return The result list has length len(bindings) + 1 [0, len(bindings)): The iter map matching result. The inner list is of length 2. @@ -407,7 +409,8 @@ Map InverseAffineIterMap(const Array& iter_map, Array> SubspaceDivide(const Array& bindings, const Map& input_iters, const Array& sub_iters, const PrimExpr& predicate, - IterMapLevel check_level, arith::Analyzer* analyzer); + IterMapLevel check_level, arith::Analyzer* analyzer, + bool simplify_trivial_iterators = true); /*! * \brief Given an expression that may contain IterMapExpr, transform it to normal PrimExpr. diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 5dbc1b5af395..c4838f2eb8aa 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -563,21 +563,26 @@ class ScheduleNode : public runtime::Object { /*! * \brief Convert the subtree rooted at a specific loop into a block. * \param loop_rv the root of the subtree + * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings * \return the new block */ - virtual BlockRV Blockize(const LoopRV& loop_rv) = 0; + virtual BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters = true) = 0; /*! * \brief Tensorize the computation enclosed by loop with the tensor intrin. * \param loop_rv The loop to be tensorized * \param intrin Name of the tensor intrinsic + * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings */ - virtual void Tensorize(const LoopRV& loop_rv, const String& intrin) = 0; + virtual void Tensorize(const LoopRV& loop_rv, const String& intrin, + bool preserve_unit_iters = true) = 0; /*! * \brief Tensorize the computation enclosed by loop with the tensor intrin. * \param block_rv The block to be tensorized * \param intrin Name of the tensor intrinsic + * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings */ - virtual void Tensorize(const BlockRV& block_rv, const String& intrin) = 0; + virtual void Tensorize(const BlockRV& block_rv, const String& intrin, + bool preserve_unit_iters = true) = 0; /******** Schedule: Annotation ********/ /*! diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py index 77d6f418b853..54dbcef32590 100644 --- a/python/tvm/arith/iter_affine_map.py +++ b/python/tvm/arith/iter_affine_map.py @@ -173,7 +173,12 @@ def normalize_iter_map_to_expr(expr): def subspace_divide( - bindings, input_iters, sub_iters, predicate=True, check_level=IterMapLevel.Surjective + bindings, + input_iters, + sub_iters, + predicate=True, + check_level=IterMapLevel.Surjective, + simplify_trivial_iterators=True, ): """Detect if bindings can be written as [a_0*e_0 + b_0 + c_0, a_1*e_1 + b_1, ..., a_n*e_n + b_n] @@ -206,6 +211,10 @@ def subspace_divide( check_level : Union[str, IterMapLevel] Checking level of iteration mapping + simplify_trivial_iterators: bool + If true, iterators with extent of 1 will be replaced with a + constant value. + Returns ------- results : List[List[PrimExpr]] @@ -218,7 +227,9 @@ def subspace_divide( """ if isinstance(check_level, str): check_level = IterMapLevel.from_str(check_level) - return _ffi_api.SubspaceDivide(bindings, input_iters, sub_iters, predicate, check_level) + return _ffi_api.SubspaceDivide( + bindings, input_iters, sub_iters, predicate, check_level, simplify_trivial_iterators + ) def inverse_affine_iter_map(iter_map, outputs): diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 91c42f2a8d1d..5ff9d7131396 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2186,13 +2186,15 @@ def after_set_scope( ########## Schedule: Blockize & Tensorize ########## @type_checked - def blockize(self, loop: LoopRV) -> BlockRV: + def blockize(self, loop: LoopRV, preserve_unit_iters: bool = True) -> BlockRV: """Convert the subtree rooted at a specific loop into a block. Parameters ---------- loop : LoopRV The root of the subtree. + preserve_unit_iters : bool + Whether or not to preserve unit iterators in block bindings Returns ------- @@ -2257,10 +2259,15 @@ def after_blockize( block are divisible by the subspace represented by the loops starting at the given loop. """ - return _ffi_api.ScheduleBlockize(self, loop) # type: ignore # pylint: disable=no-member + return _ffi_api.ScheduleBlockize(self, loop, preserve_unit_iters) # type: ignore # pylint: disable=no-member @type_checked - def tensorize(self, block_or_loop: Union[BlockRV, LoopRV], tensor_intrin: str) -> None: + def tensorize( + self, + block_or_loop: Union[BlockRV, LoopRV], + tensor_intrin: str, + preserve_unit_iters: bool = True, + ) -> None: """Tensorize the computation enclosed by loop with the tensor intrinsic. Parameters @@ -2269,6 +2276,8 @@ def tensorize(self, block_or_loop: Union[BlockRV, LoopRV], tensor_intrin: str) - The loop to be tensorized. tensor_intrin : str The tensor intrin or the name of the tensor intrin. + preserve_unit_iters : bool + Whether or not to preserve unit iterators in block bindings Examples -------- @@ -2402,7 +2411,7 @@ def after_tensorize( ) """ _ffi_api.ScheduleTensorize( # type: ignore # pylint: disable=no-member - self, block_or_loop, tensor_intrin + self, block_or_loop, tensor_intrin, preserve_unit_iters ) ########## Schedule: Annotation ########## diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index adba61632fb2..03a36e803be8 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -1812,18 +1812,26 @@ class SubspaceDivider { // extent of inner PrimExpr inner_extent; + // The kind of the division result. + enum class Kind { + kInner, // Indicates the division result is totally in inner subspace. + kOuter, // Indicates the division result is totally in outer subspace. + kMixed, // Indicates the division result is mixed in both subspace. + } kind; + DivisionResult(IterMapExpr outer, PrimExpr outer_extent, IterMapExpr inner, - PrimExpr inner_extent) + PrimExpr inner_extent, Kind kind = Kind::kMixed) : outer(std::move(outer)), inner(std::move(inner)), outer_extent(std::move(outer_extent)), - inner_extent(std::move(inner_extent)) {} + inner_extent(std::move(inner_extent)), + kind(kind) {} // whether the division result is totally in outer subspace - bool IsOuter() const { return is_one(inner_extent); } + bool IsOuter() const { return kind == Kind::kOuter; } // whether the division result is totally in inner subspace - bool IsInner() const { return is_one(outer_extent); } + bool IsInner() const { return kind == Kind::kInner; } IterSplitExpr GetOuterAsSplit() const { return GetAsSplit(outer, outer_extent); } @@ -1832,13 +1840,13 @@ class SubspaceDivider { static DivisionResult Inner(const IterMapExpr& iter, const PrimExpr& extent) { auto dtype = iter.dtype(); return DivisionResult(IterSumExpr({}, make_const(dtype, 0)), make_const(dtype, 1), iter, - extent); + extent, Kind::kInner); } static DivisionResult Outer(const IterMapExpr& iter, const PrimExpr& extent) { auto dtype = iter.dtype(); return DivisionResult(iter, extent, IterSumExpr({}, make_const(dtype, 0)), - make_const(dtype, 1)); + make_const(dtype, 1), Kind::kOuter); } // Special value to indicate the division is not possible @@ -2066,9 +2074,11 @@ class SubspaceDivider { Array> SubspaceDivide(const Array& bindings, const Map& input_iters, const Array& sub_iters, const PrimExpr& predicate, - IterMapLevel check_level, arith::Analyzer* analyzer) { + IterMapLevel check_level, arith::Analyzer* analyzer, + bool simplify_trivial_iterators) { if (!IterRangeSanityCheck(input_iters)) return Array>(); - auto res = DetectIterMap(bindings, input_iters, predicate, check_level, analyzer); + auto res = DetectIterMap(bindings, input_iters, predicate, check_level, analyzer, + simplify_trivial_iterators); const Array& maps = res->indices; if (maps.empty()) return {}; @@ -2096,10 +2106,11 @@ Array> SubspaceDivide(const Array& bindings, TVM_REGISTER_GLOBAL("arith.SubspaceDivide") .set_body_typed([](const Array& bindings, const Map& root_iters, - const Array& sub_iters, const PrimExpr& predicate, int check_level) { + const Array& sub_iters, const PrimExpr& predicate, int check_level, + bool simplify_trivial_iterators) { arith::Analyzer ana; return SubspaceDivide(bindings, root_iters, sub_iters, predicate, IterMapLevel(check_level), - &ana); + &ana, simplify_trivial_iterators); }); class InverseAffineIterMapTransformer { diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index a0d29a00f886..7ae0185b425c 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -690,25 +690,29 @@ BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { } /******** Schedule: Blockize & Tensorize ********/ -BlockRV ConcreteScheduleNode::Blockize(const LoopRV& loop_rv) { +BlockRV ConcreteScheduleNode::Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); - result = tir::Blockize(state_, this->GetSRef(loop_rv)); + result = tir::Blockize(state_, this->GetSRef(loop_rv), preserve_unit_iters); this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("blockize", this->error_render_level_); return CreateRV(result); } -void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin) { +void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin, + bool preserve_unit_iters) { TVM_TIR_SCHEDULE_BEGIN(); - tir::Tensorize(state_, this->GetSRef(loop_rv), tir::TensorIntrin::Get(intrin).value()); + tir::Tensorize(state_, this->GetSRef(loop_rv), tir::TensorIntrin::Get(intrin).value(), + preserve_unit_iters); this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_); } -void ConcreteScheduleNode::Tensorize(const BlockRV& block_rv, const String& intrin) { +void ConcreteScheduleNode::Tensorize(const BlockRV& block_rv, const String& intrin, + bool preserve_unit_iters) { TVM_TIR_SCHEDULE_BEGIN(); - tir::Tensorize(state_, this->GetSRef(block_rv), tir::TensorIntrin::Get(intrin).value()); + tir::Tensorize(state_, this->GetSRef(block_rv), tir::TensorIntrin::Get(intrin).value(), + preserve_unit_iters); this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_); } diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 66fca107715b..2381870760a0 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -137,9 +137,9 @@ class ConcreteScheduleNode : public ScheduleNode { int offset) override; void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) override; /******** Schedule: Blockize & Tensorize ********/ - BlockRV Blockize(const LoopRV& loop_rv) override; - void Tensorize(const BlockRV& block_rv, const String& intrin) override; - void Tensorize(const LoopRV& loop_rv, const String& intrin) override; + BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) override; + void Tensorize(const BlockRV& block_rv, const String& intrin, bool preserve_unit_iters) override; + void Tensorize(const LoopRV& loop_rv, const String& intrin, bool preserve_unit_iters) override; /******** Schedule: Annotation ********/ void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override; void Unannotate(const LoopRV& loop_rv, const String& ann_key) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index af1988eaaf36..38931aa27147 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -452,18 +452,20 @@ TVM_DLL void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, in * \brief Convert the subtree rooted at a specific loop into a block. * \param self The state of the schedule * \param loop_sref The root of the subtree + * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings * \return The new block */ -TVM_DLL StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref); +TVM_DLL StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref, bool preserve_unit_iters); /*! * \brief Tensorize the computation enclosed by loop with the tensor intrinsic. * \param self The state of the schedule * \param block_or_loop_sref The block or loop to be tensorized. * \param intrin The tensor intrinsic. + * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings */ TVM_DLL void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref, - const TensorIntrin& intrin); + const TensorIntrin& intrin, bool preserve_unit_iters); /******** Schedule: Annotation ********/ /*! diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index 80a653c544b0..4b4e98638505 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -76,7 +76,7 @@ class SubspaceNotDivisibleError : public ScheduleError { * 1. The binding covers no inner loop vars. * 2. The binding covers only inner loop vars. * - * The bindings are not required to be quasi-affine. + * The bindings are not required to be quasi-affine. Trivial block iters are always preserved. * * \param iter_vars The input iterators * \param bindings The values of iter_vars @@ -146,12 +146,13 @@ Array> TrivialSubspaceDivision(const Array& iter * \param loop_sref The loop that is the root of the second subspace. * \param loops The loops that represents the second part of the subspace. * \param analyzer The arithmetic analyzer to use. + * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings */ Array> SubspaceDivide(const BlockRealize& realize, const StmtSRef& block_sref, // const StmtSRef& loop_sref, // std::vector* loops, - arith::Analyzer* analyzer) { + arith::Analyzer* analyzer, bool preserve_unit_iters) { Array inner_vars; Array outer_vars; Map loop_var_domain; @@ -173,7 +174,8 @@ Array> SubspaceDivide(const BlockRealize& realize, } Array> result = arith::SubspaceDivide(realize->iter_values, loop_var_domain, inner_vars, realize->predicate, - arith::IterMapLevel::Surjective, analyzer); + arith::IterMapLevel::Surjective, analyzer, + /*simplify_trivial_iterators=*/!preserve_unit_iters); if (!result.empty()) { return result; } @@ -191,6 +193,7 @@ Array> SubspaceDivide(const BlockRealize& realize, * \param outer_bindings The outer block bindings. * \param inner_iter_vars The inner block iterators. * \param inner_bindings The inner block bindings. + * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings * \return A substitution plan to the iterators in the original inner block. */ Map DeriveBlockBinding(const Array& iter_vars, // @@ -198,7 +201,7 @@ Map DeriveBlockBinding(const Array& iter_vars, Array* outer_iter_vars, // Array* outer_bindings, // Array* inner_iter_vars, // - Array* inner_bindings) { + Array* inner_bindings, bool preserve_unit_iters) { using arith::IterMapExpr; using arith::IterMapExprNode; using arith::NormalizeIterMapToExpr; @@ -427,7 +430,8 @@ Stmt MakeLoopNest(Stmt stmt, const std::vector& loops) { } BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref, - Map* block_sref_reuse, arith::Analyzer* analyzer) { + bool preserve_unit_iters, Map* block_sref_reuse, + arith::Analyzer* analyzer) { TVM_SREF_TO_FOR(loop_sref); // Step 1: Check and get the only block under `loop`. BlockRealize block_realize = CheckGetSingleChildBlockRealizeOnSRefTree(self, loop_sref); @@ -436,7 +440,7 @@ BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref, // Step 2: Derive subspace division std::vector loops; Array> division = - SubspaceDivide(block_realize, block_sref, loop_sref, &loops, analyzer); + SubspaceDivide(block_realize, block_sref, loop_sref, &loops, analyzer, preserve_unit_iters); if (division.empty()) { throw SubspaceNotDivisibleError(self->mod, GetRef(loops.back()), block); } @@ -450,7 +454,8 @@ BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref, Map block_var_subst = // DeriveBlockBinding(block->iter_vars, division, // &outer_iter_vars, &outer_bindings, // - &inner_iter_vars, &inner_bindings); + &inner_iter_vars, &inner_bindings, // + preserve_unit_iters); // Step 4: Do var substitution to adjust to the new block bindings Map inner_iter_dom; for (const IterVar& iter : inner_iter_vars) { @@ -494,10 +499,11 @@ BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref, : Optional(NullOpt))); } -StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) { +StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref, bool preserve_unit_iters) { arith::Analyzer analyzer; Map block_sref_reuse; - BlockRealize blockized = BlockizeImpl(self, loop_sref, &block_sref_reuse, &analyzer); + BlockRealize blockized = + BlockizeImpl(self, loop_sref, preserve_unit_iters, &block_sref_reuse, &analyzer); self->Replace(loop_sref, blockized, block_sref_reuse); StmtSRef result = self->stmt2ref.at(blockized->block.get()); StmtSRef scope_root = tir::GetScopeRoot(self, result, /*require_stage_pipeline=*/false); @@ -507,7 +513,8 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) { return result; } -void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& intrin) { +void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& intrin, + bool preserve_unit_iters) { // Step 1: Blockize the subtree rooted at the given loop if needed BlockRealize block_realize{nullptr}; Optional old_block = NullOpt; @@ -517,7 +524,7 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int } else if (sref->stmt->IsInstance()) { arith::Analyzer analyzer; Map block_sref_reuse; - block_realize = BlockizeImpl(self, sref, &block_sref_reuse, &analyzer); + block_realize = BlockizeImpl(self, sref, preserve_unit_iters, &block_sref_reuse, &analyzer); } else { LOG(FATAL) << "TypeError: Tensorize only support For or Block, but gets: " << GetRef(sref->stmt); @@ -617,16 +624,17 @@ struct BlockizeTraits : public UnpackedInstTraits { private: static constexpr size_t kNumInputs = 1; - static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumAttrs = 1; static constexpr size_t kNumDecisions = 0; - static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv) { - return sch->Blockize(loop_rv); + static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Bool preserve_unit_iters) { + return sch->Blockize(loop_rv, preserve_unit_iters.operator bool()); } - static String UnpackedAsPython(Array outputs, String loop_rv) { + static String UnpackedAsPython(Array outputs, String loop_rv, Bool preserve_unit_iters) { PythonAPICall py("blockize"); py.Input("loop", loop_rv); + py.Input("preserve_unit_iters", preserve_unit_iters.operator bool()); py.SingleOutput(outputs); return py.Str(); } @@ -641,24 +649,27 @@ struct TensorizeTraits : public UnpackedInstTraits { private: static constexpr size_t kNumInputs = 1; - static constexpr size_t kNumAttrs = 1; + static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 0; - static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, String intrin) { + static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, String intrin, + Bool preserve_unit_iters) { if (const auto* block = block_or_loop_rv.as()) { - sch->Tensorize(GetRef(block), intrin); + sch->Tensorize(GetRef(block), intrin, preserve_unit_iters.operator bool()); } else if (const auto* loop = block_or_loop_rv.as()) { - sch->Tensorize(GetRef(loop), intrin); + sch->Tensorize(GetRef(loop), intrin, preserve_unit_iters.operator bool()); } else { LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " << block_or_loop_rv->GetTypeKey(); } } - static String UnpackedAsPython(Array outputs, String block_or_loop_rv, String intrin) { + static String UnpackedAsPython(Array outputs, String block_or_loop_rv, String intrin, + Bool preserve_unit_iters) { PythonAPICall py("tensorize"); py.Input("block_or_loop", block_or_loop_rv); py.Input("tensor_intrin", intrin); + py.Input("preserve_unit_iters", preserve_unit_iters.operator bool()); return py.Str(); } diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 3fe81c9f433b..d008f3639c78 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -211,11 +211,11 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetScope") TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize") .set_body_method(&ScheduleNode::Blockize); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTensorize") - .set_body_typed([](Schedule self, ObjectRef rv, String intrin) { + .set_body_typed([](Schedule self, ObjectRef rv, String intrin, bool preserve_unit_iters) { if (const auto* block_rv = rv.as()) { - self->Tensorize(GetRef(block_rv), intrin); + self->Tensorize(GetRef(block_rv), intrin, preserve_unit_iters); } else if (const auto* loop_rv = rv.as()) { - self->Tensorize(GetRef(loop_rv), intrin); + self->Tensorize(GetRef(loop_rv), intrin, preserve_unit_iters); } else { LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() << ". Its value is: " << rv; diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 010730f66c60..00941b48575d 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -442,34 +442,36 @@ void TracedScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index, /******** Schedule: Blockize & Tensorize ********/ -BlockRV TracedScheduleNode::Blockize(const LoopRV& loop_rv) { - BlockRV new_block = ConcreteScheduleNode::Blockize(loop_rv); +BlockRV TracedScheduleNode::Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) { + BlockRV new_block = ConcreteScheduleNode::Blockize(loop_rv, preserve_unit_iters); static const InstructionKind& kind = InstructionKind::Get("Blockize"); trace_->Append(/*inst=*/Instruction( /*kind=*/kind, /*inputs=*/{loop_rv}, - /*attrs=*/{}, + /*attrs=*/{Bool(preserve_unit_iters)}, /*outputs=*/{new_block})); return new_block; } -void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin) { - ConcreteScheduleNode::Tensorize(loop_rv, intrin); +void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin, + bool preserve_unit_iters) { + ConcreteScheduleNode::Tensorize(loop_rv, intrin, preserve_unit_iters); static const InstructionKind& kind = InstructionKind::Get("Tensorize"); trace_->Append(/*inst=*/Instruction( /*kind=*/kind, /*inputs=*/{loop_rv}, - /*attrs=*/{intrin}, + /*attrs=*/{intrin, Bool(preserve_unit_iters)}, /*outputs=*/{})); } -void TracedScheduleNode::Tensorize(const BlockRV& block_rv, const String& intrin) { - ConcreteScheduleNode::Tensorize(block_rv, intrin); +void TracedScheduleNode::Tensorize(const BlockRV& block_rv, const String& intrin, + bool preserve_unit_iters) { + ConcreteScheduleNode::Tensorize(block_rv, intrin, preserve_unit_iters); static const InstructionKind& kind = InstructionKind::Get("Tensorize"); trace_->Append(/*inst=*/Instruction( /*kind=*/kind, /*inputs=*/{block_rv}, - /*attrs=*/{intrin}, + /*attrs=*/{intrin, Bool(preserve_unit_iters)}, /*outputs=*/{})); } diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index cea2096d20a6..80257f644f6b 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -96,9 +96,9 @@ class TracedScheduleNode : public ConcreteScheduleNode { int offset) final; void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) final; /******** Schedule: Blockize & Tensorize ********/ - BlockRV Blockize(const LoopRV& loop_rv) final; - void Tensorize(const BlockRV& block_rv, const String& intrin) final; - void Tensorize(const LoopRV& loop_rv, const String& intrin) final; + BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) final; + void Tensorize(const BlockRV& block_rv, const String& intrin, bool preserve_unit_iters) final; + void Tensorize(const LoopRV& loop_rv, const String& intrin, bool preserve_unit_iters) final; /******** Schedule: Annotation ********/ void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override; void Unannotate(const LoopRV& loop_rv, const String& ann_key) override; diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index 6a2fdbbb3f1c..7ae5c58a9507 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -670,6 +670,35 @@ def test_subspace_division(): assert len(res) == 0 +def test_subspace_divide_trivial_iters(): + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + z = tvm.tir.Var("z", "int32") + + # trivial 1.1 + res = tvm.arith.subspace_divide( + [x * 16 + y], var_dom([(x, 1), (y, 16)]), [y], simplify_trivial_iterators=False + ) + res = convert_division(res) + assert len(res) == 2 + tvm.ir.assert_structural_equal(res[0][0], x) + tvm.ir.assert_structural_equal(res[0][1], y) + + # trivial 1.2 + res = tvm.arith.subspace_divide( + [x, y], + var_dom([(x, 1), (y, 1)]), + [y], + simplify_trivial_iterators=False, + ) + res = convert_division(res) + assert len(res) == 3 + tvm.ir.assert_structural_equal(res[0][0], x) + tvm.ir.assert_structural_equal(res[0][1], 0) + tvm.ir.assert_structural_equal(res[1][0], 0) + tvm.ir.assert_structural_equal(res[1][1], y) + + def test_complex(): n0 = create_iter("n0", 2) n1 = create_iter("n1", 4) diff --git a/tests/python/unittest/test_meta_schedule_runner.py b/tests/python/unittest/test_meta_schedule_runner.py index a79498304b2f..e10cd89066d4 100644 --- a/tests/python/unittest/test_meta_schedule_runner.py +++ b/tests/python/unittest/test_meta_schedule_runner.py @@ -690,6 +690,8 @@ def _check_correct_add(args_before: List[np.ndarray], args_after: List[np.ndarra a_before, b_before, c_before = args_before a_after, b_after, c_after = args_after c_before = a_before + b_before + print(a_before) + print(a_after) assert (a_before == a_after).all() assert (b_before == b_after).all() assert (c_before == c_after).all() @@ -786,6 +788,7 @@ def test_run_evaluator( # Run the module (runner_future,) = runner.run([runner_input]) runner_result = runner_future.result() + print(runner_result.error_msg) assert runner_result.error_msg is None for result in runner_result.run_secs: if isinstance(result, FloatImm): diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py index e70f7cb2c618..54f342c3a5d8 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py @@ -74,16 +74,16 @@ def vnni_conv2d_nchwc_0(placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], plac for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1 in T.grid(1, 8, 28, 56, 1, 1, 2, 1, 1, 1): for i5_0, i6_0, i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3 in T.grid(1, 1, 1, 4, 1, 1, 1, 2, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 1, 1): with T.block("conv2d_NCHWc_int8_o"): - n = T.axis.spatial(1, 0) + n = T.axis.spatial(1, i0_2 + i0_3 + i0_0 + i0_1) oc_chunk = T.axis.spatial(16, i1_0 * 2 + i1_1 + i1_2 + i1_3) oh = T.axis.spatial(56, i2_0 * 2 + i2_1 * 2 + i2_2 + i2_3) ow = T.axis.spatial(56, i3_3 + i3_0 + i3_1 + i3_2) - oc_block_o = T.axis.spatial(1, 0) - kh = T.axis.reduce(1, 0) - kw = T.axis.reduce(1, 0) + oc_block_o = T.axis.spatial(1, i4_0_2 + i4_0_3 + i4_0_0 + i4_0_1) + kh = T.axis.reduce(1, i5_1 + i5_0) + kw = T.axis.reduce(1, i6_0 + i6_1) ic_outer = T.axis.reduce(4, i7_0 * 4 + i7_1) ic_f_inner = T.axis.reduce(4, i8_0 + i8_1) - ic_s_inner_o = T.axis.reduce(1, 0) + ic_s_inner_o = T.axis.reduce(1, i9_0_1 + i9_0_0) T.reads(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4]) T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, 0 : 16]) T.block_attr({"meta_schedule.auto_tensorize":"dot_16x4_vnni"}) @@ -119,16 +119,16 @@ def vnni_conv2d_nchwc_1(placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], plac for i0_0, i1_0, i2_0, i3_0, i4_0_0 in T.grid(1, 8, 28, 56, 1): for i0_1, i1_1, i2_1, i3_1, i4_0_1, i5_0, i6_0, i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3 in T.grid(1, 2, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 2, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 1, 1): with T.block("conv2d_NCHWc_int8_o"): - n = T.axis.spatial(1, 0) + n = T.axis.spatial(1, i0_2 + i0_3 + i0_0 + i0_1) oc_chunk = T.axis.spatial(16, i1_0 * 2 + i1_1 + i1_2 + i1_3) oh = T.axis.spatial(56, i2_0 * 2 + i2_1 * 2 + i2_2 + i2_3) ow = T.axis.spatial(56, i3_3 + i3_0 + i3_1 + i3_2) - oc_block_o = T.axis.spatial(1, 0) - kh = T.axis.reduce(1, 0) - kw = T.axis.reduce(1, 0) + oc_block_o = T.axis.spatial(1, i4_0_2 + i4_0_3 + i4_0_0 + i4_0_1) + kh = T.axis.reduce(1, i5_1 + i5_0) + kw = T.axis.reduce(1, i6_0 + i6_1) ic_outer = T.axis.reduce(4, i7_0 * 4 + i7_1) ic_f_inner = T.axis.reduce(4, i8_0 + i8_1) - ic_s_inner_o = T.axis.reduce(1, 0) + ic_s_inner_o = T.axis.reduce(1, i9_0_1 + i9_0_0) T.reads(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4]) T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, 0 : 16]) T.block_attr({"meta_schedule.auto_tensorize":"dot_16x4_vnni"}) @@ -162,16 +162,16 @@ def vnni_conv2d_nchwc_2(placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], plac T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1, i5_0, i6_0, i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3 in T.grid(1, 8, 28, 56, 1, 1, 2, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 2, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 1, 1): with T.block("conv2d_NCHWc_int8_o"): - n = T.axis.spatial(1, 0) + n = T.axis.spatial(1, i0_2 + i0_3 + i0_0 + i0_1) oc_chunk = T.axis.spatial(16, i1_0 * 2 + i1_1 + i1_2 + i1_3) oh = T.axis.spatial(56, i2_0 * 2 + i2_1 * 2 + i2_2 + i2_3) ow = T.axis.spatial(56, i3_3 + i3_0 + i3_1 + i3_2) - oc_block_o = T.axis.spatial(1, 0) - kh = T.axis.reduce(1, 0) - kw = T.axis.reduce(1, 0) + oc_block_o = T.axis.spatial(1, i4_0_2 + i4_0_3 + i4_0_0 + i4_0_1) + kh = T.axis.reduce(1, i5_1 + i5_0) + kw = T.axis.reduce(1, i6_0 + i6_1) ic_outer = T.axis.reduce(4, i7_0 * 4 + i7_1) ic_f_inner = T.axis.reduce(4, i8_0 + i8_1) - ic_s_inner_o = T.axis.reduce(1, 0) + ic_s_inner_o = T.axis.reduce(1, i9_0_1 + i9_0_0) T.reads(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4]) T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16]) T.block_attr({"meta_schedule.auto_tensorize":"dot_16x4_vnni"}) diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py index acc626b904a1..73b2c990f08a 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py @@ -117,7 +117,7 @@ def matmul_relu_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, 128), "f for ax0_0, ax1_0 in T.grid(2, 1): with T.block("B_reindex_shared_wmma.matrix_b_o"): v0_o = T.axis.spatial(8, ax2_0_1 * 2 + ax0_0) - v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused) + v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0) T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"}) @@ -152,7 +152,7 @@ def matmul_relu_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, 128), "f for ax0_0, ax1_0 in T.grid(2, 1): with T.block("C_reindex_shared_wmma.accumulator_o"): v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0) - v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused) + v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0) T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"}) @@ -396,7 +396,8 @@ def conv2d_0(inputs: T.Buffer[(1, 16, 16, 32), "float16"], weight: T.Buffer[(3, for ax2_0_1 in T.serial(18): for ax0_0, ax1_0 in T.grid(1, 1): with T.block("PadInput_reindex_shared_wmma.matrix_a_o"): - v0_o, v1_o = T.axis.remap("SS", [ax0_0_1_ax1_0_1_fused, ax2_0_1]) + v0_o = T.axis.spatial(16, ax0_0_1_ax1_0_1_fused + ax0_0) + v1_o = T.axis.spatial(18, ax2_0_1 + ax1_0) T.reads(PadInput_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"}) @@ -408,7 +409,8 @@ def conv2d_0(inputs: T.Buffer[(1, 16, 16, 32), "float16"], weight: T.Buffer[(3, PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0_0, ax1_0 in T.grid(1, 1): with T.block("weight_reindex_shared_wmma.matrix_b_o"): - v0_o, v1_o = T.axis.remap("SS", [ax2_0_1, ax0_0_0_ax1_0_0_fused]) + v0_o = T.axis.spatial(18, ax2_0_1 + ax0_0) + v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused + ax1_0) T.reads(weight_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(weight_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"}) @@ -442,7 +444,8 @@ def conv2d_0(inputs: T.Buffer[(1, 16, 16, 32), "float16"], weight: T.Buffer[(3, conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(weight_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32") for ax0_0, ax1_0 in T.grid(1, 1): with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): - v0_o, v1_o = T.axis.remap("SS", [ax0_0_1_ax1_0_1_fused, ax0_0_0_ax1_0_0_fused]) + v0_o = T.axis.spatial(16, ax0_0_1_ax1_0_1_fused + ax0_0) + v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused + ax1_0) T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"}) @@ -560,7 +563,7 @@ def matmul_relu_pipeline_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, for ax0_0, ax1_0 in T.grid(2, 1): with T.block("A_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 4 * 2 + ax0_0) - v1_o = T.axis.spatial(8, ax2_0_0 * 2 + ax2_0_1) + v1_o = T.axis.spatial(8, ax2_0_0 * 2 + ax2_0_1 + ax1_0) T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"}) @@ -572,7 +575,7 @@ def matmul_relu_pipeline_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0_0, ax1_0 in T.grid(1, 2): with T.block("B_reindex_shared_wmma.matrix_b_o"): - v0_o = T.axis.spatial(8, ax2_0_0 * 2 + ax2_0_1) + v0_o = T.axis.spatial(8, ax2_0_0 * 2 + ax2_0_1 + ax0_0) v1_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused % 4 * 2 + ax1_0) T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) @@ -706,7 +709,7 @@ def matmul_relu_global_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, 1 for ax2_0_1 in T.serial(2): for ax0_0, ax1_0 in T.grid(1, 2): with T.block("A_reindex_shared_wmma.matrix_a_o"): - v0_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused // 2) + v0_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused // 2 + ax0_0) v1_o = T.axis.spatial(8, ax2_0_0 * 4 + ax2_0_1 * 2 + ax1_0) T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) @@ -754,7 +757,7 @@ def matmul_relu_global_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, 1 C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32") for ax0_0, ax1_0 in T.grid(1, 4): with T.block("C_reindex_wmma.accumulator_o"): - v0_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused // 2) + v0_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused // 2 + ax0_0) v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused % 2 * 4 + ax1_0) T.reads(C_reindex_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(C[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) @@ -875,7 +878,7 @@ def padded_matmul_relu_0(A: T.Buffer[(127, 127), "float16"], B: T.Buffer[(127, 1 for ax0_0, ax1_0 in T.grid(2, 1): with T.block("B_reindex_shared_wmma.matrix_b_o"): v0_o = T.axis.spatial(8, ax2_0_1 * 2 + ax0_0) - v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused) + v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0) T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"}) @@ -910,7 +913,7 @@ def padded_matmul_relu_0(A: T.Buffer[(127, 127), "float16"], B: T.Buffer[(127, 1 for ax0_0, ax1_0 in T.grid(2, 1): with T.block("C_reindex_shared_wmma.accumulator_o"): v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0) - v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused) + v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0) T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"}) @@ -1001,7 +1004,7 @@ def conv2d_1x1_0(inputs: T.Buffer[(1, 16, 16, 64), "float16"], weight: T.Buffer[ for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 1): for ax0_0_1, ax1_0_1 in T.grid(1, 4): with T.block("PadInput_reindex_shared_wmma.matrix_a_o"): - v0_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused) + v0_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused + ax0_0_1) v1_o = T.axis.spatial(4, ax1_0_1) T.reads(PadInput_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) @@ -1014,10 +1017,8 @@ def conv2d_1x1_0(inputs: T.Buffer[(1, 16, 16, 64), "float16"], weight: T.Buffer[ PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 4, 1): with T.block("weight_reindex_shared_wmma.matrix_b_o"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(1, 0) - v2_o = T.axis.spatial(4, ax2_0) - v3_o = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused) + v0, v1, v2_o = T.axis.remap("SSS", [ax0, ax1, ax2_0]) + v3_o = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused + ax3_0) T.reads(weight_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) T.writes(weight_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"}) @@ -1029,8 +1030,8 @@ def conv2d_1x1_0(inputs: T.Buffer[(1, 16, 16, 64), "float16"], weight: T.Buffer[ weight_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] = weight_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] for ax2_0_3, ax3_0_3, ax0_2, ax1_2, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(1, 1, 1, 1, 4, 1, 1): with T.block("conv2d_nhwc_o"): - v0 = T.axis.reduce(1, 0) - v1 = T.axis.reduce(1, 0) + v0 = T.axis.reduce(1, ax0_2 + ax0_0 + ax0_1) + v1 = T.axis.reduce(1, ax1_1 + ax1_2 + ax1_0) v2_o = T.axis.spatial(16, ax2_0_4 + ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused + ax2_0_3) v3_o = T.axis.spatial(4, ax3_0_4 + ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused + ax3_0_3) v4_o = T.axis.reduce(4, ax4_0_0 * 4 + ax4_0_1 * 4 + ax4_0_2) @@ -1053,8 +1054,8 @@ def conv2d_1x1_0(inputs: T.Buffer[(1, 16, 16, 64), "float16"], weight: T.Buffer[ conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] + T.cast(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], "float32") * T.cast(weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 + v4_i, v3_o * 16 + v3_i], "float32") for ax0_0, ax1_0 in T.grid(1, 1): with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): - v0_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused) - v1_o = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused) + v0_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused + ax0_0) + v1_o = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused + ax1_0) T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"}) diff --git a/tests/python/unittest/test_meta_schedule_trace_apply.py b/tests/python/unittest/test_meta_schedule_trace_apply.py index c8e6bf6a0c73..df1eb614ab97 100644 --- a/tests/python/unittest/test_meta_schedule_trace_apply.py +++ b/tests/python/unittest/test_meta_schedule_trace_apply.py @@ -635,26 +635,26 @@ class Conv2dInt8_tensorcore_scheduled: def main(p0: T.Buffer[(16, 56, 56, 64), "int8"], p1: T.Buffer[(256, 1, 1, 64), "int8"], p2: T.Buffer[(1, 1, 1, 256), "int32"], p3: T.Buffer[(1, 1, 1, 256), "int32"], p4: T.Buffer[(1, 1, 1, 256), "int64"], p5: T.Buffer[(1, 1, 1, 256), "int64"], p6: T.Buffer[(1, 1, 1, 256), "int64"], p7: T.Buffer[(), "int32"], p8: T.Buffer[1, "int32"], p9: T.Buffer[(16, 56, 56, 256), "int32"], compute: T.Buffer[(16, 56, 56, 256), "uint8"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - a0 = T.var("int32") - a1 = T.var("int32") - b0 = T.var("int32") - b1 = T.var("int32") - c0 = T.var("int32") - c1 = T.var("int32") - d0 = T.var("int32") - d0_1 = T.var("int32") - d0_2 = T.var("int32") - d0_3 = T.var("int32") - d1 = T.var("int32") - d1_1 = T.var("int32") - d1_2 = T.var("int32") - d1_3 = T.var("int32") - s0 = T.var("int32") - s0_1 = T.var("int32") - s0_2 = T.var("int32") - s1 = T.var("int32") - s1_1 = T.var("int32") - s1_2 = T.var("int32") + A_s0 = T.var("int32") + A_s0_1 = T.var("int32") + A_s0_2 = T.var("int32") + A_s0_3 = T.var("int32") + A_s1 = T.var("int32") + A_s1_1 = T.var("int32") + A_s1_2 = T.var("int32") + A_s1_3 = T.var("int32") + B_s0 = T.var("int32") + B_s1 = T.var("int32") + C_s0 = T.var("int32") + C_s0_1 = T.var("int32") + C_s0_2 = T.var("int32") + C_s0_3 = T.var("int32") + C_s0_4 = T.var("int32") + C_s1 = T.var("int32") + C_s1_1 = T.var("int32") + C_s1_2 = T.var("int32") + C_s1_3 = T.var("int32") + C_s1_4 = T.var("int32") # body # with T.block("root") conv2d_nhwc_reindex_shared = T.alloc_buffer([50176, 256], dtype="int32", scope="shared") @@ -666,83 +666,81 @@ def main(p0: T.Buffer[(16, 56, 56, 64), "int8"], p1: T.Buffer[(256, 1, 1, 64), " for ax2_0_0_ax3_0_0_fused in T.thread_binding(3136, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step":512, "pragma_unroll_explicit":1}): for ax2_0_1_ax3_0_1_fused in T.thread_binding(1, thread="vthread.x"): for ax2_0_2_ax3_0_2_fused in T.thread_binding(16, thread="threadIdx.x"): - for ax0_0, ax1_0 in T.grid(1, 1): - for ax2_0_3_init, ax3_0_3_init, ax2_0_4_init, ax3_0_4_init in T.grid(1, 1, 1, 1): - with T.block("conv2d_nhwc_o_init"): - v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3_init + ax2_0_4_init) - v3_o = T.axis.spatial(16, ax3_0_4_init + ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2 + ax3_0_3_init) - T.reads() - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "warp_execution":1}) - C = T.match_buffer(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], [16, 16], dtype="int32", strides=[d1, d0], scope="wmma.accumulator", offset_factor=16) - T.evaluate(T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // d1 // 16 * (d1 // 16) + C.elem_offset % d1 // 16, T.float32(0), dtype="handle")) - for ax4_0_0 in T.serial(2): - for ax0_ax1_fused_0 in T.serial(16): - for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.x"): - for ax0_ax1_fused_2 in T.vectorized(16): - with T.block("pad_temp_reindex_shared"): - v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 8 * 128 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 16 + ax0_ax1_fused_2) // 32) - v1 = T.axis.spatial(64, ax4_0_0 * 32 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 16 + ax0_ax1_fused_2) % 32) - T.reads(p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) - T.writes(pad_temp_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 16]]}) - pad_temp_reindex_shared[v0, v1] = p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1] - for ax0_ax1_ax2_ax3_fused_0 in T.serial(8): - for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(16, thread="threadIdx.x"): - for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(8): - with T.block("p1_reindex_shared"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(1, 0) - v2 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 8 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 128 + ax0_ax1_ax2_ax3_fused_1 * 8 + ax0_ax1_ax2_ax3_fused_2) // 32) - v3 = T.axis.spatial(64, ax4_0_0 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 128 + ax0_ax1_ax2_ax3_fused_1 * 8 + ax0_ax1_ax2_ax3_fused_2) % 32) - T.reads(p1[v2, v0, v1, v3]) - T.writes(p1_reindex_shared[v0, v1, v2, v3]) - T.block_attr({"buffer_dim_align":[[0, 2, 32, 16]]}) - p1_reindex_shared[v0, v1, v2, v3] = p1[v2, v0, v1, v3] - for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 1): - for ax0_0_1, ax1_0_1 in T.grid(1, 2): - with T.block("pad_temp_reindex_shared_wmma.matrix_a_o"): - v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2) - v1_o = T.axis.spatial(4, ax4_0_0 * 2 + ax1_0_1) - T.reads(pad_temp_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - A = T.match_buffer(pad_temp_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="int8", strides=[s1, s0], scope="shared", offset_factor=16) - C_1 = T.match_buffer(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="int8", strides=[d1_1, d0_1], scope="wmma.matrix_a", offset_factor=16) - T.evaluate(T.tvm_load_matrix_sync(C_1.data, 16, 16, 16, C_1.elem_offset // d1_1 // 16 * (d1_1 // 16) + C_1.elem_offset % d1_1 // 16, T.tvm_access_ptr(T.type_annotation(dtype="int8"), A.data, A.elem_offset, s1 * 16, 1, dtype="handle"), s1, "row_major", dtype="handle")) - for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 1, 2): - with T.block("p1_reindex_shared_wmma.matrix_b_o"): + for ax2_0_3_init, ax3_0_3_init, ax2_0_4_init, ax3_0_4_init in T.grid(1, 1, 1, 1): + with T.block("conv2d_nhwc_o_init"): + v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3_init + ax2_0_4_init) + v3_o = T.axis.spatial(16, ax3_0_4_init + ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2 + ax3_0_3_init) + T.reads() + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "warp_execution":1}) + C = T.match_buffer(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], [16, 16], dtype="int32", strides=[C_s0, C_s1], scope="wmma.accumulator", offset_factor=16) + T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // C_s0 // 16 * (C_s0 // 16) + C.elem_offset % C_s0 // 16, T.float32(0), dtype="handle") + for ax0_0, ax1_0, ax4_0_0 in T.grid(1, 1, 2): + for ax0_ax1_fused_0 in T.serial(16): + for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.x"): + for ax0_ax1_fused_2 in T.vectorized(16): + with T.block("pad_temp_reindex_shared"): + v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 8 * 128 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 16 + ax0_ax1_fused_2) // 32) + v1 = T.axis.spatial(64, ax4_0_0 * 32 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 16 + ax0_ax1_fused_2) % 32) + T.reads(p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) + T.writes(pad_temp_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align":[[0, 0, 32, 16]]}) + pad_temp_reindex_shared[v0, v1] = p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1] + for ax0_ax1_ax2_ax3_fused_0 in T.serial(8): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(16, thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(8): + with T.block("p1_reindex_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(1, 0) - v2_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2) - v3_o = T.axis.spatial(4, ax4_0_0 * 2 + ax3_0) - T.reads(p1_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - A_1 = T.match_buffer(p1_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], [16, 16], dtype="int8", strides=[s1_1, s0_1], scope="shared", offset_factor=16) - C_2 = T.match_buffer(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], [16, 16], dtype="int8", strides=[d1_2, d0_2], scope="wmma.matrix_b", offset_factor=16) - T.evaluate(T.tvm_load_matrix_sync(C_2.data, 16, 16, 16, C_2.elem_offset // d1_2 // 16 * (d1_2 // 16) + C_2.elem_offset % d1_2 // 16, T.tvm_access_ptr(T.type_annotation(dtype="int8"), A_1.data, A_1.elem_offset, s1_1 * 16, 1, dtype="handle"), s1_1, "col_major", dtype="handle")) - for ax2_0_3, ax3_0_3, ax0_2, ax1_2, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(1, 1, 1, 1, 2, 1, 1): - with T.block("conv2d_nhwc_o_update"): - v0 = T.axis.reduce(1, 0) - v1 = T.axis.reduce(1, 0) - v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3 + ax2_0_4) - v3_o = T.axis.spatial(16, ax3_0_4 + ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2 + ax3_0_3) - v4_o = T.axis.reduce(4, ax4_0_0 * 2 + ax4_0_1 * 2 + ax4_0_2) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 : v2_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 : v3_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16]) - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "warp_execution":1}) - A_2 = T.match_buffer(pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 : v2_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], [16, 16], dtype="int8", strides=[a1, a0], scope="wmma.matrix_a", offset_factor=16) - B = T.match_buffer(p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 : v3_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], [16, 16], dtype="int8", strides=[b1, b0], scope="wmma.matrix_b", offset_factor=16) - C_3 = T.match_buffer(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], [16, 16], dtype="int32", strides=[c1, c0], scope="wmma.accumulator", offset_factor=16) - T.evaluate(T.tvm_mma_sync(C_3.data, C_3.elem_offset // c1 // 16 * (c1 // 16) + C_3.elem_offset % c1 // 16, A_2.data, A_2.elem_offset // a1 // 16 * (a1 // 16) + A_2.elem_offset % a1 // 16, B.data, B.elem_offset // b1 // 16 * (b1 // 16) + B.elem_offset % b1 // 16, C_3.data, C_3.elem_offset // c1 // 16 * (c1 // 16) + C_3.elem_offset % c1 // 16, dtype="handle")) + v2 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 8 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 128 + ax0_ax1_ax2_ax3_fused_1 * 8 + ax0_ax1_ax2_ax3_fused_2) // 32) + v3 = T.axis.spatial(64, ax4_0_0 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 128 + ax0_ax1_ax2_ax3_fused_1 * 8 + ax0_ax1_ax2_ax3_fused_2) % 32) + T.reads(p1[v2, v0, v1, v3]) + T.writes(p1_reindex_shared[v0, v1, v2, v3]) + T.block_attr({"buffer_dim_align":[[0, 2, 32, 16]]}) + p1_reindex_shared[v0, v1, v2, v3] = p1[v2, v0, v1, v3] + for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 1): + for ax0_0_1, ax1_0_1 in T.grid(1, 2): + with T.block("pad_temp_reindex_shared_wmma.matrix_a_o"): + v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2 + ax0_0_1) + v1_o = T.axis.spatial(4, ax4_0_0 * 2 + ax1_0_1) + T.reads(pad_temp_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + A = T.match_buffer(pad_temp_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="int8", strides=[A_s0, A_s1], scope="shared", offset_factor=16) + C_1 = T.match_buffer(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="int8", strides=[C_s0_1, C_s1_1], scope="wmma.matrix_a", offset_factor=16) + T.tvm_load_matrix_sync(C_1.data, 16, 16, 16, C_1.elem_offset // C_s0_1 // 16 * (C_s0_1 // 16) + C_1.elem_offset % C_s0_1 // 16, T.tvm_access_ptr(T.type_annotation(dtype="int8"), A.data, A.elem_offset, A_s0 * 16, 1, dtype="handle"), A_s0, "row_major", dtype="handle") + for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 1, 2): + with T.block("p1_reindex_shared_wmma.matrix_b_o"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2 + ax2_0) + v3_o = T.axis.spatial(4, ax4_0_0 * 2 + ax3_0) + T.reads(p1_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) + T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) + A_1 = T.match_buffer(p1_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], [16, 16], dtype="int8", strides=[A_s0_1, A_s1_1], scope="shared", offset_factor=16) + C_2 = T.match_buffer(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], [16, 16], dtype="int8", strides=[C_s0_2, C_s1_2], scope="wmma.matrix_b", offset_factor=16) + T.tvm_load_matrix_sync(C_2.data, 16, 16, 16, C_2.elem_offset // C_s0_2 // 16 * (C_s0_2 // 16) + C_2.elem_offset % C_s0_2 // 16, T.tvm_access_ptr(T.type_annotation(dtype="int8"), A_1.data, A_1.elem_offset, A_s0_1 * 16, 1, dtype="handle"), A_s0_1, "col_major", dtype="handle") + for ax2_0_3, ax3_0_3, ax0_2, ax1_2, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(1, 1, 1, 1, 2, 1, 1): + with T.block("conv2d_nhwc_o_update"): + v0 = T.axis.reduce(1, ax0_2 + ax0_0 + ax0_1) + v1 = T.axis.reduce(1, ax1_1 + ax1_2 + ax1_0) + v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3 + ax2_0_4) + v3_o = T.axis.spatial(16, ax3_0_4 + ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2 + ax3_0_3) + v4_o = T.axis.reduce(4, ax4_0_0 * 2 + ax4_0_1 * 2 + ax4_0_2) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 : v2_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 : v3_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "warp_execution":1}) + A_2 = T.match_buffer(pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 : v2_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], [16, 16], dtype="int8", strides=[A_s0_2, A_s1_2], scope="wmma.matrix_a", offset_factor=16) + B = T.match_buffer(p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 : v3_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], [16, 16], dtype="int8", strides=[B_s0, B_s1], scope="wmma.matrix_b", offset_factor=16) + C_3 = T.match_buffer(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], [16, 16], dtype="int32", strides=[C_s0_3, C_s1_3], scope="wmma.accumulator", offset_factor=16) + T.tvm_mma_sync(C_3.data, C_3.elem_offset // C_s0_3 // 16 * (C_s0_3 // 16) + C_3.elem_offset % C_s0_3 // 16, A_2.data, A_2.elem_offset // A_s0_2 // 16 * (A_s0_2 // 16) + A_2.elem_offset % A_s0_2 // 16, B.data, B.elem_offset // B_s0 // 16 * (B_s0 // 16) + B.elem_offset % B_s0 // 16, C_3.data, C_3.elem_offset // C_s0_3 // 16 * (C_s0_3 // 16) + C_3.elem_offset % C_s0_3 // 16, dtype="handle") for ax0_0, ax1_0 in T.grid(1, 1): with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): - v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2) - v1_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2) + v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2 + ax0_0) + v1_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2 + ax1_0) T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - A_3 = T.match_buffer(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="int32", strides=[d1_3, d0_3], scope="wmma.accumulator", offset_factor=16) - C_4 = T.match_buffer(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="int32", strides=[s1_2, s0_2], scope="shared", offset_factor=16) - T.evaluate(T.tvm_store_matrix_sync(A_3.data, 16, 16, 16, A_3.elem_offset // d1_3 // 16 * (d1_3 // 16) + A_3.elem_offset % d1_3 // 16, T.tvm_access_ptr(T.type_annotation(dtype="int32"), C_4.data, C_4.elem_offset, s1_2 * 16, 2, dtype="handle"), s1_2, "row_major", dtype="handle")) + A_3 = T.match_buffer(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="int32", strides=[A_s0_3, A_s1_3], scope="wmma.accumulator", offset_factor=16) + C_4 = T.match_buffer(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="int32", strides=[C_s0_4, C_s1_4], scope="shared", offset_factor=16) + T.tvm_store_matrix_sync(A_3.data, 16, 16, 16, A_3.elem_offset // A_s0_3 // 16 * (A_s0_3 // 16) + A_3.elem_offset % A_s0_3 // 16, T.tvm_access_ptr(T.type_annotation(dtype="int32"), C_4.data, C_4.elem_offset, C_s0_4 * 16, 2, dtype="handle"), C_s0_4, "row_major", dtype="handle") for ax0, ax1_0 in T.grid(128, 2): for ax1_1 in T.thread_binding(16, thread="threadIdx.x"): with T.block("conv2d_nhwc_reindex_shared"): @@ -1145,45 +1143,44 @@ def main(p0: T.Buffer[(1, 32, 7, 7, 16), "uint8"], p1: T.Buffer[(128, 32, 1, 1, conv2d_NCHWc_int8 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") for i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused in T.parallel(128, annotations={"pragma_auto_unroll_max_step":64, "pragma_unroll_explicit":1}): for i2_1, i3_1, i4_0_1 in T.grid(7, 1, 1): - for i5_0, i6_0 in T.grid(1, 1): - for i1_2_init, i2_2_init, i3_2_init, i1_3_init, i2_3_init, i3_3_init in T.grid(1, 1, 1, 1, 1, 7): - with T.block("conv2d_NCHWc_int8_o_init"): - n = T.axis.spatial(1, 0) - oc_chunk = T.axis.spatial(128, i1_2_init + i1_3_init + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused // 32 * 32 + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused % 32) - oh = T.axis.spatial(7, i2_1 + i2_2_init + i2_3_init) - ow = T.axis.spatial(7, i3_1 * 7 + i3_2_init * 7 + i3_3_init) - oc_block_o = T.axis.spatial(1, 0) - T.reads() - T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16]) - for i4_1 in T.vectorized(16): - with T.block("conv2d_NCHWc_int8_init"): - oc_block_i_init = T.axis.spatial(16, i4_1) - T.reads() - T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i_init]) - conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i_init] = 0 - for i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3 in T.grid(4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 8, 1, 1, 1, 1, 1, 7, 1): - with T.block("conv2d_NCHWc_int8_o_update"): - n = T.axis.spatial(1, 0) - oc_chunk = T.axis.spatial(128, i1_2 + i1_3 + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused // 32 * 32 + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused % 32) - oh = T.axis.spatial(7, i2_1 + i2_2 + i2_3) - ow = T.axis.spatial(7, i3_1 * 7 + i3_2 * 7 + i3_3) - oc_block_o = T.axis.spatial(1, 0) - kh = T.axis.reduce(1, 0) - kw = T.axis.reduce(1, 0) - ic_outer = T.axis.reduce(32, i7_0 * 8 + i7_1) - ic_f_inner = T.axis.reduce(4, i8_1 + i8_0) - ic_s_inner_o = T.axis.reduce(1, 0) - T.reads(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16], p0[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4]) - T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16]) - A = T.match_buffer(p0[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], [4], dtype="uint8", offset_factor=1) - B = T.match_buffer(p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4], [16, 4], dtype="int8", offset_factor=1) - C = T.match_buffer(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16], [16], dtype="int32", offset_factor=1) - A_u8x4: T.uint8x4 = A[0:4] - A_i32: T.int32 = T.reinterpret(A_u8x4, dtype="int32") - B_i8x64: T.int8x64 = B[0, 0:64] - B_i32x16: T.int32x16 = T.reinterpret(B_i8x64, dtype="int32x16") - C_i32x16: T.int32x16 = C[0:16] - C[0:16] = T.call_llvm_pure_intrin(intrin_id, T.uint32(0), C_i32x16, T.broadcast(A_i32, 16), B_i32x16, dtype="int32x16") + for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i4_0_2_init, i0_3_init, i1_3_init, i2_3_init, i3_3_init, i4_0_3_init in T.grid(1, 1, 1, 1, 1, 1, 1, 1, 7, 1): + with T.block("conv2d_NCHWc_int8_o_init"): + n = T.axis.spatial(1, i0_3_init + i0_2_init) + oc_chunk = T.axis.spatial(128, i1_2_init + i1_3_init + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused // 32 * 32 + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused % 32) + oh = T.axis.spatial(7, i2_1 + i2_2_init + i2_3_init) + ow = T.axis.spatial(7, i3_1 * 7 + i3_2_init * 7 + i3_3_init) + oc_block_o = T.axis.spatial(1, i4_0_3_init + i4_0_1 + i4_0_2_init) + T.reads() + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16]) + for i4_1 in T.vectorized(16): + with T.block("conv2d_NCHWc_int8_init"): + oc_block_i_init = T.axis.spatial(16, i4_1) + T.reads() + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i_init]) + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i_init] = 0 + for i5_0, i6_0, i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3 in T.grid(1, 1, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 8, 1, 1, 1, 1, 1, 7, 1): + with T.block("conv2d_NCHWc_int8_o_update"): + n = T.axis.spatial(1, i0_3 + i0_2) + oc_chunk = T.axis.spatial(128, i1_2 + i1_3 + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused // 32 * 32 + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused % 32) + oh = T.axis.spatial(7, i2_1 + i2_2 + i2_3) + ow = T.axis.spatial(7, i3_1 * 7 + i3_2 * 7 + i3_3) + oc_block_o = T.axis.spatial(1, i4_0_3 + i4_0_1 + i4_0_2) + kh = T.axis.reduce(1, i5_0 + i5_1) + kw = T.axis.reduce(1, i6_1 + i6_0) + ic_outer = T.axis.reduce(32, i7_0 * 8 + i7_1) + ic_f_inner = T.axis.reduce(4, i8_1 + i8_0) + ic_s_inner_o = T.axis.reduce(1, i9_0_0 + i9_0_1) + T.reads(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16], p0[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4]) + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16]) + A = T.match_buffer(p0[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], [4], dtype="uint8", offset_factor=1) + B = T.match_buffer(p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4], [16, 4], dtype="int8", offset_factor=1) + C = T.match_buffer(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16], [16], dtype="int32", offset_factor=1) + A_u8x4: T.uint8x4 = A[0:4] + A_i32: T.int32 = T.reinterpret(A_u8x4, dtype="int32") + B_i8x64: T.int8x64 = B[0, 0:64] + B_i32x16: T.int32x16 = T.reinterpret(B_i8x64, dtype="int32x16") + C_i32x16: T.int32x16 = C[0:16] + C[0:16] = T.call_llvm_pure_intrin(T.uint32(10060), T.uint32(0), C_i32x16, T.broadcast(A_i32, 16), B_i32x16, dtype="int32x16") for ax0, ax1, ax2, ax3 in T.grid(1, 1, 1, 7): for ax4_fused in T.vectorized(16): with T.block("T_cast_8"): @@ -1740,8 +1737,8 @@ def main(p0: T.Buffer[(16, 56, 56, 64), "int8"], p1: T.Buffer[(256, 1, 1, 64), " for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 2): for ax0_0_1, ax1_0_1 in T.grid(1, 1): with T.block("pad_temp_reindex_shared_wmma.matrix_a_o"): - v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2) - v1_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1) + v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2 + ax0_0_1) + v1_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1 + ax1_0_1) T.reads(pad_temp_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_a"}) @@ -1753,10 +1750,9 @@ def main(p0: T.Buffer[(16, 56, 56, 64), "int8"], p1: T.Buffer[(256, 1, 1, 64), " pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = pad_temp_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 2, 1): with T.block("p1_reindex_shared_wmma.matrix_b_o"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(1, 0) + v0, v1 = T.axis.remap("SS", [ax0, ax1]) v2_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax2_0) - v3_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1) + v3_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1 + ax3_0) T.reads(p1_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_b_trans"}) @@ -1768,8 +1764,8 @@ def main(p0: T.Buffer[(16, 56, 56, 64), "int8"], p1: T.Buffer[(256, 1, 1, 64), " p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] = p1_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] for ax2_0_3, ax3_0_3, ax0_2, ax1_2, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(1, 1, 1, 1, 1, 1, 2): with T.block("conv2d_nhwc_o"): - v0 = T.axis.reduce(1, 0) - v1 = T.axis.reduce(1, 0) + v0 = T.axis.reduce(1, ax0_2 + ax0_0 + ax0_1) + v1 = T.axis.reduce(1, ax1_1 + ax1_2 + ax1_0) v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3 + ax2_0_4) v3_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax3_0_3 * 2 + ax3_0_4) v4_o = T.axis.reduce(4, ax4_0_0 * 2 + ax4_0_1 + ax4_0_2) @@ -1789,10 +1785,10 @@ def main(p0: T.Buffer[(16, 56, 56, 64), "int8"], p1: T.Buffer[(256, 1, 1, 64), " T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i], pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 + v3_i, v4_o * 16 + v4_i]) T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i]) T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) - conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] + T.cast(pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], "int32") * T.cast(p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 + v3_i, v4_o * 16 + v4_i], "int32") + conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] + T.Cast("int32", pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i]) * T.Cast("int32", p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 + v3_i, v4_o * 16 + v4_i]) for ax0_0, ax1_0 in T.grid(1, 2): with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): - v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2) + v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2 + ax0_0) v1_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax1_0) T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) @@ -2478,7 +2474,7 @@ def apply_trace(sch): l311, l312, ) = sch.get_loops(block=b296) - b313 = sch.decompose_reduction(block=b296, loop=l302) + b313 = sch.decompose_reduction(block=b296, loop=l300) sch.unannotate(block_or_loop=b313, ann_key="meta_schedule.auto_tensorize") sch.annotate( block_or_loop=b313, @@ -2723,7 +2719,7 @@ def apply_trace(sch): l188, l189, ) = sch.get_loops(block=b165) - b190 = sch.decompose_reduction(block=b165, loop=l172) + b190 = sch.decompose_reduction(block=b165, loop=l170) sch.unannotate(block_or_loop=b190, ann_key="meta_schedule.auto_tensorize") sch.annotate(block_or_loop=b190, ann_key="meta_schedule.auto_tensorize", ann_val="") b191 = sch.get_block(name="conv2d_NCHWc_int8_o_init", func_name="main") diff --git a/tests/python/unittest/test_tir_schedule_blockize.py b/tests/python/unittest/test_tir_schedule_blockize.py index 12836cdb9e68..a68170009bb5 100644 --- a/tests/python/unittest/test_tir_schedule_blockize.py +++ b/tests/python/unittest/test_tir_schedule_blockize.py @@ -20,6 +20,7 @@ from tvm import tir from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip +import pytest # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks @@ -247,7 +248,8 @@ def after_rowsum_blockize( verify_trace_roundtrip(sch=s, mod=rowsum) -def test_blockize_outer_int64_shape(): +@pytest.mark.parametrize("preserve_unit_iters", [True, False]) +def test_blockize_outer_int64_shape(preserve_unit_iters): @T.prim_func def single_elementwise_int64( A: T.Buffer[(T.int64(16), T.int64(128)), "float32"], @@ -275,10 +277,31 @@ def after_single_elementwise_int64_blockize( vi_i, vj_o * T.int64(16) + vj_i ] + T.float32(1) + @T.prim_func + def after_single_elementwise_int64_blockize_preserve_unit_iters( + A: T.Buffer[(T.int64(16), T.int64(128)), "float32"], + B: T.Buffer[(T.int64(16), T.int64(128)), "float32"], + ) -> None: + for i0, j0 in T.grid(T.int64(1), T.int64(8)): + with T.block("B_o"): + vi_o = T.axis.spatial(T.int64(1), i0) + vj_o = T.axis.spatial(T.int64(8), j0) + for i1, j1 in T.grid(T.int64(16), T.int64(16)): + with T.block("B"): + vi_i, vj_i = T.axis.remap("SS", [i1, j1]) + B[vi_i, vj_o * T.int64(16) + vj_i] = A[ + vi_i, vj_o * T.int64(16) + vj_i + ] + T.float32(1) + s = tir.Schedule(single_elementwise_int64, debug_mask="all") _, _, i1, _ = s.get_loops(s.get_block("B")) - s.blockize(i1) - tvm.ir.assert_structural_equal(s.mod["main"], after_single_elementwise_int64_blockize) + s.blockize(i1, preserve_unit_iters=preserve_unit_iters) + expected = ( + after_single_elementwise_int64_blockize_preserve_unit_iters + if preserve_unit_iters + else after_single_elementwise_int64_blockize + ) + tvm.ir.assert_structural_equal(s.mod["main"], expected) verify_trace_roundtrip(sch=s, mod=single_elementwise_int64) From d0ef417f9ec442d053d7556ee2d734417e9ad4b6 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 8 Dec 2022 11:12:20 -0800 Subject: [PATCH 2/2] fix --- src/tir/schedule/primitive/blockize_tensorize.cc | 8 ++++---- tests/python/unittest/test_meta_schedule_runner.py | 3 --- tests/python/unittest/test_meta_schedule_trace_apply.py | 2 +- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index 4b4e98638505..6860927c4d36 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -430,8 +430,8 @@ Stmt MakeLoopNest(Stmt stmt, const std::vector& loops) { } BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref, - bool preserve_unit_iters, Map* block_sref_reuse, - arith::Analyzer* analyzer) { + Map* block_sref_reuse, arith::Analyzer* analyzer, + bool preserve_unit_iters) { TVM_SREF_TO_FOR(loop_sref); // Step 1: Check and get the only block under `loop`. BlockRealize block_realize = CheckGetSingleChildBlockRealizeOnSRefTree(self, loop_sref); @@ -503,7 +503,7 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref, bool preserve_u arith::Analyzer analyzer; Map block_sref_reuse; BlockRealize blockized = - BlockizeImpl(self, loop_sref, preserve_unit_iters, &block_sref_reuse, &analyzer); + BlockizeImpl(self, loop_sref, &block_sref_reuse, &analyzer, preserve_unit_iters); self->Replace(loop_sref, blockized, block_sref_reuse); StmtSRef result = self->stmt2ref.at(blockized->block.get()); StmtSRef scope_root = tir::GetScopeRoot(self, result, /*require_stage_pipeline=*/false); @@ -524,7 +524,7 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int } else if (sref->stmt->IsInstance()) { arith::Analyzer analyzer; Map block_sref_reuse; - block_realize = BlockizeImpl(self, sref, preserve_unit_iters, &block_sref_reuse, &analyzer); + block_realize = BlockizeImpl(self, sref, &block_sref_reuse, &analyzer, preserve_unit_iters); } else { LOG(FATAL) << "TypeError: Tensorize only support For or Block, but gets: " << GetRef(sref->stmt); diff --git a/tests/python/unittest/test_meta_schedule_runner.py b/tests/python/unittest/test_meta_schedule_runner.py index e10cd89066d4..a79498304b2f 100644 --- a/tests/python/unittest/test_meta_schedule_runner.py +++ b/tests/python/unittest/test_meta_schedule_runner.py @@ -690,8 +690,6 @@ def _check_correct_add(args_before: List[np.ndarray], args_after: List[np.ndarra a_before, b_before, c_before = args_before a_after, b_after, c_after = args_after c_before = a_before + b_before - print(a_before) - print(a_after) assert (a_before == a_after).all() assert (b_before == b_after).all() assert (c_before == c_after).all() @@ -788,7 +786,6 @@ def test_run_evaluator( # Run the module (runner_future,) = runner.run([runner_input]) runner_result = runner_future.result() - print(runner_result.error_msg) assert runner_result.error_msg is None for result in runner_result.run_secs: if isinstance(result, FloatImm): diff --git a/tests/python/unittest/test_meta_schedule_trace_apply.py b/tests/python/unittest/test_meta_schedule_trace_apply.py index df1eb614ab97..9a62207fa261 100644 --- a/tests/python/unittest/test_meta_schedule_trace_apply.py +++ b/tests/python/unittest/test_meta_schedule_trace_apply.py @@ -1180,7 +1180,7 @@ def main(p0: T.Buffer[(1, 32, 7, 7, 16), "uint8"], p1: T.Buffer[(128, 32, 1, 1, B_i8x64: T.int8x64 = B[0, 0:64] B_i32x16: T.int32x16 = T.reinterpret(B_i8x64, dtype="int32x16") C_i32x16: T.int32x16 = C[0:16] - C[0:16] = T.call_llvm_pure_intrin(T.uint32(10060), T.uint32(0), C_i32x16, T.broadcast(A_i32, 16), B_i32x16, dtype="int32x16") + C[0:16] = T.call_llvm_pure_intrin(T.uint32(intrin_id), T.uint32(0), C_i32x16, T.broadcast(A_i32, 16), B_i32x16, dtype="int32x16") for ax0, ax1, ax2, ax3 in T.grid(1, 1, 1, 7): for ax4_fused in T.vectorized(16): with T.block("T_cast_8"):