From e4e0f2e17a6819ff4631db5d484fc8273c20a5dc Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 6 Apr 2021 20:14:18 +0000 Subject: [PATCH] [M1b] Scaffolding ScheduleState data structure Co-authored-by: Siyuan Feng Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Cody Yu Co-authored-by: Jared Roesch --- include/tvm/runtime/object.h | 15 + include/tvm/tir/schedule/block_scope.h | 271 ++++++ include/tvm/tir/schedule/state.h | 216 +++++ python/tvm/tir/__init__.py | 3 + python/tvm/tir/schedule/__init__.py | 21 + python/tvm/tir/schedule/_ffi_api_schedule.py | 20 + python/tvm/tir/schedule/block_scope.py | 152 +++ python/tvm/tir/schedule/state.py | 185 ++++ python/tvm/tir/stmt.py | 19 +- src/printer/tvmscript_printer.cc | 1 + src/tir/analysis/var_touch.cc | 2 +- src/tir/ir/stmt.cc | 6 +- src/tir/schedule/analysis.h | 47 + src/tir/schedule/analysis/analysis.cc | 60 ++ src/tir/schedule/analysis/verify.cc | 146 +++ src/tir/schedule/block_scope.cc | 162 ++++ src/tir/schedule/state.cc | 870 ++++++++++++++++++ src/tir/schedule/utils.h | 93 ++ tests/python/unittest/test_tir_block_scope.py | 145 +++ .../unittest/test_tir_schedule_state.py | 352 +++++++ 20 files changed, 2776 insertions(+), 10 deletions(-) create mode 100644 include/tvm/tir/schedule/block_scope.h create mode 100644 include/tvm/tir/schedule/state.h create mode 100644 python/tvm/tir/schedule/__init__.py create mode 100644 python/tvm/tir/schedule/_ffi_api_schedule.py create mode 100644 python/tvm/tir/schedule/block_scope.py create mode 100644 python/tvm/tir/schedule/state.py create mode 100644 src/tir/schedule/analysis.h create mode 100644 src/tir/schedule/analysis/analysis.cc create mode 100644 src/tir/schedule/analysis/verify.cc create mode 100644 src/tir/schedule/block_scope.cc create mode 100644 src/tir/schedule/state.cc create mode 100644 src/tir/schedule/utils.h create mode 100644 tests/python/unittest/test_tir_block_scope.py create mode 100644 tests/python/unittest/test_tir_schedule_state.py diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 048fc1d5af54..f13bdee09f87 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -739,6 +739,21 @@ struct ObjectPtrEqual { ObjectName* operator->() const { return static_cast(data_.get()); } \ using ContainerType = ObjectName; +/* + * \brief Define object reference methods that is both not nullable and mutable. + * + * \param TypeName The object type name + * \param ParentType The parent type of the objectref + * \param ObjectName The type name of the object. + */ +#define TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ + ObjectName* operator->() const { return static_cast(data_.get()); } \ + ObjectName* get() const { return operator->(); } \ + static constexpr bool _type_is_nullable = false; \ + using ContainerType = ObjectName; + /*! * \brief Define CopyOnWrite function in an ObjectRef. * \param ObjectName The Type of the Node. diff --git a/include/tvm/tir/schedule/block_scope.h b/include/tvm/tir/schedule/block_scope.h new file mode 100644 index 000000000000..49d5e7f2c323 --- /dev/null +++ b/include/tvm/tir/schedule/block_scope.h @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/tir/schedule/block_scope.h + * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope. + * \sa StmtSRefNode + * \sa BlockScopeNode + */ +#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_ +#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_ + +#include + +#include + +namespace tvm { +namespace tir { + +/*! + * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref". + * + * Glossary + * - Block sref: A StmtSRef that points to a TensorIR block. + * - Loop sref: A StmtSRef that points to a TensorIR for loop. + * - Parent sref: The parent reference of an sref is the block or loop reference to the closest + schedulable statement. We define closest to be the nearest schedulable statement of an ancestor in + the AST. + * schedulable statement of its ancestors on the TensorIR AST. + * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref. + * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by + * the TensorIR AST. + */ +class StmtSRefNode : public Object { + public: + /*! + * \brief The block or `for` stmt the object refers to + * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write + * optimization on statements when possible. The strong reference is held in the ScheduleState. + */ + const StmtNode* stmt; + /*! \brief The parent sref. */ + StmtSRefNode* parent; + /*! + * \brief If the statement the sref points to is an element of a SeqStmt in the AST, + * then `seq_index` is set to its index; otherwise `seq_index` is -1 + */ + int64_t seq_index; + + void VisitAttrs(AttrVisitor* v) { + // `stmt` is not visited + // `parent` is not visited + v->Visit("seq_index", &seq_index); + } + + static constexpr const char* _type_key = "tir.StmtSRef"; + TVM_DECLARE_FINAL_OBJECT_INFO(StmtSRefNode, Object); + + /*! \brief Reset the object inplace to the invalid state */ + void Reset() { + this->stmt = nullptr; + this->parent = nullptr; + this->seq_index = -1; + } + + /*! + * \brief Get the referenced statement with proper type checking. + * It serves the same purpose as `ObjectRef::as`, but does not acquire strong reference to `stmt` + * \tparam StmtType The type that `this->stmt` to be downcasted to. Presumably + * tvm::tir::BlockNode or tvm::tir::ForNode + * \return nullptr if type check fails, otherwise the casted result for `this->stmt` + */ + template + const StmtType* StmtAs() const { + if (stmt != nullptr && stmt->IsInstance()) { + return static_cast(stmt); + } else { + return nullptr; + } + } +}; + +/*! + * \brief Managed reference to StmtSRefNode + * \sa StmtSRefNode + */ +class StmtSRef : public ObjectRef { + public: + /*! + * \brief The constructor + * \param stmt The corresponding stmt node, can be either block or for loop. + * \param parent The parent sref. + * \param seq_index The location in an array if the parent of the stmt contains multiple children. + * -1 if the parent does not contain multiple children. + */ + TVM_DLL explicit StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index); + + /*! \return The mutable pointer to the StmtSRefNode */ + StmtSRefNode* get() const { return static_cast(data_.get()); } + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(StmtSRef, ObjectRef, StmtSRefNode); + + public: + /*! + * \return A special StmtSRef, which doesn't point to any stmt in the AST, + * only serving as a "mark" to hint compute-at to do the work of compute-inline + * \note This is only as a faked loop sref for compute-at and reverse-compute-at, + * i.e. + * + * compute-at(block, loop_sref): + * compute-inline(block) if loop_sref.same_as(InlineMark()) + * no-op if loop_sref.same_as(RootMark()) + * compute-at-impl(block, loop_sref) otherwise + */ + TVM_DLL static StmtSRef InlineMark(); + /*! + * \return A special StmtSRef, which doesn't point to any stmt in the AST, + * only serving as a "mark" to hint compute-at to do nothing + * \note This is only as a faked loop sref for compute-at and reverse-compute-at, + * i.e. + * + * compute-at(block, loop_sref): + * compute-inline(block) if loop_sref.same_as(InlineMark()) + * no-op if loop_sref.same_as(RootMark()) + * compute-at-impl(block, loop_sref) otherwise + */ + TVM_DLL static StmtSRef RootMark(); +}; + +/*! + * \brief Type of dependency. Right now we have 4 types of dependencies + * 1) Read-after-write (kRAW) + * 2) Write-after-write (kWAW) + * 3) Write-after-read (kWAR) + * 4) Opaque dependency (kOpaque) + */ +enum class DepKind : int32_t { + kRAW = 0, + kWAW = 1, + kWAR = 2, + kOpaque = 3, +}; + +/*! + * \brief A tuple (src, dst, kind) representing certain types of dependency. + * For example, (A, B, kRAW) means block B depends on block A, and the dependency kind is + * read-after-write, which means block B reads the result written by block A. + */ +class DependencyNode : public Object { + public: + /*! \brief The source of the dependency relation */ + StmtSRef src; + /*! \brief The destination of the dependency relation */ + StmtSRef dst; + /*! \brief The dependency kind */ + DepKind kind; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("src", &src); + v->Visit("dst", &dst); + v->Visit("kind", &kind); + } + + static constexpr const char* _type_key = "tir.Dependency"; + TVM_DECLARE_FINAL_OBJECT_INFO(DependencyNode, Object); +}; + +/*! + * \brief Managed reference to DependencyNode + * \sa DependencyNode + */ +class Dependency : public ObjectRef { + public: + /*! \brief Constructor */ + TVM_DLL explicit Dependency(StmtSRef src, StmtSRef dst, DepKind kind); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Dependency, ObjectRef, DependencyNode); +}; + +/*! + * \brief An object with 1-to-1 correspondence with each block reference in the sref tree. + * This data structure is used to track the producer-consumer dependencies between blocks. + * For example even leaf nodes have a scope node, even though they have no dependencies. + * + * Glossary: + * - Block scope: A contiguous subtree of the sref tree, rooted at each block sref, + * whose components are: + * - scope root: a block sref + * - internal srefs: loop srefs + * - scope leaves: block srefs + * - Child block: The scope leaf blocks under the scope root or a specific internal sref + */ +class BlockScopeNode : public Object { + public: + /*! + * \brief Lookup table for the `src` of dependencies + * \note We intentionally didn't use tvm::Map as the data structure, because we need the values + * inside to be mutable so that they could be further maintained properly during transformations. + */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> src2deps; + /*! \brief Lookup table for the `dst` of dependencies */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> dst2deps; + /*! \brief The mapping from the buffer to the blocks who write it */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_writers; + /*! + * \brief This property indicates that the block scope (rooted at its corresponding block) is + * equivalent to of a stage pipeline. Under the following conditions: + * + * 1) The region cover property holds for every of its child blocks + * 2) No write-after-read dependency + */ + bool stage_pipeline{false}; + + void VisitAttrs(AttrVisitor* v) {} + + static constexpr const char* _type_key = "tir.BlockScope"; + TVM_DECLARE_FINAL_OBJECT_INFO(BlockScopeNode, Object); + + public: + /******** Dependency ********/ + /*! + * \brief Get all dependencies whose `src` equals `src` + * \param src The queried block + * \return The dependencies + */ + TVM_DLL Array GetDepsBySrc(const StmtSRef& src) const; + /*! + * \brief Get all dependencies whose `dst` equals `dst` + * \param dst The queried block + * \return The dependencies + */ + TVM_DLL Array GetDepsByDst(const StmtSRef& dst) const; +}; + +/*! + * \brief Managed reference to BlockScopeNode + * \sa BlockScopeNode + */ +class BlockScope : public ObjectRef { + public: + /*! \brief The constructor creating an empty block scope with on dependency information */ + TVM_DLL BlockScope(); + /*! + * \brief Create the object with the specific leaf blocks, and compute the dependency information + * between the leaf blocks. + * \param child_block_srefs The srefs to the leaf blocks + * \note We assume the leaf blocks are given in pre-DFS order + */ + TVM_DLL BlockScope(const Array& child_block_srefs); + + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockScope, ObjectRef, BlockScopeNode); +}; + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_ diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h new file mode 100644 index 000000000000..12b6fc18dc21 --- /dev/null +++ b/include/tvm/tir/schedule/state.h @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/tir/schedule/state.h + * \brief This file defines ScheduleState, the core data structure of TensorIR scheduling. + */ +#ifndef TVM_TIR_SCHEDULE_STATE_H_ +#define TVM_TIR_SCHEDULE_STATE_H_ + +#include +#include +#include + +#include +#include + +namespace tvm { +namespace tir { + +/*! + * \brief The information about a TensorIR block, it contains two categories of information + * 1) Info on the block scope rooted at a specific block, including dependency tracking, + * flags indicating if the scope is a stage pipeline, etc. + * 2) Info on the block itself, including if the block has a quasi-affine binding, if the regions it + * reads are completely covered by their producers, etc. + */ +struct BlockInfo { + /*! \brief Property of a block scope rooted at the block, storing dependencies in the scope */ + BlockScope scope{nullptr}; + // The properties below are information about the current block realization under its parent scope + /*! \brief Property of a block, indicating the block realization binding is quasi-affine */ + bool affine_binding{false}; + /*! + * \brief Property of a block, indicating each of the block's read regions is fully + * produced by its producers + */ + bool region_cover{false}; + + BlockInfo() = default; + + explicit BlockInfo(BlockScope scope, bool affine_binding = false, bool region_cover = false) + : scope(std::move(scope)), // + affine_binding(affine_binding), // + region_cover(region_cover) {} +}; + +/*! + * \brief The bitmask of the debug flag in the ScheduleStateNode. + * \sa ScheduleStateNode + */ +enum class ScheduleDebugMask : int32_t { + /*! \brief Verify the correctness of the sref tree */ + kVerifySRefTree = 1, + /*! \brief Verify the correctness of affine_binding */ + kVerifyAffineBinding = 2, + /*! \brief Verify the correctness of region_cover */ + kVerifyRegionCover = 4, + /*! \brief Verify the correctness of stage_pipeline */ + kVerifyStagePipeline = 8, +}; + +/*! + * \brief The state of scheduling, which exposes a `Replace` method as + * the primary interface for all the scheduling primitives to manipulate the TensorIR. + * + * The data structure contains the following information + * 1) The AST being scheduled (mod) + * 2) The sref tree of schedulable statements (indicated by the srefs) + * 3) The dependency information of each block scope (block_info) + * 4) A reverse mapping from the AST nodes to that in the sref tree (stmt2ref) + * 5) A debug flag, if set, extra checking is enabled (debug_mode) + */ +class ScheduleStateNode : public Object { + public: + /*! \brief The AST of the module being scheduled */ + IRModule mod; + /*! + * \brief Mapping from a block sref to its correpsonding BlockInfo, + * tracking the dependency inside the block scope, + * and storing necessary information flags for scheduling + */ + std::unordered_map block_info; + /*! \brief The reverse mapping from block/for-loop to their corresponding srefs */ + std::unordered_map stmt2ref; + /*! + * \brief Do extra correctness checking after the class creation + * and each time after calling the Replace method. + * \sa ScheduleDebugMask + */ + int debug_mode; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("mod", &mod); + // `block_info` is not visited + // `stmt2ref` is not visited + v->Visit("debug_mode", &debug_mode); + } + /*! + * \brief Replace the part of the AST, as being pointed to by `src_sref`, + * with a specific statement `tgt_stmt`, and maintain the sref tree accordingly. + * Replace will try to perform copy on write as much as possible when the ScheduleState holds + * the only copy to the IRModule and IR nodes. + * + * Only 3 types of replacements are allowed: from `src_sref->stmt` to `tgt_stmt`. + * 1) Block -> Block + * 2) Loop -> Loop + * 3) Loop -> BlockRealize + * + * \param src_sref The sref to the statement to be replaced + * \param tgt_stmt The statement to be replaced in + * \param block_sref_reuse Maps an old block (to be replaced in the subtree under + * `src_sref->stmt`) to a new block (replaced to, in the subtree under `tgt_stmt`), and enforces + * reuse of srefs between them (rather than create new srefs) i.e. after being replaced, the sref + * that points to the old block will point to the new one + * \note The reuse of loop srefs are detected automatically according to the reuse of loop vars. + */ + TVM_DLL void Replace(const tir::StmtSRef& src_sref, const Stmt& tgt_stmt, + const Map& block_sref_reuse); + /*! + * \brief Trigger the verification according to the `debug_mode` bitmask. + * 1) If the bitmask `kVerifySRefTree` is on, verify the correctness of the sref tree. + * 2) If the bitmask `kVerifyAffineBinding` is on, verify the correctness of `affine_binding` + * 3) If the bitmask `kVerifyRegionCover` is on, verify the correctness of `region_cover` + * 4) If the bitmask `kVerifyStagePipeline` is on, verify the correctness of `stage_pipeline` + */ + TVM_DLL void DebugVerify() const; + + static constexpr const char* _type_key = "tir.ScheduleState"; + TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleStateNode, Object); + + /******** Property of blocks ********/ + /*! \brief Returns the BlockInfo correpsonding to the block sref */ + TVM_DLL BlockInfo GetBlockInfo(const StmtSRef& block_sref) const; + /*! + * \brief Get the BlockScope correpsonding to the sref of scope root block + * \param scope_root The block sref to be retrieved + * \return The corresponding BlockScope + */ + BlockScope GetBlockScope(const StmtSRef& scope_root) const { + return GetBlockInfo(scope_root).scope; + } + /*! + * \brief Check a cached flag indicating if the specific block has quasi-affine bindings + * \param block_sref The block sref to be checked + * \return A boolean flag indicating if the block has quasi-affine bindings + */ + bool IsAffineBlockBinding(const StmtSRef& block_sref) const { + return GetBlockInfo(block_sref).affine_binding; + } + /*! + * \brief Check a cached flag indicating if each of the specific consumer block's read region + * is fully produced by its producers + * \param consumer_block_sref The specific consumer block + * \return A boolean flag indicating if the block has quasi-affine bindings + */ + bool IsRegionCoveredConsumer(const StmtSRef& consumer_block_sref) const { + return GetBlockInfo(consumer_block_sref).region_cover; + } + /*! + * \brief Check a cached flag indicating if a block scope is an equivalence of a stage pipeline + * \param scope_root The block sref to be retrieved + * \return The corresponding BlockScope + */ + bool IsStagePipeline(const StmtSRef& scope_root) const { + return GetBlockScope(scope_root)->stage_pipeline; + } +}; + +/*! + * \brief Managed reference to ScheduleStateNode + * \sa ScheduleStateNode + */ +class ScheduleState : public ObjectRef { + public: + /*! + * \brief Construct a schedule state from an IRModule + * \param mod The IRModule to be scheduled + * \param debug_mode Do extra correctness checking after the class creation + * and each time after calling the Replace method. + */ + TVM_DLL explicit ScheduleState(IRModule mod, int debug_mode = 0); + /*! + * \brief Construct a schedule state from a PrimFunc + * \param func The PrimFunc to be scheduled. A new IRModule will be created with + * this specific PrimFunc as "main" function in the module to be scheduled + * \param debug_mode Do extra correctness checking after the class creation + * and each time after calling the Replace method. + */ + TVM_DLL explicit ScheduleState(PrimFunc func, int debug_mode = 0); + + /*! \return The mutable pointer to the ScheduleStateNode */ + ScheduleStateNode* get() const { return static_cast(data_.get()); } + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleState, ObjectRef, ScheduleStateNode); +}; + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_SCHEDULE_STATE_H_ diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index ad91eab64b52..681fc3172c92 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -48,6 +48,9 @@ from .op import comm_reducer, min, max, sum from .op import q_multiply_shift +from .schedule import StmtSRef, BlockScope, ScheduleState + +from . import schedule from . import ir_builder from . import transform from . import analysis diff --git a/python/tvm/tir/schedule/__init__.py b/python/tvm/tir/schedule/__init__.py new file mode 100644 index 000000000000..21721f70b5bf --- /dev/null +++ b/python/tvm/tir/schedule/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +"""Namespace for the TensorIR schedule API.""" + +from .block_scope import BlockScope, Dependency, DepKind, StmtSRef +from .state import ScheduleDebugMask, ScheduleState diff --git a/python/tvm/tir/schedule/_ffi_api_schedule.py b/python/tvm/tir/schedule/_ffi_api_schedule.py new file mode 100644 index 000000000000..ae8bdfde54bf --- /dev/null +++ b/python/tvm/tir/schedule/_ffi_api_schedule.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.tir.schedule""" +import tvm._ffi + +tvm._ffi._init_api("tir.schedule", __name__) # pylint: disable=protected-access diff --git a/python/tvm/tir/schedule/block_scope.py b/python/tvm/tir/schedule/block_scope.py new file mode 100644 index 000000000000..82814521785d --- /dev/null +++ b/python/tvm/tir/schedule/block_scope.py @@ -0,0 +1,152 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.""" +from enum import IntEnum +from typing import List, Optional, Union + +from tvm._ffi import register_object +from tvm.runtime import Object +from tvm.tir import Block, For + +from . import _ffi_api_schedule + + +@register_object("tir.StmtSRef") +class StmtSRef(Object): + """An object that refers to schedulable elements in the TensorIR, aka "sref". + + Glossary + - Block sref: An StmtSref that points to a TensorIR block. + - Loop sref: An StmtSRef that points to a TensorIR for loop. + - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest + schedulable statement of its ancestors on the TensorIR AST. + - Root sref: Sref to the root block. Every sref has exactly one parent sref + except for root sref. + - Sref tree: The parent-children-relationship of srefs that forms a tree, + uniquely determined by the TensorIR AST. + """ + + seq_index: int + + @property + def stmt(self) -> Optional[Union[Block, For]]: + """The block/for stmt the object refers to""" + return _ffi_api_schedule.StmtSRefStmt(self) # pylint: disable=no-member + + @property + def parent(self) -> Optional["StmtSRef"]: + """The parent sref""" + return _ffi_api_schedule.StmtSRefParent(self) # pylint: disable=no-member + + @staticmethod + def inline_mark() -> "StmtSRef": + """A special StmtSRef, which doesn't point to any stmt in the AST, + only serving as a "mark" to hint compute-at to do the work of compute-inline""" + return _ffi_api_schedule.StmtSRefInlineMark() # pylint: disable=no-member + + @staticmethod + def root_mark() -> "StmtSRef": + """A special StmtSRef, which doesn't point to any stmt in the AST, + only serving as a "mark" to hint compute-at to do nothing""" + return _ffi_api_schedule.StmtSRefRootMark() # pylint: disable=no-member + + +class DepKind(IntEnum): + """Type of dependency. + + Attributes + ---------- + RAW : int = 0 + Read-after-write dependency + WAW : int = 1 + Write-after-write dependency + WAR : int = 2 + Write-after-read dependency. Not supported in TensorIR for now. + OPAQUE: int = 3 + Opaque dependency + """ + + RAW = 0 + WAW = 1 + WAR = 2 + OPAQUE = 3 + + +@register_object("tir.Dependency") +class Dependency(Object): + """A tuple (src, dst, kind) representing certain types of dependency. + For example, (A, B, kRAW) means block B depends on block A, and the dependency kind is + read-after-write, which means block B reads the result written by block A. + + Parameters + ---------- + src : StmtSRef + The source of the dependency relation + dst : StmtSRef + The destination of the dependency relation + kind : DepKind + The dependency kind + """ + + src: StmtSRef + dst: StmtSRef + kind: DepKind + + +@register_object("tir.BlockScope") +class BlockScope(Object): + """An object corresponds to each block sref in the sref tree, + which tracks the producer-consumer dependency between blocks. + + Glossary: + - Block scope: A contiguous subtree of the sref tree, rooted at each block sref, + whose components are: + - scope root: a block sref + - internal srefs: loop srefs + - scope leaves: block srefs + - Child block: The scope leaf blocks under the scope root or a specific internal sref + """ + + def get_deps_by_src(self, block: StmtSRef) -> List[Dependency]: + """Get all dependencies whose `src` is the target`block`. + + Parameters + ---------- + block: StmtSRef + The queried block + + Returns + ------- + blocks: List[Dependency] + The dependencies + """ + return _ffi_api_schedule.BlockScopeGetDepsBySrc(self, block) # pylint: disable=no-member + + def get_deps_by_dst(self, block: StmtSRef) -> List[Dependency]: + """Get all dependencies whose `dst` is the target `block`. + + Parameters + ---------- + block: StmtSRef + The queried block + + Returns + ------- + blocks: List[Dependency] + The dependencies + """ + return _ffi_api_schedule.BlockScopeGetDepsByDst(self, block) # pylint: disable=no-member diff --git a/python/tvm/tir/schedule/state.py b/python/tvm/tir/schedule/state.py new file mode 100644 index 000000000000..180fede228e5 --- /dev/null +++ b/python/tvm/tir/schedule/state.py @@ -0,0 +1,185 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This file defines ScheduleState, the core data structure of TensorIR scheduling.""" +from enum import IntEnum +from typing import Dict, Optional, Union + +from tvm._ffi import register_object +from tvm.ir import IRModule +from tvm.runtime import Object +from tvm.tir import Block, BlockRealize, For, PrimFunc + +from . import _ffi_api_schedule +from .block_scope import BlockScope, StmtSRef + + +class ScheduleDebugMask(IntEnum): + """The bitmask of the `debug_mode` flag in the ScheduleState class. + + If the `debug_mode` flag has a certain bit on, then the correpsonding + verification pass will be conducted. For example, if `(debug_mode & VERIFY_SREF_TREE) != 0`, + then the correctness of the sref tree will be verified after each schedule instruction. + + Attributes + ---------- + VERIFY_SREF_TREE : int = 1 + Verify the correctness of the sref tree + VERIFY_AFFINE_BINDING : int = 2 + Verify the correctness of affine_binding + VERIFY_REGION_COVER : int = 4 + Verify the correctness of region_cover + VERIFY_STAGE_PIPELINE: int = 8 + Verify the correctness of stage_pipeline + """ + + VERIFY_SREF_TREE = 1 + VERIFY_AFFINE_BINDING = 2 + VERIFY_REGION_COVER = 4 + VERIFY_STAGE_PIPELINE = 8 + + +@register_object("tir.ScheduleState") +class ScheduleState(Object): + """The state of scheduling, which exposes a `Replace` method as + the primary resort for all the scheduling primitives to manipulate the TensorIR. + + The data structure contains the following information + 1) The AST being scheduled (mod) + 2) The sref tree of schedulable statements (indicated by the srefs) + 3) The dependency information of each block scope (block_info) + 4) A reverse mapping from the AST nodes to that in the sref tree (get_sref) + 5) A debug flag, if set, extra checking is enabled (debug_mode) + + Parameters + ---------- + mod : IRModule + The AST of the module being scheduled + debug_mode : int + Do extra correctness checking after the object construction + and each time after calling the Replace method. + """ + + mod: IRModule + debug_mode: int + + def __init__( + self, + func_or_mod: Union[PrimFunc, IRModule], + debug_mode: Union[bool, int] = False, + ): + """Construct a schedule state from an IRModule or a PrimFunc + + Parameters + ---------- + func_or_mod : Union[PrimFunc, IRModule] + The IRModule or PrimFunc to be scheduled + debug_mode : Union[bool, int] + Do extra correctness checking after the class creation and each time + after calling the Replace method. + Possible choices of `debug_mode`: + 1) True - Turn on all the checks + 2) False - Turn off all the checks + 3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask + """ + if isinstance(debug_mode, bool): + if debug_mode: + debug_mode = -1 + else: + debug_mode = 0 + if not isinstance(debug_mode, int): + raise TypeError(f"`debug_mode` should be integer or boolean, but gets: {debug_mode}") + self.__init_handle_by_constructor__( + _ffi_api_schedule.ScheduleState, # pylint: disable=no-member + func_or_mod, + debug_mode, + ) + + def get_sref(self, stmt: Union[Block, For]) -> Optional[StmtSRef]: + """Return the corresponding sref that points to the stmt + + Parameters + ---------- + stmt : Union[Block, For] + The schedulable statement in the TensorIR to be retrieved for its sref + + Returns + ------- + sref : StmtSRef + The corresponding sref + """ + return _ffi_api_schedule.ScheduleStateGetSRef(self, stmt) # pylint: disable=no-member + + def get_block_scope(self, block_sref: StmtSRef) -> BlockScope: + """Get the BlockScope correpsonding to the block sref + + Parameters + ---------- + block_sref : StmtSRef + The block sref to be retrieved + + Returns + ------- + sref : StmtSRef + The corresponding sref + """ + return _ffi_api_schedule.ScheduleStateGetBlockScope( # pylint: disable=no-member + self, block_sref + ) + + def replace( + self, + src_sref: StmtSRef, + tgt_stmt: Union[Block, For, BlockRealize], + block_sref_reuse: Optional[Dict[Block, Block]] = None, + ) -> None: + """ + Replace the part of the AST, as being pointed to by `src_sref`, + with a specific statement `tgt_stmt`, and maintain the sref tree accordingly. + Replace will try to perform copy on write as much as possible when the ScheduleState holds + the only copy to the IRModule and IR nodes. + + Only 3 types of replacements are allowed: from `src_sref->stmt` to `tgt_stmt`. + 1) Block -> Block + 2) Loop -> Loop + 3) Loop -> BlockRealize + + Parameters + ---------- + src_sref : StmtSRef + The sref to the statement to be replaced in the TensorIR AST + + tgt_stmt : Union[Block, For, BlockRealize] + The statement to be replaced to + + block_sref_reuse : Optional[Dict[Block, Block]] = None + Maps an old block (to be replaced in the subtree under `src_sref->stmt`) + to a new block (replaced to, in the subtree under `tgt_stmt`), and enforces + reuse of srefs between them (rather than create new srefs) i.e. after being replaced, + the sref that points to the old block will point to the new one + + Note + ---------- + The reuse of loop srefs are detected automatically according to the reuse of loop vars. + """ + if block_sref_reuse is None: + block_sref_reuse = {} + _ffi_api_schedule.ScheduleStateReplace( # pylint: disable=no-member + self, + src_sref, + tgt_stmt, + block_sref_reuse, + ) diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 47462066c364..46f456cd760a 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -26,12 +26,13 @@ assert isinstance(st, tvm.tir.stmt.Store) assert(st.buffer_var == a) """ -from typing import List, Optional, Mapping from enum import IntEnum +from typing import List, Mapping, Optional, Union + import tvm._ffi +from tvm.ir import PrimExpr, Range, Span +from tvm.runtime import Object, const -from tvm.runtime import Object -from tvm.ir import Span, PrimExpr, Range from . import _ffi_api from .buffer import Buffer from .expr import IterVar @@ -589,7 +590,7 @@ class BlockRealize(Stmt): iter_values : List[PrimExpr] The binding values of the block var. - predicate : PrimExpr + predicate : Union[PrimExpr, bool] The predicate of the block. block : Block @@ -607,12 +608,18 @@ class BlockRealize(Stmt): def __init__( self, iter_values: List[PrimExpr], - predicate: PrimExpr, + predicate: Union[PrimExpr, bool], block: Block, span: Optional[Span] = None, ): + if isinstance(predicate, bool): + predicate = const(predicate, "bool") self.__init_handle_by_constructor__( - _ffi_api.BlockRealize, iter_values, predicate, block, span + _ffi_api.BlockRealize, + iter_values, + predicate, + block, + span, ) diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 438079502306..7afdcab371da 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -683,6 +683,7 @@ Doc TVMScriptPrinter::VisitStmt_(const IfThenElseNode* op) { doc << "if " << Print(op->condition) << ":"; doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->then_case)); if (!is_one(op->condition) && op->else_case.defined()) { + doc << Doc::NewLine(); doc << "else:" << Doc::Indent(4, Doc::NewLine() << PrintBody(op->else_case)); } return doc; diff --git a/src/tir/analysis/var_touch.cc b/src/tir/analysis/var_touch.cc index 2a2332955582..40a2cce70ae9 100644 --- a/src/tir/analysis/var_touch.cc +++ b/src/tir/analysis/var_touch.cc @@ -18,7 +18,7 @@ */ /*! - * \file simple_analysis.cc + * \file var_touch.cc * \brief Implementation of simple passes */ #include diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 2aeaae3eb592..87ead3e883e1 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -221,7 +221,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); - p->stream << "while(" << op->condition << "){\n"; + p->stream << "while(" << op->condition << ") {\n"; p->indent += 2; p->Print(op->body); p->indent -= 2; @@ -781,7 +781,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) auto* op = static_cast(node.get()); p->PrintIndent(); PrintBlockTitle(op, p); - p->stream << "{\n"; + p->stream << " {\n"; p->indent += 2; // Print block elements (e.g. reads/writes, etc) @@ -820,7 +820,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) auto* block_op = op->block.get(); p->PrintIndent(); PrintBlockTitle(block_op, p); - p->stream << "{\n"; + p->stream << " {\n"; p->indent += 2; // Print binding iter_values diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h new file mode 100644 index 000000000000..32d9f6d4cb51 --- /dev/null +++ b/src/tir/schedule/analysis.h @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TIR_SCHEDULE_ANALYSIS_H_ +#define TVM_TIR_SCHEDULE_ANALYSIS_H_ + +#include + +namespace tvm { +namespace tir { + +/******** Verification ********/ +/*! + * \brief Verify the sref tree state is consistent with the IR + * \param self The schedule state containing the sref to be verified + * \throw An exception will be thrown if the sref tree is not valid + */ +void VerifySRefTree(const ScheduleState& self); + +/******** Block-loop relation ********/ +/*! + * \brief Get the leaf blocks of a scope where a specific block/loop is in + * \param self The schedule state + * \param parent_sref The StmtSRef that points to the parent block/loop + * \return A list of leaf blocks + */ +Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref); + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_SCHEDULE_ANALYSIS_H_ diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc new file mode 100644 index 000000000000..005ff373106f --- /dev/null +++ b/src/tir/schedule/analysis/analysis.cc @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/******** Block-loop relation ********/ + +Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref) { + struct Collector : public StmtVisitor { + public: + static Array Collect(const ScheduleState& self, const Stmt& stmt) { + Collector collector(self); + collector(stmt); + return std::move(collector.result_); + } + + private: + explicit Collector(const ScheduleState& self) : self_(self) {} + + void VisitStmt_(const BlockNode* block) final { + auto it = self_->stmt2ref.find(block); + ICHECK(it != self_->stmt2ref.end()); + result_.push_back(it->second); + } + + const ScheduleState& self_; + Array result_; + }; + + if (parent_sref->stmt->IsInstance()) { + const auto* loop = static_cast(parent_sref->stmt); + return Collector::Collect(self, loop->body); + } else if (parent_sref->stmt->IsInstance()) { + const auto* block = static_cast(parent_sref->stmt); + return Collector::Collect(self, block->body); + } + ICHECK(false) << "Unreachable"; + throw; +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/analysis/verify.cc b/src/tir/schedule/analysis/verify.cc new file mode 100644 index 000000000000..edb62b54cd1b --- /dev/null +++ b/src/tir/schedule/analysis/verify.cc @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +class SRefTreeVerifier : public StmtVisitor { + public: + static void Verify(const ScheduleStateNode* self) { SRefTreeVerifier(self).Verify(); } + + private: + /*! \brief Constructor */ + explicit SRefTreeVerifier(const ScheduleStateNode* self) : self_(self) {} + + void Verify() { + VisitPrimFuncs(self_->mod, [this](const PrimFuncNode* func) { this->VisitStmt(func->body); }); + ICHECK_EQ(n_sref_visited_, static_cast(self_->stmt2ref.size())); + for (const auto& kv : self_->block_info) { + const StmtSRef& sref = kv.first; + ICHECK(sref->stmt != nullptr) + << "InternalError: An expired sref is found in the block_scope mapping"; + auto it = self_->stmt2ref.find(sref->stmt); + ICHECK(it != self_->stmt2ref.end()) + << "InternalError: The sref points to a statement that does not exist in stmt2ref"; + const StmtSRef& sref2 = it->second; + ICHECK(sref.same_as(sref2)) + << "InternalError: The sref points to a statement whose corresponding sref in stmt2ref " + "is not the same object as itself"; + } + ICHECK_EQ(n_block_sref_visited_, static_cast(self_->block_info.size())); + } + + void VisitStmt_(const BlockNode* block) final { + if (init_block_depth_) { + ICHECK(!self_->stmt2ref.count(block)) << "InternalError: A block inside init block has its " + "corresponding sref, which is not allowed"; + StmtVisitor::VisitStmt_(block); + return; + } + ICHECK(self_->stmt2ref.count(block)) + << "InternalError: A BlockNode should appear in sref map, but it didn't\n" + << GetRef(block); + ++n_sref_visited_; + ++n_block_sref_visited_; + const StmtSRef& sref = self_->stmt2ref.at(block); + ICHECK(self_->block_info.count(sref)) + << "InternalError: Cannot find scope information of the BlockNode:\n" + << GetRef(block); + ICHECK(sref->parent == ancestors_.back()) + << "InternalError: Parent information mismatch for BlockNode:\n" + << GetRef(block) << "\nIts parent is supposed to be:\n" + << GetRef(ancestors_.back()->stmt) << "\nHowever, its parent is incorrect and is:\n" + << (sref->parent ? Optional(GetRef(sref->parent->stmt)) + : Optional(NullOpt)); + ancestors_.push_back(sref.operator->()); + if (block->init.defined()) { + ++init_block_depth_; + VisitStmt(block->init.value()); + --init_block_depth_; + } + VisitStmt(block->body); + ancestors_.pop_back(); + } + + void VisitStmt_(const ForNode* loop) final { + if (init_block_depth_) { + ICHECK(!self_->stmt2ref.count(loop)) << "InternalError: A loop inside init block has its " + "corresponding sref, which is not allowed"; + StmtVisitor::VisitStmt_(loop); + return; + } + ICHECK(self_->stmt2ref.count(loop)) + << "InternalError: A ForNode should appear in sref map, but it didn't\n" + << GetRef(loop); + ++n_sref_visited_; + const StmtSRef& sref = self_->stmt2ref.at(loop); + Optional stmt = NullOpt; + ICHECK(sref->parent == ancestors_.back()) + << "InternalError: Parent information mismatch for ForNode:\n" + << GetRef(loop) << "\nIts parent is supposed to be:\n" + << GetRef(ancestors_.back()->stmt) << "\nHowever, its parent is incorrect and is:\n" + << (sref->parent ? Optional(GetRef(sref->parent->stmt)) + : Optional(NullOpt)); + ancestors_.push_back(sref.operator->()); + StmtVisitor::VisitStmt_(loop); + ancestors_.pop_back(); + } + + void VisitStmt_(const SeqStmtNode* seq_stmt) final { + // Verify seq_index + if (init_block_depth_) { + StmtVisitor::VisitStmt_(seq_stmt); + return; + } + int n = static_cast(seq_stmt->seq.size()); + for (int i = 0; i < n; ++i) { + const Stmt& child = seq_stmt->seq[i]; + StmtSRef sref{nullptr}; + if (const auto* realize = child.as()) { + const auto* block = realize->block.get(); + ICHECK(self_->stmt2ref.count(block)); + sref = self_->stmt2ref.at(block); + } else if (child->IsInstance()) { + ICHECK(self_->stmt2ref.count(child.get())); + sref = self_->stmt2ref.at(child.get()); + } else { + continue; + } + ICHECK_EQ(sref->seq_index, i) << "InternalError: A StmtSRef has incorrect seq_index"; + } + StmtVisitor::VisitStmt_(seq_stmt); + } + + /*! \brief The schedule it belongs to */ + const ScheduleStateNode* self_; + /*! \brief Parent information during the visit */ + std::vector ancestors_ = {nullptr}; + /*! \brief If the visitor is currently in the init block */ + int init_block_depth_ = 0; + /*! \brief Number of srefs that are visited */ + int n_sref_visited_ = 0; + /*! \brief Number of block srefs that are visited */ + int n_block_sref_visited_ = 0; +}; + +void VerifySRefTree(const ScheduleState& self) { SRefTreeVerifier::Verify(self.get()); } + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/block_scope.cc b/src/tir/schedule/block_scope.cc new file mode 100644 index 000000000000..f1ce65e48e03 --- /dev/null +++ b/src/tir/schedule/block_scope.cc @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace tir { + +/******** Utility functions ********/ + +template +using SMap = std::unordered_map; + +/*! + * \brief Add a dependency relation. + * \param src The source of the dependency + * \param dst The destination of the dependecy + * \param kind Type of the dependency + * \note This method is effectively NOP on self-loops + */ +void AddDependency(BlockScopeNode* self, const StmtSRef& src, const StmtSRef& dst, DepKind kind) { + if (!src.same_as(dst)) { + Dependency dep(src, dst, kind); + self->src2deps[src].push_back(dep); + self->dst2deps[dst].push_back(dep); + } +} + +/******** Constructors ********/ + +StmtSRef::StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index) { + ObjectPtr n = make_object(); + n->stmt = stmt; + n->parent = parent; + n->seq_index = seq_index; + data_ = std::move(n); +} + +StmtSRef StmtSRef::InlineMark() { + static StmtSRef result(nullptr, nullptr, -1); + return result; +} + +StmtSRef StmtSRef::RootMark() { + static StmtSRef result(nullptr, nullptr, -1); + return result; +} + +Dependency::Dependency(StmtSRef src, StmtSRef dst, DepKind kind) { + ObjectPtr node = make_object(); + node->src = std::move(src); + node->dst = std::move(dst); + node->kind = kind; + data_ = std::move(node); +} + +BlockScope::BlockScope() { data_ = make_object(); } + +BlockScope::BlockScope(const Array& child_block_srefs) { + ObjectPtr n = make_object(); + SMap> buffer_readers; + SMap>& buffer_writers = n->buffer_writers; + for (const StmtSRef& child_block_sref : child_block_srefs) { + const BlockNode* child_block = TVM_SREF_TO_BLOCK(child_block, child_block_sref); + // Step 1. Update `buffer_readers` and `buffer_writers` for each buffer + for (const BufferRegion& region : child_block->reads) { + buffer_readers[region->buffer].push_back(child_block_sref); + } + for (const BufferRegion& region : child_block->writes) { + buffer_writers[region->buffer].push_back(child_block_sref); + } + // Step 2. Update RAW dependency + for (const BufferRegion& region : child_block->reads) { + auto it = buffer_writers.find(region->buffer); + if (it != buffer_writers.end()) { + for (const StmtSRef& from : it->second) { + AddDependency(n.get(), from, child_block_sref, DepKind::kRAW); + } + } + } + // Step 3. Update WAW dependency + for (const BufferRegion& region : child_block->writes) { + auto it = buffer_writers.find(region->buffer); + if (it != buffer_writers.end()) { + for (const StmtSRef& from : it->second) { + AddDependency(n.get(), from, child_block_sref, DepKind::kWAW); + } + } + } + // Step 4. Update WAR dependency + for (const BufferRegion& region : child_block->writes) { + auto it = buffer_readers.find(region->buffer); + if (it != buffer_readers.end()) { + for (const StmtSRef& from : it->second) { + AddDependency(n.get(), from, child_block_sref, DepKind::kWAR); + } + } + } + } + data_ = std::move(n); +} + +/******** Dependency ********/ + +Array BlockScopeNode::GetDepsBySrc(const StmtSRef& block_sref) const { + auto iter = this->src2deps.find(block_sref); + if (iter != this->src2deps.end()) { + return iter->second; + } else { + return {}; + } +} + +Array BlockScopeNode::GetDepsByDst(const StmtSRef& block_sref) const { + auto iter = this->dst2deps.find(block_sref); + if (iter != this->dst2deps.end()) { + return iter->second; + } else { + return {}; + } +} + +/******** FFI ********/ + +TVM_REGISTER_NODE_TYPE(StmtSRefNode); +TVM_REGISTER_NODE_TYPE(DependencyNode); +TVM_REGISTER_NODE_TYPE(BlockScopeNode); + +TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefStmt") + .set_body_typed([](StmtSRef sref) -> Optional { + return GetRef>(sref->stmt); + }); +TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefParent") + .set_body_typed([](StmtSRef sref) -> Optional { + return GetRef>(sref->parent); + }); +TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefRootMark") // + .set_body_typed(StmtSRef::RootMark); +TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefInlineMark") // + .set_body_typed(StmtSRef::InlineMark); +TVM_REGISTER_GLOBAL("tir.schedule.BlockScopeGetDepsBySrc") + .set_body_method(&BlockScopeNode::GetDepsBySrc); +TVM_REGISTER_GLOBAL("tir.schedule.BlockScopeGetDepsByDst") + .set_body_method(&BlockScopeNode::GetDepsByDst); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc new file mode 100644 index 000000000000..d1b899b05439 --- /dev/null +++ b/src/tir/schedule/state.cc @@ -0,0 +1,870 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace tir { + +template +using SMap = std::unordered_map; + +/**************** Utility functions ****************/ + +/*! + * \brief Set the `StmtSRefNode::seq_index` field for stmt + * \param self The schedule class + * \param stmt The statement, or the realize node of the statement whose sref to be set + * \param seq_index The seq_index to be set + * \note The method is NOP for statements that are not scheduleable, i.e. not For or Block + */ +void SetSeqIndex(ScheduleStateNode* self, const Stmt& stmt, int seq_index) { + if (const auto* realize = stmt.as()) { + const BlockNode* block = realize->block.get(); + ICHECK(self->stmt2ref.count(block)); + self->stmt2ref.at(block)->seq_index = seq_index; + } else if (const auto* block = stmt.as()) { + ICHECK(self->stmt2ref.count(block)); + self->stmt2ref.at(block)->seq_index = seq_index; + } else if (const auto* loop = stmt.as()) { + ICHECK(self->stmt2ref.count(loop)); + self->stmt2ref.at(loop)->seq_index = seq_index; + } else { + // do nothing + } +} + +/*! + * \brief Update seq_index of the children of a SeqStmt + * \param self The schedule class + * \param seq_stmt The SeqStmt whose children need updating + */ +void SetSeqIndexInChildren(ScheduleStateNode* self, const SeqStmtNode* seq_stmt) { + int i = 0; + for (const Stmt& stmt : seq_stmt->seq) { + SetSeqIndex(self, stmt, i); + ++i; + } +} + +/*! + * \brief Update the sref information on the schedule class, as well as the statement of sref itself + * More specifically, update + * `sref->stmt` to `new_stmt` + * `self->stmt2ref`, remove the old statement that sref points to, and add the new statement + * \param self The schedule class to be updated + * \param sref The sref to be updated + * \param new_stmt The statement that replaces the statement inside the sref + */ +void UpdateSRef(ScheduleStateNode* self, StmtSRefNode* sref, const StmtNode* new_stmt) { + ICHECK(new_stmt->IsInstance() || new_stmt->IsInstance()); + const StmtNode* old_stmt = sref->stmt; + ICHECK_NE(new_stmt, old_stmt); + self->stmt2ref[new_stmt] = GetRef(sref); + self->stmt2ref.erase(sref->stmt); + sref->stmt = new_stmt; +} + +/*! + * \brief Get PrimFunc and GlobalVar that the root block belongs to + * \param mod The IRModule + * \param root_block The root block of the PrimFunc + * \param result_g_var The result GlobalVar + * \return The result PrimFunc where the root block belongs to + * \note This function returns the pointer instead of ObjectRef to avoid later copy-on-write + */ +const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block, + GlobalVar* result_g_var) { + for (const auto& kv : mod->functions) { + const GlobalVar& g_var = kv.first; + const BaseFunc& base_func = kv.second; + if (const auto* func = base_func.as()) { + if (const auto* realize = func->body.as()) { + if (realize->block.get() == root_block) { + *result_g_var = g_var; + return func; + } + } + } + } + LOG(FATAL) << "IndexError: Could not get the correpsonding function in the schedule state of the " + "statement:\n" + << GetRef(root_block); + throw; +} + +/**************** Creation ****************/ + +/*! \brief A helper class to create a new ScheduleStateNode from an IRModule */ +class StateCreator : private StmtVisitor { + public: + /*! + * \brief The entry function + * \param self The schedule state to be completed + */ + static ObjectPtr Create(IRModule mod, int debug_mode) { + ObjectPtr n = make_object(); + ScheduleStateNode* self = n.get(); + // Set `n->mod` + n->mod = std::move(mod); + // Set `n->debug_mode` + n->debug_mode = debug_mode; + // Set `n->stmt2ref` and `n->block_info` + StateCreator creator(self); + for (const auto& kv : n->mod->functions) { + const BaseFunc& base_func = kv.second; + if (const auto* func = base_func.as()) { + creator.VisitStmt(func->body); + } + } + return n; + } + + private: + explicit StateCreator(ScheduleStateNode* self) + : self_(self), srefs_{}, realizes_{}, block_frames_{} { + block_frames_.emplace({}); + } + + /*! + * \brief Add a new statement to the stack, which becomes the current scope + * \param stmt A for-loop statement or a block statement + * \return A sref to the stmt + */ + StmtSRef PushSRef(const StmtNode* stmt) { + if (srefs_.empty()) { + srefs_.push_back( + StmtSRef(stmt, + /*parent=*/nullptr, + /*seq_index=*/-1)); // `seq_index` will be set properly in SetSeqIndex + } else { + StmtSRefNode* parent = srefs_.back().get(); + srefs_.push_back( + StmtSRef(stmt, parent, + /*seq_index=*/-1)); // `seq_index` will be set properly in SetSeqIndex + } + return srefs_.back(); + } + + /*! \brief Pop the top of the scope and record it in stmt2ref map */ + StmtSRef PopAndRecordSRef() { + StmtSRef sref = std::move(srefs_.back()); + self_->stmt2ref[sref->stmt] = sref; + srefs_.pop_back(); + return sref; + } + + void MakeBlockInfo(StmtSRef scope_root) { + // Calculate `BlockInfo::scope` + Array child_block_srefs = std::move(block_frames_.back()); + BlockInfo& info = + self_->block_info.emplace(std::move(scope_root), BlockInfo(BlockScope(child_block_srefs))) + .first->second; + // TODO(@junrushao1994): calculate the flags + // Set `affine_binding` + info.affine_binding = false; + // Set `region_cover` + info.region_cover = false; + // Set `stage_pipeline` + info.scope->stage_pipeline = false; + } + + void VisitStmt_(const ForNode* loop) final { + PushSRef(loop); + VisitStmt(loop->body); + PopAndRecordSRef(); + } + + void VisitStmt_(const BlockRealizeNode* realize) final { + realizes_.push_back(realize); + block_frames_.emplace_back(); + const BlockNode* block = realize->block.get(); + // Recursive visit + PushSRef(block); + VisitStmt(block->body); // `block->init` is not visited + StmtSRef sref = PopAndRecordSRef(); + // Create BlockInfo for the block + MakeBlockInfo(sref); + // Update parent scope + block_frames_.pop_back(); + block_frames_.back().push_back(sref); + realizes_.pop_back(); + } + + void VisitStmt_(const SeqStmtNode* seq_stmt) final { + // Set `seq_index` information for SeqStmtNode + StmtVisitor::VisitStmt_(seq_stmt); + SetSeqIndexInChildren(self_, seq_stmt); + } + + /*! \brief The result ScheduleStateNode */ + ScheduleStateNode* self_; + /*! \brief The stack frame used to indicate the current scope */ + std::vector srefs_; + /*! \brief The BlockRealize in the ancestors */ + std::vector realizes_; + /*! \brief The stack frames of blocks in the DFS visit. */ + std::vector> block_frames_; +}; + +/**************** Constructor ****************/ + +ScheduleState::ScheduleState(IRModule mod, int debug_mode) { + CHECK_GE(debug_mode, -1) << "ValueError: negative `debug_mode` other than -1 is not supported"; + data_ = StateCreator::Create(mod, debug_mode); + (*this)->DebugVerify(); +} + +ScheduleState::ScheduleState(PrimFunc func, int debug_mode) + : ScheduleState(IRModule({{GlobalVar("main"), func}}), debug_mode) {} + +/**************** Replace ****************/ + +/* + * The goal of the replacement algorithm is to substitute a subtree `src_stmt` of the AST to a new + * subtree `tgt_stmt`, and maintain the corresponding sref tree accordingly, with some srefs reused, + * so that the srefs users hold doesn't expire. For example, if we split a loop into 2, and the + * original loop has a child block, then the sref to the child block should be reused, so that users + * won't have to acquire that sref again. + * + * The workflow of the replacement algorithm is: + * 1) Detect all possible reuses in class ReuseInfo + * 2) Remove the expired srefs in class SRefTreePruner + * 3) Update the reused the sref, and create the srefs for new statements, in class SRefUpdater + * 4) Renew the ancestors of `src_stmt` to reflect the replacement + */ + +/*! + * \brief Record the different sref reuse types in the replacement + * + * 1) Intact: the subtree appears as the same object on both `src_stmt` and `tgt_stmt`, + * which, given the immutability of the IR, means the entire subtree is unchanged, + * and we do not need to recurse into the subtree. + * + * 2) Loop/Block sref reuse: for two different objects (`src`, `tgt`), + * which are both loops or both blocks, + * there is correspondence between them, + * which makes us to reuse the sref pointing to `src`, and change it to point to `tgt`. + * + * \note The intact reuse and loop sref reuse are collected in the ReuseCollector, + * while the block reuse is specified by the caller. + * + * \sa ReuseCollector + */ +struct ReuseInfo { + /*! + * \brief Kind 1. Intact reuse. If a stmt is in `intact`, it means its corresponding + * sref is reused and it is intact reuse. + */ + std::unordered_set intact; + /*! + * \brief Kind 2.1. Loop sref reuse + * If the loop var of a loop is in `loop_sref_possible_reuse`, + * it means that when `src_stmt` has a loop that uses this loop var, + * the reuse kind is loop sref reuse. + * \note For each loop var in `loop_sref_possible_reuse`, it is possible that `src_stmt` doesn't + * contain a loop that uses this loop var, and that is the reason why it is named "possible". + */ + std::unordered_set loop_sref_possible_reuse; + /*! + * \brief Kind 2.2. Block sref reuse. + * Maps an old Block in `src_stmt` to a new block in `tgt_stmt`, + * indicating the sref to the old block should be reused in the sref to the new block. + */ + std::unordered_map block_sref_reuse; +}; + +/*! + * \brief A helper visitor which collects two cases of sref reuses in the `tgt_stmt`: + * + * 1) Intact: the subtree represented by `intact` appears on both old and new IR. + * Given the immutability of the IR, we can quickly decide that the entire subtree is unchanged, + * which means we do not need to visit into the subtree of the old statement. + * + * 2) Reused block/loop: for two different objects (`src`, `tgt`), + * which are both loops or both blocks, + * and there is correspondence between them, + * which makes us to reuse the sref pointing to `src`, and changes it to point to `tgt`, + */ +class ReuseCollector : public StmtVisitor { + public: + static ReuseInfo Collect(const ScheduleStateNode* self, const Stmt& tgt_stmt) { + ReuseCollector collector(self); + collector.VisitStmt(tgt_stmt); + ReuseInfo result; + result.intact = {collector.intact_.begin(), collector.intact_.end()}; + result.loop_sref_possible_reuse = {collector.loop_vars_.begin(), collector.loop_vars_.end()}; + // `result.block_reuse ` is not set here because ReuseCollector doesn't collect it, + // and it is supposed to be properly set by the caller. + return result; + } + + private: + explicit ReuseCollector(const ScheduleStateNode* self) : self_(self) {} + + void VisitStmt_(const ForNode* op) final { + if (self_->stmt2ref.count(op)) { + intact_.push_back(op); + } else { + // Collect loop vars for detecting reuse of loop sref + loop_vars_.push_back(op->loop_var.get()); + StmtVisitor::VisitStmt_(op); + } + } + + void VisitStmt_(const BlockNode* op) final { + if (self_->stmt2ref.count(op)) { + intact_.push_back(op); + } else { + StmtVisitor::VisitStmt_(op); + } + } + + /*! \brief The schedule state to be worked on */ + const ScheduleStateNode* self_; + /*! \brief The intact statements we have collected along the way of visiting */ + std::vector intact_; + /*! \brief The loop variable we collected in the tgt_stmt */ + std::vector loop_vars_; +}; + +/*! + * \brief A helper visitor which removes the stale srefs in the `src_stmt` + * that are useless after the replacement. + * + * It uses the reuse information previously collected to + * 1) delete those srefs that are not reused. + * 2) return the sref objects that are loop/block sref reuses, but not intact reuses + */ +class SRefTreePruner : public StmtVisitor { + public: + /*! + * \brief The entry function + * \param self The schedule class + * \param info The reuse info about intact reuse and loop/block reuse + * \param src_stmt The `src_stmt` where stale srefs to be removed + * \return Mapping from the reuse elements to reused srefs, more specifically: + * 1) Loop reuse: maps a loop var to the reused sref + * 2) Block reuse: maps a block stmt to the reused sref, + * where the block comes from the subtree of `tgt_stmt` + * 3) Intact reuse: not returned + */ + static std::unordered_map Prune(ScheduleStateNode* self, + const ReuseInfo& reuse_info, + const Stmt& src_stmt) { + SRefTreePruner pruner(self, reuse_info); + pruner.VisitStmt(src_stmt); + return std::move(pruner.reused_srefs_); + } + + private: + explicit SRefTreePruner(ScheduleStateNode* self, const ReuseInfo& reuse_info) + : self_(self), reuse_info_(reuse_info) {} + + void VisitStmt_(const ForNode* op) final { + if (reuse_info_.intact.count(op)) { + return; + } + auto it = self_->stmt2ref.find(op); + ICHECK(it != self_->stmt2ref.end()) + << "IndexError: Cannot find correpsonding StmtSRef for the loop:\n" + << GetRef(op); + StmtSRef& sref = it->second; + // Detect reuse + const VarNode* loop_var = op->loop_var.get(); + if (reuse_info_.loop_sref_possible_reuse.count(loop_var)) { + // sref can be reused + reused_srefs_.emplace(loop_var, std::move(sref)); + } else { + sref->Reset(); + } + // erase the statement + self_->stmt2ref.erase(it); + // detect recursively + VisitStmt(op->body); + } + + void VisitStmt_(const BlockNode* op) final { + if (reuse_info_.intact.count(op)) { + return; + } + auto it = self_->stmt2ref.find(op); + ICHECK(it != self_->stmt2ref.end()) + << "IndexError: Cannot find correpsonding StmtSRef for the block:\n" + << GetRef(op); + StmtSRef& sref = it->second; + // Detect reuse + auto reuse_it = reuse_info_.block_sref_reuse.find(op); + if (reuse_it != reuse_info_.block_sref_reuse.end()) { + // sref can be reused + reused_srefs_.emplace(reuse_it->second, std::move(sref)); + } else { + sref->Reset(); + self_->block_info.erase(sref); + } + // erase the statement + self_->stmt2ref.erase(it); + // detect recursively + // op->init is omitted + VisitStmt(op->body); + } + + /*! \brief The schedule state we are working on */ + ScheduleStateNode* self_; + /*! \brief The reuse information we collected previously */ + const ReuseInfo& reuse_info_; + /*! + * \brief Reused srefs: + * 1) loop var -> StmtSRef + * 2) block stmt -> StmtSRef, where the block comes from the subtree of `tgt_stmt` + */ + std::unordered_map reused_srefs_; +}; + +/*! + * \brief Update the sref in the `tgt_stmt` given the reuse information + * + * After being updated, in the `tgt_stmt` subtree, + * 1) all `StmtSRefNode::parent`s are correct + * 2) all `StmtSRefNode::seq_index`s are correct, except for the root + * 3) all `StmtSRefNode::stmt`s are correct, except for the root + */ +class SRefUpdater : public StmtVisitor { + public: + static void Update(ScheduleStateNode* self, StmtSRefNode* src_stmt_parent, + const std::unordered_map& reused_srefs, + const Stmt& tgt_stmt) { + SRefUpdater(self, src_stmt_parent, reused_srefs).VisitStmt(tgt_stmt); + } + + private: + explicit SRefUpdater(ScheduleStateNode* self, StmtSRefNode* src_stmt_parent, + const std::unordered_map& reused_srefs) + : self_(GetRef(self)), + ancestors_{src_stmt_parent}, + reused_srefs_(reused_srefs) {} + + void VisitStmt_(const ForNode* op) final { + StmtSRef& sref = self_->stmt2ref[op]; + // Detect intact reuse + if (sref.defined()) { + sref->parent = ancestors_.back(); + sref->seq_index = -1; // `seq_index` will be set properly in SetSeqIndex + return; + } + // Detect loop reuse + auto it = reused_srefs_.find(op->loop_var.get()); + if (it != reused_srefs_.end()) { + // Update `stmt2ref[op]` to `reused_srefs_[op->loop_var]` + sref = it->second; + sref->stmt = op; + sref->parent = ancestors_.back(); + sref->seq_index = -1; // `seq_index` will be set properly in SetSeqIndex + } else { + // A new loop sref without reuse + sref = StmtSRef(/*stmt=*/op, /*parent=*/ancestors_.back(), + /*seq_index=*/-1); // `seq_index` will be set properly in SetSeqIndex + } + // Recursive visit + ancestors_.push_back(sref.get()); + VisitStmt(op->body); + ancestors_.pop_back(); + } + + void VisitStmt_(const BlockNode* op) final { + StmtSRef& sref = self_->stmt2ref[op]; + // Detect intact + if (sref.defined()) { + sref->parent = ancestors_.back(); + sref->seq_index = -1; // `seq_index` will be set properly in SetSeqIndex + return; + } + // Detect block reuse + auto it = reused_srefs_.find(op); + if (it != reused_srefs_.end()) { + // Update `stmt2ref[op]` to `reused_srefs_[op]` + sref = it->second; + sref->stmt = op; + sref->parent = ancestors_.back(); + sref->seq_index = -1; // `seq_index` will be set properly in SetSeqIndex + } else { + // A new block sref without reuse + sref = StmtSRef(/*stmt=*/op, /*parent=*/ancestors_.back(), + /*seq_index=*/-1); // `seq_index` will be set properly in SetSeqIndex + } + // Recursive visit + ancestors_.push_back(sref.get()); + VisitStmt(op->body); + ancestors_.pop_back(); + // Additionally, need to update the scope because the block is changed + UpdateBlockInfo(sref); + } + + void VisitStmt_(const SeqStmtNode* seq_stmt) final { + StmtVisitor::VisitStmt_(seq_stmt); + SetSeqIndexInChildren(self_.get(), seq_stmt); + } + + void UpdateBlockInfo(const StmtSRef& block_sref) { + using TIter = std::unordered_map::iterator; + // The caller is responsible for correcting the flags + BlockInfo new_info(BlockScope(GetChildBlocks(self_, block_sref))); + std::pair insert_result = self_->block_info.emplace(block_sref, new_info); + bool inserted = insert_result.second; + BlockInfo& info = insert_result.first->second; + if (inserted) { + // Insertion has happened, update the flags accordingly + BlockInfo& info = insert_result.first->second; + info.affine_binding = false; + info.region_cover = false; + info.scope->stage_pipeline = false; + } else { + // Insertion didn't take place, because the entry has been there before. + // In this case, we assume that flags are still valid so intentionally keep them unchanged + info.scope = std::move(new_info.scope); + } + } + + /*! \brief The schedule state class to be worked on */ + ScheduleState self_; + /*! \brief A stack containing all the ancestor For/Block nodes during the visit */ + std::vector ancestors_; + /*! \brief Maps the loop var / block to the reused sref */ + const std::unordered_map& reused_srefs_; +}; + +/*! + * \brief A helper that returns a new copy of `parent_stmt`, + * where the subtree `child_src_stmt` is replaced with the subtree `child_tgt_stmt`. + * \note The visitor assumes `child_src_stmt` is the child of `parent_stmt` in the sref tree. + */ +class ChildReplacer : private StmtMutator { + public: + static Stmt Replace(const StmtNode* parent_stmt, const StmtNode* child_src_stmt, + const Stmt& child_tgt_stmt, int seq_index, bool allow_copy_on_write) { + // Check the invariant + ICHECK(child_src_stmt->IsInstance() || // + child_src_stmt->IsInstance()); + ICHECK(child_tgt_stmt->IsInstance() || // + child_tgt_stmt->IsInstance() || // + child_tgt_stmt->IsInstance()); + ChildReplacer replacer(child_src_stmt, child_tgt_stmt, seq_index); + replacer.allow_copy_on_write_ = allow_copy_on_write; + return replacer.CopyOnWriteAndVisit(parent_stmt); + } + + private: + explicit ChildReplacer(const StmtNode* src_stmt, const Stmt& tgt_stmt, int seq_index) + : src_stmt_(src_stmt), tgt_stmt_(tgt_stmt), seq_index_(seq_index) {} + + Stmt VisitStmt(const Stmt& stmt) final { + if (stmt.get() == src_stmt_) { + // If the statement matches the `src_stmt` to be replaced, just return the `tgt_stmt` + return tgt_stmt_; + } else { + return StmtMutator::VisitStmt(stmt); + } + } + + // Skipping sibling blocks and loops other than `src_stmt_` + Stmt VisitStmt_(const BlockNode* op) final { return GetRef(op); } + Stmt VisitStmt_(const ForNode* op) final { return GetRef(op); } + + Stmt VisitStmt_(const SeqStmtNode* op) final { + int i = this->seq_index_; + int n = static_cast(op->seq.size()); + if (0 <= i && i < n) { + const Stmt& stmt = op->seq[i]; + Optional new_stmt = NullOpt; + const StmtNode* src_stmt = this->src_stmt_; + // `stmt` can be For or BlockRealize + // `src_stmt` can be For or Block + // so the match from `stmt` to `src_stmt` can be + // 1) For -> For + // 2) BlockRealize -> Block + if (stmt.get() == src_stmt) { + // Case 1. src_stmt is For, stmt is For + new_stmt = tgt_stmt_; + } else if (const auto* realize = stmt.as()) { + // Case 2. stmt is BlockRealize, src_stmt is Block + if (realize->block.get() == src_stmt) { + const auto* tgt_block = TVM_TYPE_AS(tgt_block, tgt_stmt_, BlockNode); + ObjectPtr new_realize = make_object(*realize); + new_realize->block = GetRef(tgt_block); + new_stmt = BlockRealize(std::move(new_realize)); + } + } + // Move new_stmt to position i + if (new_stmt.defined()) { + ObjectPtr new_seq_stmt = CopyOnWrite(op); + new_seq_stmt->seq.Set(i, new_stmt.value()); + return SeqStmt(std::move(new_seq_stmt)); + } + } + return StmtMutator::VisitStmt_(op); + } + + Stmt CopyOnWriteAndVisit(const StmtNode* parent_stmt) { + // Step 1. Copy-on-write the `parent_stmt` and extract its `body`, + // where `body` means the body of either a block or a loop + // Step 2. Mutate the `block/loop->body`, searching for `child_old_stmt` + // and replace it with `child_tgt_stmt` + if (parent_stmt->IsInstance()) { + auto* block = const_cast(static_cast(parent_stmt)); + ObjectPtr new_block = CopyOnWrite(block); + new_block->body = this->VisitStmt(new_block->body); + return Block(std::move(new_block)); + } else if (parent_stmt->IsInstance()) { + auto* loop = const_cast(static_cast(parent_stmt)); + ObjectPtr new_loop = CopyOnWrite(loop); + new_loop->body = this->VisitStmt(new_loop->body); + return For(std::move(new_loop)); + } + LOG(FATAL) << "TypeError: Unexpected type: " << parent_stmt->GetTypeKey(); + throw; + } + + /*! \brief The `src_stmt` to be replaced */ + const StmtNode* src_stmt_; + /*! \brief The `tgt_stmt` to be replaced in */ + const Stmt& tgt_stmt_; + /*! + * \brief The `seq_index` of the `src_stmt` + * \sa StmtSRefNode + */ + int seq_index_; +}; + +void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_stmt, + const Map& _block_sref_reuse) { + if (this->debug_mode != 0) { + const StmtNode* src_stmt = _src_sref->stmt; + bool input_correct = + (src_stmt->IsInstance() && tgt_stmt->IsInstance()) || + (src_stmt->IsInstance() && tgt_stmt->IsInstance()) || + (src_stmt->IsInstance() && tgt_stmt->IsInstance()); + if (!input_correct) { + LOG(FATAL) << "TypeError: src_stmt has type: " << src_stmt->GetTypeKey() + << ". tgt_stmt has type: " << tgt_stmt->GetTypeKey() << ".\nsrc_stmt:\n" + << GetRef(src_stmt) << "\ntgt_stmt:\n" + << tgt_stmt; + } + } + // Rule out the case that no replacement happens + if (_src_sref->stmt == tgt_stmt.get()) { + return; + } + // Reset sref as a new sref so that its content won't be affected by subsequent changes + StmtSRef src_sref(_src_sref->stmt, _src_sref->parent, _src_sref->seq_index); + Stmt src_stmt = GetRef(src_sref->stmt); + // Step 1. Create all the nodes needed for the new sref tree. + // After this step + // 1) all `parent`s are correct + // 2) all `seq_index`s are correct, except for the root + // 3) all `stmt`s are correct, except for the root + { + // Step 0. Setup block_sref_reuse + std::unordered_map block_sref_reuse; + block_sref_reuse.reserve(_block_sref_reuse.size() + 1); + for (const auto& kv : _block_sref_reuse) { + block_sref_reuse.emplace(kv.first.get(), kv.second.get()); + } + // Step 1.1. Collect info for different kinds of reuses + // 1) intact + // 2) loop/block reuse + ReuseInfo reuse_info = ReuseCollector::Collect(this, tgt_stmt); + reuse_info.block_sref_reuse = std::move(block_sref_reuse); + // Step 1.2. Collect loop/block reuse to their corresponding srefs + // and remove those srefs in the `src_stmt` that are no longer used after replacement + std::unordered_map reused_srefs = + SRefTreePruner::Prune(this, reuse_info, src_stmt); + // Step 1.3. Update the sref tree, inserting newly created srefs and properly handle reused + // srefs in `tgt_stmt` + SRefUpdater::Update(this, src_sref->parent, reused_srefs, tgt_stmt); + } + // Step 2. Set the ancestors' children properly + // Iteratively visit the ancestors, creating new ones whose `body`s are properly fixed. + // The visit stops when all the ancestors are uniquely referenced, i.e. can mutate inplace. + // Along the way, because we create a new ancestor path, + // we need to update those sref points from old ancestors to newly created ones + // Variables: + // 1) `num_copy_steps`. The maximum number of hops until we need to copy. To reach a node that + // can be mutated inplace, it needs `num_copy_steps + 1` hops. + // 2) `need_module_copy`. If true, need to mutate the PrimFunc and IRModule the sref belongs to. + // 3) `g_var` and `g_func`. Indicate which GlobalVar and PrimFunc the sref corresponds to + int num_copy_steps = -1; + bool need_module_copy = false; + const PrimFuncNode* g_func = nullptr; + GlobalVar g_var; + { + int i = 0; + const StmtSRefNode* p = src_sref.get(); + while (true) { + if (!p->stmt->unique()) { + num_copy_steps = i; + } + if (p->parent == nullptr) { + break; + } + ++i; + p = p->parent; + } + // Find `g_func` and `g_var` where the `src_sref` is in + g_func = GetRootPrimFunc(this->mod, p->stmt, &g_var); + need_module_copy = num_copy_steps == i || // + !this->mod.unique() || // + !this->mod->functions.unique() || // + !g_func->unique(); + } + // Loop invariant: + // + // Before step `i`: + // 1) `child_sref` is `src_sref` going up by `i` steps + // 2) `child_tgt_stmt` is the subtree that `child_sref` should correspond to after replacement + // 3) except for the subtree root, srefs that point to the subtree of `child_tgt_stmt` are + // correct 4) for the subtree root of `child_tgt_stmt`, `child_sref` has not pointed to it yet + // 5) `tgt_stmt` is of type Loop, Block or BlockRealize + // + // During step `i`: + // 1) Create `parent_stmt` that corresponds to `child_sref->parent` + // 2) Point `child_sref` to `child_tgt_stmt` + // 3) `tgt_stmt` is of type Loop or Block + StmtSRefNode* child_sref = src_sref.get(); + Stmt child_tgt_stmt = std::move(tgt_stmt); + for (int i = 0; (need_module_copy || i <= num_copy_steps) && child_sref->parent != nullptr; ++i) { + bool can_directly_mutate_parent = !need_module_copy && i == num_copy_steps; + // Replace `child_sref->stmt` to `child_tgt_stmt`. + const StmtNode* parent_stmt = child_sref->parent->stmt; + const StmtNode* child_src_stmt = child_sref->stmt; + // Step 2.1. Link `child_sref` to `child_tgt_stmt` + if (i == 0) { + // As the invariance of SRefUpdater, + // the `seq_index` of the root of `tgt_stmt` is set as -1, + // which might be incorrect + SetSeqIndex(this, child_tgt_stmt, child_sref->seq_index); + } else { + // Point `child_sref` to `child_tgt_stmt` + UpdateSRef(this, child_sref, child_tgt_stmt.get()); + } + // Step 2.2. Create `new_parent_stmt`, by mutating the body of `parent_stmt` + Stmt new_parent_stmt = + ChildReplacer::Replace(parent_stmt, child_src_stmt, child_tgt_stmt, + /*seq_index=*/child_sref->seq_index, + /*allow_copy_on_write=*/can_directly_mutate_parent); + // Step 2.3. Go to next parent + if (can_directly_mutate_parent) { + // If the node can be directly mutated inplace, + // then there is no need to update its parent and the function + break; + } + child_tgt_stmt = std::move(new_parent_stmt); + child_sref = child_sref->parent; + } + // Step 3. Handle the case that we mutate the root + if (need_module_copy) { + // From the loop invariant, upon exit, while its subtree is properly set, + // `child_sref` is not properly to `child_tgt_stmt` yet. + if (src_sref->parent != nullptr) { + // Not replacing a root + UpdateSRef(this, child_sref, child_tgt_stmt.get()); + } + // Ensure the uniqueness of `this->mod` and `this->mod->functions` + IRModuleNode* new_mod = this->mod.CopyOnWrite(); + MapNode* new_map = new_mod->functions.CopyOnWrite(); + // Move out the PrimFunc where the sref belong while ensuring uniqueness + PrimFunc ref_new_func = Downcast(std::move(new_map->at(g_var))); + ICHECK(ref_new_func.get() == g_func); + PrimFuncNode* new_func = ref_new_func.CopyOnWrite(); + // If `g_func` was not unique, after the 3 lines above: + // `ref_new_func` points to a unique PrimFunc + // `g_func` points to the previous PrimFunc if it is not unique + // If `g_func` was unique, after the 3 lines above: + // `ref_new_func` points to the same unique function that `g_func` points to + // Update the body of the function the sref belongs to Assign + const auto* realize = TVM_TYPE_AS(realize, g_func->body, BlockRealizeNode); + // Make `child_tgt_stmt` the root block + const auto* child_block = TVM_TYPE_AS(child_block, child_tgt_stmt, BlockNode); + ObjectPtr new_realize = make_object(*realize); + new_realize->block = GetRef(child_block); + new_func->body = BlockRealize(std::move(new_realize)); + // Finally, move the `ref_new_func` back and update `this->mod` + new_map->at(g_var) = std::move(ref_new_func); + this->mod = GetRef(new_mod); + } + constexpr int kVerifySRefTree = static_cast(ScheduleDebugMask::kVerifySRefTree); + if (debug_mode == -1 || (debug_mode & kVerifySRefTree)) { + VerifySRefTree(GetRef(this)); + } +} + +void ScheduleStateNode::DebugVerify() const { + constexpr int kVerifySRefTree = static_cast(ScheduleDebugMask::kVerifySRefTree); + constexpr int kVerifyAffineBinding = static_cast(ScheduleDebugMask::kVerifyAffineBinding); + constexpr int kVerifyRegionCover = static_cast(ScheduleDebugMask::kVerifyRegionCover); + constexpr int kVerifyStagePipeline = static_cast(ScheduleDebugMask::kVerifyStagePipeline); + ICHECK_GE(debug_mode, -1); + if (debug_mode == -1 || (debug_mode & kVerifySRefTree)) { + VerifySRefTree(GetRef(this)); + } + if (debug_mode == -1 || (debug_mode & kVerifyAffineBinding)) { + // TODO(@junrushao1994): Verify affine block binding + } + if (debug_mode == -1 || (debug_mode & kVerifyRegionCover)) { + // TODO(@junrushao1994): Verify region cover + } + if (debug_mode == -1 || (debug_mode & kVerifyStagePipeline)) { + // TODO(@junrushao1994): Verify stage pipeline + } +} + +/**************** BlockInfo-related ****************/ + +BlockInfo ScheduleStateNode::GetBlockInfo(const StmtSRef& block_sref) const { + const auto* block = TVM_SREF_TO_BLOCK(block, block_sref); + auto it = this->block_info.find(block_sref); + CHECK(it != this->block_info.end()) + << "IndexError: Cannot find the corresponding BlockScope to the block sref:\n" + << GetRef(block_sref->stmt); + return it->second; +} + +/**************** FFI ****************/ + +TVM_REGISTER_NODE_TYPE(ScheduleStateNode); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleState").set_body_typed([](ObjectRef obj, int debug_mode) { + if (const auto* func = obj.as()) { + return ScheduleState(GetRef(func), debug_mode); + } + if (const auto* mod = obj.as()) { + return ScheduleState(GetRef(mod), debug_mode); + } + LOG(FATAL) << "TypeError: Expects `IRModule` or `PrimFunc`, but gets: " << obj->GetTypeKey(); + throw; +}); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetBlockScope") + .set_body_method(&ScheduleStateNode::GetBlockScope); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateReplace") + .set_body_method(&ScheduleStateNode::Replace); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetSRef") + .set_body_typed([](ScheduleState self, Stmt stmt) -> Optional { + auto it = self->stmt2ref.find(stmt.get()); + return it != self->stmt2ref.end() ? it->second : Optional(NullOpt); + }); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h new file mode 100644 index 000000000000..63ec77dcf312 --- /dev/null +++ b/src/tir/schedule/utils.h @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TIR_SCHEDULE_UTILS_H_ +#define TVM_TIR_SCHEDULE_UTILS_H_ + +#include +#include +#include +#include +#include +#include + +#include "./analysis.h" + +namespace tvm { +namespace tir { + +/*! + * \brief A helper macro to convert an sref to the statement it points to, + * then check if the downcasting succeeded. + * \param Result The result variable, used for checking + * \param SRef The SRef to be casted + * \param Type The type to be casted to, can be Block or For + */ +#define TVM_SREF_AS_OR_ERR(Result, SRef, Type) \ + SRef->StmtAs(); \ + ICHECK(Result) + +/*! + * \brief A helper macro to convert an sref to the block it points to, + * throwing an internal error if downcasting fails + * \param Result The result variable, used for checking + * \param SRef The SRef to be casted + */ +#define TVM_SREF_TO_BLOCK(Result, SRef) \ + TVM_SREF_AS_OR_ERR(Result, SRef, ::tvm::tir::BlockNode) \ + << "TypeError: Expects StmtSRef `" << #SRef \ + << "` points to `Block`, but gets: " << (SRef->stmt ? SRef->stmt->GetTypeKey() : "None") + +/*! + * \brief A helper macro to convert an sref to the for-loop it points to, + * throwing an internal error if downcasting fails + * \param Result The name of the result variable, used for checking + * \param SRef The SRef to be casted + */ +#define TVM_SREF_TO_FOR(Result, SRef) \ + TVM_SREF_AS_OR_ERR(Result, SRef, ::tvm::tir::ForNode) \ + << "TypeError: Expects StmtSRef `" << #SRef \ + << "` points to `Loop`, but gets: " << (SRef->stmt ? SRef->stmt->GetTypeKey() : "None") + +/*! + * \brief Downcast a TVM ObjectRef to its corresponding container using `ObjectRef::as`, + * then check if the downcasting succeeded. + * \param Result The result variable, used for checking + * \param From The ObjectRef to be downcasted + * \param Type The type to be downcasted to + */ +#define TVM_TYPE_AS_OR_ERR(Result, From, Type) \ + From.as(); \ + ICHECK(Result) + +/*! + * \brief Downcast a TVM ObjectRef to its corresponding container using `ObjectRef::as`, + * throwing an internal error if downcast fails. + * \param Result The result variable, used for checking + * \param From The ObjectRef to be downcasted + * \param Type The type to be downcasted to + */ +#define TVM_TYPE_AS(Result, From, Type) \ + TVM_TYPE_AS_OR_ERR(Result, From, Type) \ + << "TypeError: Expects `" << #From << "` to have type `" << Type::_type_key \ + << "`, but gets: " << (From.defined() ? From->GetTypeKey() : "None") + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_SCHEDULE_UTILS_H_ diff --git a/tests/python/unittest/test_tir_block_scope.py b/tests/python/unittest/test_tir_block_scope.py new file mode 100644 index 000000000000..4a914f5063f8 --- /dev/null +++ b/tests/python/unittest/test_tir_block_scope.py @@ -0,0 +1,145 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import tvm +from tvm import tir +from tvm.script import ty +from tvm.tir.schedule import DepKind +from tvm.tir.stmt_functor import post_order_visit + +# pylint: disable=no-member,invalid-name,unused-variable + + +@tvm.script.tir +def elementwise(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), "float32") + C = tir.match_buffer(c, (128, 128), "float32") + B = tir.alloc_buffer((128, 128), "float32") + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "init") as [vi, vj]: + C[vi, vj] = tir.float32(0) + for k in range(0, 128): + with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@tvm.script.tir +def war_dependency(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + +# pylint: enable=no-member,invalid-name,unused-variable + +# pylint: disable=invalid-name + + +def _get_block(s: tir.ScheduleState, name_hint: str) -> tir.StmtSRef: + result = None + + def f_visit(node): + nonlocal result + if isinstance(node, tvm.tir.Block) and node.name_hint == name_hint: + result = node + + func = s.mod["main"] + post_order_visit(func.body, f_visit) + assert result is not None and isinstance(result, tvm.tir.Block) + return s.get_sref(result) + + +def test_elementwise_dependency(): + s = tir.ScheduleState(elementwise, debug_mode=True) + root = _get_block(s, "root") + block_b = _get_block(s, "B") + block_c = _get_block(s, "C") + # Check get_deps_by_src + (dep,) = s.get_block_scope(root).get_deps_by_src(block_b) + assert dep.src.same_as(block_b) + assert dep.dst.same_as(block_c) + assert dep.kind == DepKind.RAW + # Check get_deps_by_dst + (dep,) = s.get_block_scope(root).get_deps_by_dst(block_c) + assert dep.src.same_as(block_b) + assert dep.dst.same_as(block_c) + assert dep.kind == DepKind.RAW + + +def test_matmul_dependency(): + s = tir.ScheduleState(matmul, debug_mode=True) + root = _get_block(s, "root") + init = _get_block(s, "init") + update = _get_block(s, "update") + # Check get_deps_by_src + p0, p1 = s.get_block_scope(root).get_deps_by_src(init) + assert p0.src.same_as(init) + assert p0.dst.same_as(update) + assert p1.src.same_as(init) + assert p1.dst.same_as(update) + assert (p0.kind == DepKind.RAW and p1.kind == DepKind.WAW) or ( + p0.kind == DepKind.WAW and p1.kind == DepKind.RAW + ) + # Check get_deps_by_dst + p0, p1 = s.get_block_scope(root).get_deps_by_dst(update) + assert p0.src.same_as(init) + assert p0.dst.same_as(update) + assert p1.src.same_as(init) + assert p1.dst.same_as(update) + assert (p0.kind == DepKind.RAW and p1.kind == DepKind.WAW) or ( + p0.kind == DepKind.WAW and p1.kind == DepKind.RAW + ) + + +def test_war_dependency(): + s = tir.ScheduleState(war_dependency, debug_mode=True) + root = _get_block(s, "root") + block_c = _get_block(s, "C") + block_b = _get_block(s, "B") + # Check get_deps_by_src + (dep,) = s.get_block_scope(root).get_deps_by_src(block_c) + assert dep.src.same_as(block_c) + assert dep.dst.same_as(block_b) + assert dep.kind == DepKind.WAR + # Check get_deps_by_dst + (dep,) = s.get_block_scope(root).get_deps_by_dst(block_b) + assert dep.src.same_as(block_c) + assert dep.dst.same_as(block_b) + assert dep.kind == DepKind.WAR + + +if __name__ == "__main__": + test_elementwise_dependency() + test_matmul_dependency() + test_war_dependency() diff --git a/tests/python/unittest/test_tir_schedule_state.py b/tests/python/unittest/test_tir_schedule_state.py new file mode 100644 index 000000000000..ac98725ef9f8 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_state.py @@ -0,0 +1,352 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring + +import gc + +import tvm +from tvm import tir +from tvm.ir import IRModule +from tvm.script import ty + +# pylint: disable=no-member,invalid-name,unused-variable + + +@tvm.script.tir +def elementwise(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), "float32") + C = tir.match_buffer(c, (128, 128), "float32") + B = tir.alloc_buffer((128, 128), "float32") + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "init") as [vi, vj]: + C[vi, vj] = tir.float32(0) + for k in range(0, 128): + with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@tvm.script.tir +def block_in_opaque_block(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), "float32") + B = tir.match_buffer(b, (128, 128), "float32") + with tir.block([128], "B") as vi: + tir.reads([A[0:128, 0:128]]) + tir.writes([B[0:128, 0:128]]) + B[vi, 0] = A[vi, 0] + if A[vi, 0] == 0.0: + with tir.block([], "C"): + tir.reads([A[0:128, 0:128]]) + tir.writes([B[0:128, 0:128]]) + with tir.block([128], "D") as vj: + B[vi, vj] = A[vi, vj] * 3.0 + else: + with tir.block([], "E"): + tir.reads([A[0:128, 0:128]]) + tir.writes([B[0:128, 0:128]]) + with tir.block([128], "F") as vj: + B[vi, vj] = A[vi, vj] * 2.0 + + +# pylint: enable=no-member,invalid-name,unused-variable + + +def replace_ir_builder(deep_copy=False, realize=False): + new_func = tvm.script.from_source(tvm.script.asscript(elementwise)) + s = tir.ScheduleState(new_func, debug_mode=True) + target = tvm.tir.Block( + iter_vars=[], + reads=[], + writes=[], + name_hint="target", + body=s.mod["main"].body.block.body[1], + init=None, + alloc_buffers=None, + match_buffers=None, + annotations=None, + ) + if realize: + target = tvm.tir.BlockRealize( + iter_values=[], + predicate=True, + block=target, + ) + if deep_copy: + target.__setstate__(target.__getstate__()) + gc.collect() + return s, target + + +def replace_ir_builder_module(deep_copy=False, realize=False): + new_func = tvm.script.from_source(tvm.script.asscript(elementwise)) + other_func = tvm.script.from_source(tvm.script.asscript(elementwise)) + mod = IRModule(functions={"main": new_func, "other": other_func}) + s = tir.ScheduleState(mod, debug_mode=True) + target = tvm.tir.Block( + iter_vars=[], + reads=[], + writes=[], + name_hint="target", + body=s.mod["main"].body.block.body[1], + init=None, + alloc_buffers=None, + match_buffers=None, + annotations=None, + ) + if realize: + target = tvm.tir.BlockRealize( + iter_values=[], + predicate=True, + block=target, + ) + if deep_copy: + target.__setstate__(target.__getstate__()) + gc.collect() + return s, target + + +def replace_ir_builder_with_opaque(): + func = tvm.script.from_source(tvm.script.asscript(block_in_opaque_block)) + s = tir.ScheduleState(func, debug_mode=True) + gc.collect() + return s + + +def test_replace_direct_write0(): + s, target = replace_ir_builder(realize=True) + old_hash = s.mod["main"].__hash__() + sref = s.get_sref(s.mod["main"].body.block.body[1]) + s.replace(sref, target) + # There is no other reference so the AST node can be written directly + assert old_hash == s.mod["main"].__hash__() + # Check the replaced part is equal to the target + tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[1], target) + # The target reuse the stmt of the sref, so the sref won't be None + assert sref.stmt is not None + + +def test_replace_direct_write1(): + s, target = replace_ir_builder(realize=True) + old_hash = s.mod["main"].body.block.body.__hash__() + hold_ref = s.mod["main"].body.block.body[1] + sref = s.get_sref(s.mod["main"].body.block.body[1]) + s.replace(sref, target) + # There is no other reference so the AST node can be written directly + assert old_hash == s.mod["main"].body.block.body.__hash__() + assert not tvm.ir.structural_equal(hold_ref.body, target) + # Check the replaced part is equal to the target + tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[1], target) + # The target reuse `sref.stmt`, so the sref won't be None + assert sref.stmt is not None + + +def test_replace_copy(): + s, target = replace_ir_builder(deep_copy=True, realize=True) + old_hash = s.mod["main"].__hash__() + # We hold another reference of func + old_func = s.mod["main"] + sref = s.get_sref(s.mod["main"].body.block.body[0]) + s.replace(sref, target) + # We need to copy the whole func to remain the old_func unchanged + assert old_hash != s.mod["main"].__hash__() + assert not tvm.ir.structural_equal(old_func.body, s.mod["main"].body) + assert old_hash == old_func.__hash__() + # Check the replaced part is equal to the target + tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[0], target) + # The replaced AST node will be deleted, so the ref will be None + assert sref.stmt is None + + +def test_replace_partial_copy0(): + s, target = replace_ir_builder(deep_copy=True, realize=True) + func_old_hash = s.mod["main"].__hash__() + hold_ref = s.mod["main"].body.block.body[0] + ref_old_hash = hold_ref.__hash__() + sref = s.get_sref(s.mod["main"].body.block.body[0].body) + other_part_hash = s.mod["main"].body.block.body[1].__hash__() + s.replace(sref, target) + # The stmt is held by `hold_sref`, so it will be coped in copy-on-write because the ref count is not unique + assert ref_old_hash != s.mod["main"].body.block.body[0].__hash__() + assert not tvm.ir.structural_equal(hold_ref.body, target) + # The function and the other part stmt can be directly written + assert func_old_hash == s.mod["main"].__hash__() + assert other_part_hash == s.mod["main"].body.block.body[1].__hash__() + # Check the replaced part is equal to the target + tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[0].body, target) + # The replaced AST node will be deleted, so the ref will be None + assert sref.stmt is None + + +def test_replace_partial_copy1(): + s, target = replace_ir_builder(deep_copy=True) + func_old_hash = s.mod["main"].__hash__() + hold_ref = s.mod["main"].body.block.body[0].body + stmt_old_hash = s.mod["main"].body.block.body[0].__hash__() + sref = s.get_sref(s.mod["main"].body.block.body[0].body.body.block) + other_part_hash = s.mod["main"].body.block.body[1].__hash__() + s.replace(sref, target) + # The parent stmt will change since there is only one reference + assert stmt_old_hash == s.mod["main"].body.block.body[0].__hash__() + assert not tvm.ir.structural_equal(hold_ref.body, target) + # The function and the other part stmt can be directly written + assert func_old_hash == s.mod["main"].__hash__() + assert other_part_hash == s.mod["main"].body.block.body[1].__hash__() + # Check the replaced part is equal to the target + tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[0].body.body.block, target) + # The replaced AST node will be deleted, so the ref will be None + assert sref.stmt is None + + +def test_replace_root_write(): + s, target = replace_ir_builder() + old_hash = s.mod["main"].__hash__() + sref = s.get_sref(s.mod["main"].body.block) + s.replace(sref, target) + # Check no copy and the new body equals to target + assert old_hash == s.mod["main"].__hash__() + tvm.ir.assert_structural_equal(s.mod["main"].body.block, target) + + +def test_replace_root_copy0(): + s, target = replace_ir_builder(deep_copy=True) + old_hash = s.mod["main"].__hash__() + func_ref = s.mod["main"] + sref = s.get_sref(s.mod["main"].body.block) + s.replace(sref, target) + # Check the new body equals to target + assert old_hash != s.mod["main"].__hash__() + tvm.ir.assert_structural_equal(s.mod["main"].body.block, target) + # Check the original func remains unchanged + assert old_hash == func_ref.__hash__() + assert not tvm.ir.structural_equal(func_ref.body, target) + + +def test_replace_root_copy1(): + s, target = replace_ir_builder(deep_copy=True, realize=True) + old_hash = s.mod["main"].body.block.__hash__() + func_ref = s.mod["main"].body.block + sref = s.get_sref(s.mod["main"].body.block.body[0]) + s.replace(sref, target) + # Check the new body equals to target + assert old_hash != s.mod["main"].body.block.__hash__() + tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[0], target) + # Check the original func remains unchanged + assert old_hash == func_ref.__hash__() + assert not tvm.ir.structural_equal(func_ref.body, target) + + +def test_replace_root_copy2(): + s, target = replace_ir_builder(deep_copy=True) + old_hash = s.mod.functions.__hash__() + func_ref = s.mod.functions + sref = s.get_sref(s.mod["main"].body.block) + s.replace(sref, target) + # Check the new body equals to target + assert old_hash != s.mod.functions.__hash__() + tvm.ir.assert_structural_equal(s.mod["main"].body.block, target) + # Check the original func remains unchanged + assert old_hash == func_ref.__hash__() + for _, v in func_ref.items(): + assert not tvm.ir.structural_equal(v.body.block, target) + + +def test_replace_root_copy3(): + s, target = replace_ir_builder(deep_copy=True) + old_hash = s.mod.__hash__() + func_ref = s.mod + sref = s.get_sref(s.mod["main"].body.block) + s.replace(sref, target) + # Check the new body equals to target + assert old_hash != s.mod.__hash__() + tvm.ir.assert_structural_equal(s.mod["main"].body.block, target) + # Check the original func remains unchanged + assert old_hash == func_ref.__hash__() + assert not tvm.ir.structural_equal(func_ref["main"].body.block, target) + + +def test_replace_block_remap(): + func = elementwise + s = tir.ScheduleState(func, debug_mode=True) + # The target stmt + target = matmul.body.block.body.body.body[0].block + sref = s.get_sref(s.mod["main"].body.block.body[0].body.body.block) + s.replace(sref, target, {sref.stmt: target}) + sref_new = s.get_sref(s.mod["main"].body.block.body[0].body.body.block) + # Check the original sref has been remapped + assert sref.__hash__() == sref_new.__hash__() + tvm.ir.assert_structural_equal(sref.stmt, target) + + +def test_replace_block_in_opaque_block(): + s = replace_ir_builder_with_opaque() + root_hash = s.mod["main"].__hash__() + for_loop = s.mod["main"].body.block.body.body.block.body[1].then_case.block.body + sref = s.get_sref(for_loop) + new_for_loop = tir.For( + loop_var=for_loop.loop_var, + min_val=0, + extent=128, + kind=tir.ForKind.SERIAL, + body=tir.Evaluate(0), + thread_binding=None, + annotations=None, + ) + s.replace(sref, new_for_loop) + assert root_hash == s.mod["main"].__hash__() + tvm.ir.assert_structural_equal(sref.stmt, new_for_loop) + + +def test_replace_ir_module(): + s, target = replace_ir_builder_module(deep_copy=True) + old_hash = s.mod["main"].__hash__() + other_func_hash = s.mod["other"].__hash__() + func_ref = s.mod["main"] + sref = s.get_sref(s.mod["main"].body.block) + s.replace(sref, target) + # Check the new body equals to target + assert old_hash != s.mod["main"].__hash__() + tvm.ir.assert_structural_equal(s.mod["main"].body.block, target) + # Check the original func remains unchanged + assert old_hash == func_ref.__hash__() + assert not tvm.ir.structural_equal(func_ref.body, target) + assert other_func_hash == s.mod["other"].__hash__() + + +if __name__ == "__main__": + test_replace_direct_write0() + test_replace_direct_write1() + test_replace_copy() + test_replace_partial_copy0() + test_replace_partial_copy1() + test_replace_root_write() + test_replace_root_copy0() + test_replace_root_copy1() + test_replace_root_copy2() + test_replace_root_copy3() + test_replace_block_remap() + test_replace_block_in_opaque_block() + test_replace_ir_module()