diff --git a/include/tvm/tir/usmp/utils.h b/include/tvm/tir/usmp/utils.h new file mode 100644 index 000000000000..32a2bc6e292d --- /dev/null +++ b/include/tvm/tir/usmp/utils.h @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/usmp/utils.h + * \brief Utilities for Unified Static Memory Planner + */ + +#ifndef TVM_TIR_USMP_UTILS_H_ +#define TVM_TIR_USMP_UTILS_H_ + +#include +#include +#include + +namespace tvm { +namespace tir { +namespace usmp { + +/*! + * \brief The string parameter to indicate read and write access to a pool + * This needs to be kept in sync with PoolInfo.READ_WRITE_ACCESS in + * python/tvm/tir/usmp/utils.py + */ +static constexpr const char* kTargetPoolReadWriteAccess = "rw"; +/*! + * \brief The string parameter to indicate read only access to a pool + * This needs to be kept in sync with PoolInfo.READ_ONLY_ACCESS in + * python/tvm/tir/usmp/utils.py + */ +static constexpr const char* kTargetPoolReadOnlyAccess = "ro"; + +/*! + * \brief Describes a pool of memory accessible by one or more targets. + */ +struct PoolInfoNode : public Object { + /*! \brief The name of the memory pool */ + String pool_name; + /*! \brief The expected size hint to be used by the allocator. + * The size_hint_bytes is defaulted to kUnrestrictedPoolSizeHint + * to indicate the pool is not size restricted. + */ + Integer size_hint_bytes; + /*! \brief The accessibility from each Target*/ + Map target_access; // 'rw' or 'ro' + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pool_name", &pool_name); + v->Visit("size_hint_bytes", &size_hint_bytes); + v->Visit("target_access", &target_access); + } + + bool SEqualReduce(const PoolInfoNode* other, SEqualReducer equal) const { + return equal(pool_name, other->pool_name) && equal(size_hint_bytes, other->size_hint_bytes) && + equal(target_access, other->target_access); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(pool_name); + hash_reduce(size_hint_bytes); + hash_reduce(target_access); + } + + static constexpr const char* _type_key = "tir.usmp.PoolInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(PoolInfoNode, Object); +}; + +/*! + * \brief The PoolSize is unrestricted for the memory planner + */ +static const int kUnrestrictedPoolSizeHint = -1; + +class PoolInfo : public ObjectRef { + public: + TVM_DLL PoolInfo(String pool_name, Map target_access, + Integer size_hint_bytes = kUnrestrictedPoolSizeHint); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PoolInfo, ObjectRef, PoolInfoNode); +}; + +/*! + * \brief Describes an abstract memory buffer that will get allocated inside a pool. + * The actual memory buffer in represented by PoolAllocationNode after static memory planning. + * + * See also for relay-level counterparts: + * relay::StorageToken (graph_plan_memory.cc) + * relay::backend::StorageInfoNode (relay/backend/utils.h) + * Region (python/tvm/relay/transform/memory_plan.py) + */ +struct BufferInfoNode : public Object { + /*! \brief The name of the buffer var */ + String name_hint; + /*! \brief The size in terms of bytes */ + Integer size_bytes; + /*! \brief The pool candidates that this buffer can get pooled to*/ + Array pool_candidates; + /*! \brief The byte alignment required for buffers that will placed within the pool */ + Integer alignment; + /*! \brief The liveness conflicting other buffer info objects */ + Array conflicts; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name_hint", &name_hint); + v->Visit("size_bytes", &size_bytes); + v->Visit("pool_candidates", &pool_candidates); + v->Visit("alignment", &alignment); + v->Visit("conflicts", &conflicts); + } + + bool SEqualReduce(const BufferInfoNode* other, SEqualReducer equal) const { + return equal(name_hint, other->name_hint) && equal(size_bytes, other->size_bytes) && + equal(pool_candidates, other->pool_candidates) && equal(alignment, other->alignment) && + equal(conflicts, other->conflicts); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(name_hint); + hash_reduce(size_bytes); + hash_reduce(alignment); + hash_reduce(conflicts); + hash_reduce(pool_candidates); + } + /*! + * \brief Set the liveness conflicts of this BufferInfo + * + * \param conflicting_buffer_info_objs An array of BufferInfo that conflicts in liveness + */ + TVM_DLL void SetConflicts(Array conflicting_buffer_info_objs); + + static constexpr const char* _type_key = "tir.usmp.BufferInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(BufferInfoNode, Object); +}; + +class BufferInfo : public ObjectRef { + public: + TVM_DLL BufferInfo(String name_hint, Integer size_bytes, Array pool_candidates, + Integer alignment = runtime::kDefaultWorkspaceAlignment); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BufferInfo, ObjectRef, BufferInfoNode); +}; + +/*! + * \brief The pool allocation produced after the USMP algorithm + */ +struct PoolAllocationNode : public Object { + /*! \brief The assigned PoolInfo object */ + PoolInfo pool_info; + /*! \brief The byte offset where the tensor is supposed to be placed within the pool*/ + Integer byte_offset; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pool_info", &pool_info); + v->Visit("byte_offset", &byte_offset); + } + + bool SEqualReduce(const PoolAllocationNode* other, SEqualReducer equal) const { + return equal(pool_info, other->pool_info) && equal(byte_offset, other->byte_offset); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(pool_info); + hash_reduce(byte_offset); + } + + static constexpr const char* _type_key = "tir.usmp.PoolAllocation"; + TVM_DECLARE_FINAL_OBJECT_INFO(PoolAllocationNode, Object); +}; + +class PoolAllocation : public ObjectRef { + public: + TVM_DLL PoolAllocation(PoolInfo pool_info, Integer byte_offset); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PoolAllocation, ObjectRef, PoolAllocationNode); +}; + +/*! + * \brief Convert the IR-bound BufferInfo map to an array of BufferInfo + * + * \param buffer_info_map IR-bound BufferInfo map + */ +Array CreateArrayBufferInfo(const Map& buffer_info_map); + +/*! + * \brief The allocate node attribute to indicate candidate memory pools. + * This needs to be kept in sync with CANDIDATE_MEMORY_POOL_ATTR in + * python/tvm/tir/usmp/utils.py. + */ +static constexpr const char* kPoolCandidatesAllocateAttr = "candidate_memory_pools"; + +/*! + * \brief Calculate the size of the extents in bytes + * + * \param op the allocate node + */ +Integer CalculateExtentsSize(const AllocateNode* op); + +} // namespace usmp +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_USMP_UTILS_H_ diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py index 4750ad7626e2..0ce02d4cc244 100644 --- a/python/tvm/script/tir/scope_handler.py +++ b/python/tvm/script/tir/scope_handler.py @@ -110,6 +110,7 @@ def __init__(self): def allocate(extents, dtype, scope, condition=True, annotations=None, span=None): condition = tvm.runtime.convert(condition) scope = tvm.runtime.convert(scope) + return tvm.tir.Allocate( self.buffer_var, dtype, diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 428403a98f16..07ceb29ebf98 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -55,3 +55,4 @@ from . import transform from . import analysis from . import stmt_functor +from . import usmp diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 978c630b17ad..a71476b23e44 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -411,6 +411,7 @@ def allocate(self, dtype, shape, name="buf", scope=""): scope : str, optional The scope of the buffer. + Returns ------- buffer : BufferVar diff --git a/python/tvm/tir/usmp/__init__.py b/python/tvm/tir/usmp/__init__.py new file mode 100644 index 000000000000..8aa0d4ccfe88 --- /dev/null +++ b/python/tvm/tir/usmp/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import, redefined-builtin +"""Namespace for Unified Static Memory Planner""" + +from . import analysis +from .utils import BufferInfo diff --git a/python/tvm/tir/usmp/_ffi_api.py b/python/tvm/tir/usmp/_ffi_api.py new file mode 100644 index 000000000000..5899ef0c86ea --- /dev/null +++ b/python/tvm/tir/usmp/_ffi_api.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.tir.usmp""" +import tvm._ffi + + +tvm._ffi._init_api("tir.usmp", __name__) diff --git a/python/tvm/tir/usmp/analysis/__init__.py b/python/tvm/tir/usmp/analysis/__init__.py new file mode 100644 index 000000000000..756e8c7204c5 --- /dev/null +++ b/python/tvm/tir/usmp/analysis/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import, redefined-builtin +"""Namespace for Unified Static Memory Planner""" + +from .analysis import extract_buffer_info diff --git a/python/tvm/tir/usmp/analysis/_ffi_api.py b/python/tvm/tir/usmp/analysis/_ffi_api.py new file mode 100644 index 000000000000..36973f19905c --- /dev/null +++ b/python/tvm/tir/usmp/analysis/_ffi_api.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.tir.usmp.analysis""" +import tvm._ffi + + +tvm._ffi._init_api("tir.usmp.analysis", __name__) diff --git a/python/tvm/tir/usmp/analysis/analysis.py b/python/tvm/tir/usmp/analysis/analysis.py new file mode 100644 index 000000000000..ff70355a967b --- /dev/null +++ b/python/tvm/tir/usmp/analysis/analysis.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""USMP Analysis Python API for passes""" +# pylint: disable=invalid-name +from . import _ffi_api +from ...function import PrimFunc +from ....ir.module import IRModule + + +def extract_buffer_info(main_func: PrimFunc, mod: IRModule): + """Convert Parallel For Loop to Serial. + + Parameters + ---------- + main_func: tvm.tir.PrimFunc + The main function containing calls to operator PrimFuncs. + mod : tvm.ir.IRModule + The full IRModule containing all PrimFuncs + + Returns + ------- + Map + extracted buffer info objects + """ + return _ffi_api.extract_buffer_info(main_func, mod) diff --git a/python/tvm/tir/usmp/utils.py b/python/tvm/tir/usmp/utils.py new file mode 100644 index 000000000000..0445775869e8 --- /dev/null +++ b/python/tvm/tir/usmp/utils.py @@ -0,0 +1,150 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""USMP Utilities and Data Structures""" +# pylint: disable=invalid-name + +from typing import Dict, Optional, List + +from tvm._ffi import register_object +from tvm.runtime import Object +from tvm.target import Target +from . import _ffi_api + + +# The allocate node attribute to indicate candidate memory pools. +# This needs to be kept in sync with CANDIDATE_MEMORY_POOL_ATTR in +# include/tvm/tir/usmp/utils.h +CANDIDATE_MEMORY_POOL_ATTR = "candidate_memory_pools" + + +@register_object("tir.usmp.PoolInfo") +class PoolInfo(Object): + """PoolInfo object holds information related to memory pools + where the statically sized allocate nodes will pooled into. + + Parameters + ---------- + pool_name : str + The name of the memory pool + + target_access : Dict[Target, str] + A dictionary where keys describe which targets could + access the pool where value could take the values : + a) "rw" : read-write access + b) "ro" : write-only acesss + + size_hint_bytes : Optional[int] + The expected size hint to be used by the allocator. + The default value would be -1 which means the pool + is not size restricted. + + """ + + # The string parameter to indicate read and write access to a pool + # This needs to be kept in sync with kTargetPoolReadWriteAccess in + # include/tvm/tir/usmp/utils.h + READ_WRITE_ACCESS = "rw" + # The string parameter to indicate read only access to a pool + # This needs to be kept in sync with kTargetPoolReadOnlyAccess in + # include/tvm/tir/usmp/utils.h + READ_ONLY_ACCESS = "ro" + + def __init__( + self, + pool_name: str, + target_access: Dict[Target, str], + size_hint_bytes: Optional[int] = None, + ): + self.__init_handle_by_constructor__( + _ffi_api.PoolInfo, # type: ignore # pylint: disable=no-member + pool_name, + target_access, + size_hint_bytes, + ) + + +@register_object("tir.usmp.BufferInfo") +class BufferInfo(Object): + """BufferInfo object holds information related to buffers + that are associated with tir.allocates and tir.allocate_consts + that will be used with USMP + + Parameters + ---------- + name_hint : str + The name associated with the buffer (derived from TIR) + + size_bytes : int + The size in bytes + + pool_candidates : List[PoolInfo] + The list of candidates pools this buffer could be placed + + alignment : Optional[int] + The byte alignment required in the workspace memory + + """ + + def __init__( + self, + name_hint: str, + size_bytes: int, + pool_candidates: List[PoolInfo], + alignment: Optional[int] = None, + ): + self.__init_handle_by_constructor__( + _ffi_api.BufferInfo, # type: ignore # pylint: disable=no-member + name_hint, + size_bytes, + pool_candidates, + alignment, + ) + + def set_pool_candidates(self, pool_candidates: list): + """Sets the pool candidate names""" + _ffi_api.BufferInfoSetPoolCandidates(self, pool_candidates) + + def set_pool_offsets(self, pool_name: str, pool_offset: int): + """Sets the pool offset by name""" + _ffi_api.BufferInfoSetPoolOffset(self, pool_name, pool_offset) + + def set_conflicts(self, conflicts: list): + """Sets the the conflicting array of buffer info objects""" + _ffi_api.BufferInfoSetConflicts(self, conflicts) + + +@register_object("tir.usmp.PoolAllocation") +class PoolAllocation(Object): + """PoolAllocation object holds information related to an allocation + that indicates an offset in a pool + + Parameters + ---------- + pool_info : PoolInfo + The PoolInfo to which this allocation corresponds to + + byte_offset : int + The offset in the pool where the allocate node should be placed + + """ + + def __init__(self, pool_info: PoolInfo, byte_offset: int): + self.__init_handle_by_constructor__( + _ffi_api.PoolAllocation, # type: ignore # pylint: disable=no-member + pool_info, + byte_offset, + ) diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc new file mode 100644 index 000000000000..c25578fd9779 --- /dev/null +++ b/src/tir/usmp/analysis/extract_buffer_info.cc @@ -0,0 +1,446 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/analysis/usmp/extract_buffer_info.cc + * + * \brief This analysis pass consumes a TIR IRModule with a main function + * that defines a ordering in the callees to operators and produces BufferInfo + * objects that contains information about tir.allocate nodes and liveness + * conflicts between other tir.allocate nodes. + */ +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace tir { +namespace usmp { + +/*! + * \brief The visitor class to obtain buffer information + * + * The visitor would initiate the traversal from the main + * function and visits into the operator PrimFuncs. It will + * crate unique BufferInfo objects for each Allocate node. + * + * Every time the buffer variable of the allocate node is referenced + * it will be recorded using the stmt index. However, note that + * the same buffer variable could be references multiple times + * from different calls. Thereafter, a sweep is done on all the + * BufferInfo objects using the per-call liveness events. In the sweep, + * The BufferInfo objects that are live together will be recorded as + * mutual conflicts of each other. + */ +class BufferInfoExtractor : public StmtExprVisitor { + public: + explicit BufferInfoExtractor(const IRModule& module) : module_(module) { + for (const auto& gv_func : module_->functions) { + functions_.Set(gv_func.first->name_hint, Downcast(gv_func.second)); + } + // Pushing a scope info for the initial body of the main function + scope_stack_.push(ScopeInfo()); + } + Map operator()(const PrimFunc& func); + + private: + void VisitStmt(const Stmt& n) override; + void VisitStmt_(const AllocateNode* op) override; + void VisitExpr_(const CallNode* op) override; + void VisitExpr_(const VarNode* op) override; + void VisitExpr_(const LoadNode* op) override; + void VisitStmt_(const StoreNode* op) override; + void VisitStmt_(const ForNode* op) override; + + void UpdateAliases(const Array& args, const PrimFunc& func); + void RecordAllocateNodeInfo(const AllocateNode* op); + void VisitPrimFunc(const PrimFunc& func, const Call& call); + + /*! + * \brief Maintains the mapping of BufferInfo to their associated TIR Statements. + */ + Map buffer_info_map_; + /*! + * \brief Records the order of calls in the main for stability. + */ + std::set call_order_; + /*! + * \brief Records first access in-terms of Stmts to each buffer per call + * + * This is because multiple calls could happen to the same PrimFunc. + */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + buffer_info_start_stmt_idx_; + /*! + * \brief Records last access in-terms of Stmts to each buffer per call + * + * This is because multiple calls could happen to the same PrimFunc. + */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + buffer_info_end_stmt_idx_; + /*! + * \brief Maintains the mapping of buffer variable to their allocate nodes to ensure + * that only one BufferInfo object is created. + */ + Map allocate_var_to_stmt_map_; + /*! + * \brief Indicates a count of stmts visited so far to use as a metric of liveness + */ + int current_stmt_idx_ = 0; + /*! + * \brief This structure is supposed to contain information around the scope + * the visitor is currently in. + */ + struct ScopeInfo { + /*! + * \brief We need to record access per call + */ + Call call; + /*! + * \brief Having access to PrimFunc metadata is useful + */ + PrimFunc func; + /*! + * \brief We currently support only serial for loops. Therefore + * need to know what kind of for loop the visitor is in. + */ + For for_loop; + /*! + * \brief We record the live allocate_nodes because once in loops + * the liveness range has to be extended to the whole of the nested + * loops structure. + */ + std::unordered_set allocate_nodes; + /*! + * \brief This is recorded to extend the liveness of all allocates within + * nested loop structure. + */ + Integer initial_stmt_of_the_nested_loops; + }; + std::stack scope_stack_; + + /*! + * \brief A liveness event is an event that when + * traversing the tir.Stmts where tir.allocate node + * begins or ceases to be Live. This particular struct + * is used to solve interval overlap problem using + * a sweep-line algorithm. For that, we need to record + * where the liveness event occurred in a chronological + * order. + */ + enum LivenessEventType { START = 0, END = 1 }; + struct LivenessEvent { + size_t tick; + LivenessEventType le_type; + BufferInfo buffer_info; + bool operator==(const LivenessEvent& other) { + if (tick == other.tick && le_type == other.le_type && buffer_info == other.buffer_info) { + return true; + } + return false; + } + }; + /*! + * \brief We need to create unique buffer name is the same name is used in + * two allocate nodes for clarity for memory planning algorithms. + */ + std::string GetUniqueBufferName(std::string name); + + /*! + * \brief This is per buffer name counter to aid the generating the above + * unique name. + */ + std::unordered_map buffer_names; + /*! + * \brief The TIR main function calls by name to PrimFuncs to be able to + * support BYOC. Therefore, this Map records functions that are present + * in the IRModule by name/ + */ + Map functions_; + /*! + * \brief The IRModule being analyzed. + */ + IRModule module_; +}; + +std::string BufferInfoExtractor::GetUniqueBufferName(std::string name) { + if (buffer_names.find(name) == buffer_names.end()) { + buffer_names[name] = 1; + return name; + } else { + buffer_names[name] = buffer_names[name] + 1; + return name + std::to_string(buffer_names[name]); + } +} + +void BufferInfoExtractor::VisitStmt(const Stmt& n) { + current_stmt_idx_ += 1; + StmtExprVisitor::VisitStmt(n); +} + +void BufferInfoExtractor::RecordAllocateNodeInfo(const AllocateNode* op) { + auto size_bytes = CalculateExtentsSize(op); + // We only statically memory plan only allocates with known + // compile time sizes. + if (size_bytes.defined() && + allocate_var_to_stmt_map_.find(op->buffer_var) == allocate_var_to_stmt_map_.end()) { + // By default, the core compiler is assumed to attach the a default pool to each allocate. + ICHECK(op->annotations.count(kPoolCandidatesAllocateAttr)) + << "Every statically sized allocate node needs an pool candidate attribute"; + auto pool_candidates = Downcast>(op->annotations[kPoolCandidatesAllocateAttr]); + + // TODO(@manupa-arm): improve the error when the responsible component for attaching a single + // pool is added + ICHECK(pool_candidates.size() > 0) + << "The core compiler should at least attach a single PoolInfo. If there were no " + "user-given arguments for memory pools, the default behaviour is a single size " + "un-restricted pool is assigned"; + PrimFunc func = scope_stack_.top().func; + Optional tgt = func->GetAttr(tvm::attr::kTarget); + ICHECK(tgt) << "There should not be any PrimFuncs without a target attached by now"; + auto workspace_alignment = + tgt.value()->GetAttr("workspace-byte-alignment").value_or(16); + auto buffer_info = BufferInfo(GetUniqueBufferName(op->buffer_var->name_hint), size_bytes, + pool_candidates, workspace_alignment); + auto allocate = GetRef(op); + allocate_var_to_stmt_map_.Set(op->buffer_var, allocate); + buffer_info_map_.Set(buffer_info, allocate); + } +} + +void BufferInfoExtractor::VisitStmt_(const AllocateNode* op) { + ScopeInfo& current_scope_info = scope_stack_.top(); + const auto& type = Downcast(op->buffer_var->type_annotation); + const auto& storage_scope = type->storage_scope; + + // If the allocate is in a for loop, USMP currently only looks at serial for loops. + // If its not a serial for loop, then memory planner will omit them in the current memory planning + // process leaving them to as tir.allocate nodes for codegen. Additionally, the USMP can only work + // with buffers that have global storage_scope + + if (!current_scope_info.for_loop.defined()) { + RecordAllocateNodeInfo(op); + } else if (current_scope_info.for_loop.defined() && + current_scope_info.for_loop->kind == ForKind::kSerial && storage_scope == "global") { + RecordAllocateNodeInfo(op); + } + StmtExprVisitor::VisitStmt(op->body); + current_scope_info.allocate_nodes.erase(GetRef(op)); +} + +void BufferInfoExtractor::VisitStmt_(const ForNode* op) { + ScopeInfo si{scope_stack_.top().call, scope_stack_.top().func, GetRef(op), + scope_stack_.top().allocate_nodes, + scope_stack_.top().initial_stmt_of_the_nested_loops}; + if (!scope_stack_.top().initial_stmt_of_the_nested_loops.defined()) { + si.initial_stmt_of_the_nested_loops = Integer(current_stmt_idx_); + } + Call current_call = scope_stack_.top().call; + scope_stack_.push(si); + StmtExprVisitor::VisitStmt_(op); + // Extending the liveness to beginning of for-loop next and end of the current for-loop + for (const Allocate& allocate : scope_stack_.top().allocate_nodes) { + if (scope_stack_.top().initial_stmt_of_the_nested_loops->value < + buffer_info_start_stmt_idx_[current_call][allocate]) { + buffer_info_start_stmt_idx_[current_call].Set( + allocate, scope_stack_.top().initial_stmt_of_the_nested_loops->value); + } + if (current_stmt_idx_ > buffer_info_end_stmt_idx_[current_call][allocate]) { + buffer_info_end_stmt_idx_[current_call].Set(allocate, current_stmt_idx_); + } + } + scope_stack_.pop(); +} + +void BufferInfoExtractor::VisitExpr_(const LoadNode* op) { + this->VisitExpr(op->buffer_var); + StmtExprVisitor::VisitExpr_(op); +} + +void BufferInfoExtractor::VisitStmt_(const StoreNode* op) { + this->VisitExpr(op->buffer_var); + StmtExprVisitor::VisitStmt_(op); +} + +void BufferInfoExtractor::VisitExpr_(const VarNode* op) { + auto var = GetRef(op); + Call current_call = scope_stack_.top().call; + if (allocate_var_to_stmt_map_.count(var)) { + auto allocate = allocate_var_to_stmt_map_[var]; + if (buffer_info_start_stmt_idx_[current_call].count(allocate) == 0) { + buffer_info_start_stmt_idx_[current_call].Set(allocate, current_stmt_idx_); + } + buffer_info_end_stmt_idx_[current_call].Set(allocate, current_stmt_idx_); + + ScopeInfo& currect_scope_info = scope_stack_.top(); + if (currect_scope_info.for_loop.defined()) { + currect_scope_info.allocate_nodes.insert(Downcast(allocate)); + } + } + StmtExprVisitor::VisitExpr_(op); +} + +Array static GetMatchedBuffers(const PrimFunc& func) { + Array buffer_vars; + for (const auto& param : func->params) { + buffer_vars.push_back(func->buffer_map[param]->data); + } + return buffer_vars; +} + +void BufferInfoExtractor::UpdateAliases(const Array& args, const PrimFunc& func) { + auto param_buffers = GetMatchedBuffers(func); + ICHECK(args.size() == param_buffers.size()); + for (size_t i = 0; i < args.size(); i++) { + auto arg = args[i]; + auto param_buf = param_buffers[i]; + // If tir.allocates are passed in to functions + // The function params are re-directed to point + // to the original allocate + if (arg->IsInstance()) { + auto load = Downcast(arg); + if (allocate_var_to_stmt_map_.count(load->buffer_var)) { + allocate_var_to_stmt_map_.Set(param_buf, allocate_var_to_stmt_map_[load->buffer_var]); + } + } else if (arg->IsInstance()) { + auto var = Downcast(arg); + if (allocate_var_to_stmt_map_.count(var)) { + allocate_var_to_stmt_map_.Set(param_buf, allocate_var_to_stmt_map_[var]); + } + } + } +} + +void BufferInfoExtractor::VisitPrimFunc(const PrimFunc& func, const Call& call) { + ScopeInfo si{call, func, scope_stack_.top().for_loop, scope_stack_.top().allocate_nodes, + scope_stack_.top().initial_stmt_of_the_nested_loops}; + call_order_.insert(call); + scope_stack_.push(si); + this->VisitStmt(func->body); + scope_stack_.pop(); +} + +void BufferInfoExtractor::VisitExpr_(const CallNode* op) { + if (op->op.same_as(builtin::call_extern()) || op->op.same_as(builtin::tvm_call_cpacked())) { + StringImm func_name = Downcast(op->args[0])->value; + if (functions_.find(func_name->value) != functions_.end()) { + auto func = functions_.at(func_name->value); + auto actual_args = Array(op->args.begin() + 1, op->args.end()); + this->UpdateAliases(actual_args, func); + VisitPrimFunc(func, GetRef(op)); + return; + } + } + if (op->op->IsInstance()) { + auto func = Downcast(op->op); + this->UpdateAliases(op->args, func); + VisitPrimFunc(func, GetRef(op)); + return; + } + StmtExprVisitor::VisitExpr_(op); +} + +Map BufferInfoExtractor::operator()(const PrimFunc& main_func) { + VisitPrimFunc(main_func, Call()); + + // Create a vector of liveness events + // associated with each BufferNodes. + std::vector le_events_timeline; + for (const auto& kv1 : buffer_info_map_) { + if (!kv1.second->IsInstance()) { + continue; + } + auto allocate = Downcast(kv1.second); + auto buffer_info = Downcast(kv1.first); + + ICHECK(call_order_.size() >= buffer_info_end_stmt_idx_.size()); + ICHECK(call_order_.size() >= buffer_info_end_stmt_idx_.size()); + + for (const Call& call : call_order_) { + Map buffer_info_starts = buffer_info_start_stmt_idx_[call]; + if (buffer_info_starts.find(allocate) != buffer_info_starts.end()) { + LivenessEvent le_event_start; + le_event_start.buffer_info = buffer_info; + le_event_start.le_type = START; + le_event_start.tick = buffer_info_starts[allocate]; + le_events_timeline.push_back(le_event_start); + } + } + + for (const Call& call : call_order_) { + Map buffer_info_ends = buffer_info_end_stmt_idx_[call]; + if (buffer_info_ends.find(allocate) != buffer_info_ends.end()) { + LivenessEvent le_event_end; + le_event_end.buffer_info = buffer_info; + le_event_end.le_type = END; + le_event_end.tick = buffer_info_ends[allocate]; + le_events_timeline.push_back(le_event_end); + } + } + } + + // Sort the liveness events based on the chronological + // ordering. For events that are simultaneous, START event + // takes precedence. + std::sort(le_events_timeline.begin(), le_events_timeline.end(), + [](const LivenessEvent& lhs, const LivenessEvent& rhs) { + if (lhs.tick < rhs.tick) { + return true; + } else if (lhs.tick == rhs.tick && lhs.le_type == START && rhs.le_type == END) { + return true; + } + return false; + }); + + // Traverse the liveness events using a open set to track what + // is live while updating the conflicts through out the linear traversal + std::unordered_set open_set; + for (const auto& le_event : le_events_timeline) { + if (le_event.le_type == START) { + for (const auto& open_buffer_info : open_set) { + open_buffer_info->conflicts.push_back(le_event.buffer_info); + if (le_event.buffer_info != open_buffer_info) { + le_event.buffer_info->conflicts.push_back(open_buffer_info); + } + } + open_set.insert(le_event.buffer_info); + } else { + open_set.erase(le_event.buffer_info); + } + } + return this->buffer_info_map_; +} + +Map ExtractBufferInfo(const PrimFunc& main_func, const IRModule& mod) { + return BufferInfoExtractor(mod)(main_func); +} + +TVM_REGISTER_GLOBAL("tir.usmp.analysis.extract_buffer_info") + .set_body_typed([](PrimFunc main_func, IRModule mod) { + return (ExtractBufferInfo(main_func, mod)); + }); + +} // namespace usmp +} // namespace tir +} // namespace tvm diff --git a/src/tir/usmp/utils.cc b/src/tir/usmp/utils.cc new file mode 100644 index 000000000000..b7177cc1635b --- /dev/null +++ b/src/tir/usmp/utils.cc @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/usmp/utils.cc + * \brief Utilities for Unified Static Memory Planner + */ + +#include +#include +#include +#include + +namespace tvm { +namespace tir { +namespace usmp { + +BufferInfo::BufferInfo(String name_hint, Integer size_bytes, Array pool_candidates, + Integer alignment) { + auto bufinfo_node = make_object(); + bufinfo_node->name_hint = name_hint; + bufinfo_node->size_bytes = size_bytes; + bufinfo_node->pool_candidates = pool_candidates; + bufinfo_node->alignment = alignment; + data_ = std::move(bufinfo_node); +} + +void BufferInfoNode::SetConflicts(Array conflicting_buffer_info_objs) { + this->conflicts = conflicting_buffer_info_objs; +} + +TVM_REGISTER_NODE_TYPE(BufferInfoNode); +TVM_REGISTER_GLOBAL("tir.usmp.BufferInfo") + .set_body_typed([](String name_hint, Integer size_bytes, Array pool_candidates, + Integer alignment) { + if (!alignment.defined()) { + return BufferInfo(name_hint, size_bytes, pool_candidates); + } + return BufferInfo(name_hint, size_bytes, pool_candidates, alignment); + }); +TVM_REGISTER_GLOBAL("tir.usmp.BufferInfoSetConflicts") + .set_body_method(&BufferInfoNode::SetConflicts); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "BufferInfoNode(\n" + << "name_hint=" << node->name_hint << ",\n size_bytes=" << node->size_bytes + << ",\n pool_candidates=" << node->pool_candidates + << ",\n alignment=" << node->alignment << ")"; + }); + +PoolInfo::PoolInfo(String pool_name, Map target_access, Integer size_hint_bytes) { + auto poolinfo_node = make_object(); + poolinfo_node->pool_name = pool_name; + poolinfo_node->size_hint_bytes = size_hint_bytes; + poolinfo_node->target_access = target_access; + data_ = std::move(poolinfo_node); +} + +TVM_REGISTER_NODE_TYPE(PoolInfoNode); +TVM_REGISTER_GLOBAL("tir.usmp.PoolInfo") + .set_body_typed([](String pool_name, Map target_access, + Integer size_hint_bytes) { + if (size_hint_bytes.defined()) { + return PoolInfo(pool_name, target_access, size_hint_bytes); + } + return PoolInfo(pool_name, target_access); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "PoolInfoNode(\n" + << "pool_name=" << node->pool_name << ",\n target_access=" << node->target_access + << ",\n size_hint_bytes=" << node->size_hint_bytes << ")"; + }); + +PoolAllocation::PoolAllocation(PoolInfo pool_info, Integer byte_offset) { + auto pool_allocation_node = make_object(); + pool_allocation_node->pool_info = pool_info; + pool_allocation_node->byte_offset = byte_offset; + data_ = std::move(pool_allocation_node); +} + +TVM_REGISTER_NODE_TYPE(PoolAllocationNode); +TVM_REGISTER_GLOBAL("tir.usmp.PoolAllocation") + .set_body_typed([](PoolInfo pool_info, Integer byte_offset) { + return PoolAllocation(pool_info, byte_offset); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "PoolAllocationNode(\n" + << "pool_info=" << node->pool_info << ",\n byte_offset=" << node->byte_offset + << ")"; + }); + +Array CreateArrayBufferInfo(const Map& buffer_info_map) { + Array ret; + for (const auto& kv : buffer_info_map) { + auto buffer_info = kv.first; + ret.push_back(buffer_info); + } + return ret; +} + +Integer CalculateExtentsSize(const AllocateNode* op) { + size_t element_size_bytes = op->dtype.bytes(); + size_t num_elements = 1; + for (const auto& ext : op->extents) { + if (ext->IsInstance()) { + num_elements *= Downcast(ext)->value; + } else { + // We can't statically calculate workspace for dynamic shapes + return Integer(); + } + } + return Integer(num_elements * element_size_bytes); +} + +TVM_REGISTER_GLOBAL("tir.usmp.CreateArrayBufferInfo") + .set_body_typed([](Map buffer_info_map) { + return (CreateArrayBufferInfo(buffer_info_map)); + }); + +} // namespace usmp +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py new file mode 100644 index 000000000000..fa645f1379ff --- /dev/null +++ b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py @@ -0,0 +1,1555 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import sys + +import tvm +from tvm import tir, script +from tvm.ir import Range +from tvm.script import tir as T +from tvm.tir import stmt_functor +from tvm.tir import PrimFunc +from tvm.tir.usmp import utils as usmp_utils +from tvm.target import Target + + +def _replace_stmt_with_buf_var_names(buffer_info_map): + """helper to replace tir.allocates with buffer names""" + new_buffer_info_map = dict() + for k, v in buffer_info_map.items(): + new_buffer_info_map[k.name_hint] = k + return new_buffer_info_map + + +def _verify_conflicts(main_buf_name, conflicting_buf_names, buffer_info_map): + """helper to check expected liveness conflicts""" + buf_info = buffer_info_map[main_buf_name] + for conflict in buf_info.conflicts: + assert conflict.name_hint in conflicting_buf_names + + +def _get_allocates(primfunc): + """helper to extract all allocate nodes by name""" + allocates = dict() + + def get_allocate(stmt): + if isinstance(stmt, tvm.tir.Allocate): + allocates[str(stmt.buffer_var.name)] = stmt + + stmt_functor.post_order_visit(primfunc.body, get_allocate) + return allocates + + +def _assign_poolinfos_to_allocates_in_primfunc(primfunc, pool_infos): + """helper to assing poolinfos to allocate nodes in a tir.PrimFunc""" + + def set_poolinfos(stmt): + if isinstance(stmt, tvm.tir.Allocate): + return tvm.tir.Allocate( + buffer_var=stmt.buffer_var, + dtype=stmt.dtype, + extents=stmt.extents, + condition=stmt.condition, + body=stmt.body, + annotations={tvm.tir.usmp.utils.CANDIDATE_MEMORY_POOL_ATTR: pool_infos}, + ) + + return primfunc.with_body(stmt_functor.ir_transform(primfunc.body, None, set_poolinfos)) + + +def _assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos): + """helper to assign poolinfos to allocate nodes in a IRModule""" + ret = tvm.IRModule() + for global_var, basefunc in mod.functions.items(): + if isinstance(basefunc, tvm.tir.PrimFunc): + ret[global_var] = _assign_poolinfos_to_allocates_in_primfunc(basefunc, pool_infos) + return ret + + +def _assign_targets_to_primfuncs_irmodule(mod, target): + """helper to assign target for PrimFunc in a IRModule""" + ret = tvm.IRModule() + for global_var, basefunc in mod.functions.items(): + if isinstance(basefunc, tvm.tir.PrimFunc): + ret[global_var] = basefunc.with_attr("target", target) + return ret + + +# These are test IRModules that contains varied topologies of operator graphs +# that includes a main TIR function that includes call to such operators. + +# fmt: off +@tvm.script.ir_module +class LinearStructure: + @T.prim_func + def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) + placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dTpe="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + for ax0_ax1_fused_1 in T.serial(0, 224): + for ax2_1, ax3_inner_1 in T.grid(224, 3): + T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(T.load("uint8", placeholder_4.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)), "int16") - T.load("int16", placeholder_5.data, 0)), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) + placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_7 = T.allocate([157323], "int16", "global") + for i0_i1_fused_7 in T.serial(0, 229): + for i2_7, i3_7 in T.grid(229, 3): + T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), T.load("int16", placeholder_65.data, ((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)), T.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): + Conv2dOutput_7 = T.allocate([64], "int32", "global") + for ff_3 in T.serial(0, 64): + T.store(Conv2dOutput_7, ff_3, 0, True) + for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): + T.store(Conv2dOutput_7, ff_3, (T.load("int32", Conv2dOutput_7, ff_3) + (T.cast(T.load("int16", PaddedInput_7, (((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)), "int32")*T.cast(T.load("int16", placeholder_66.data, ((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)), "int32"))), True) + for ax3_inner_7 in T.serial(0, 64): + T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_7, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7)), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) + + @T.prim_func + def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) + placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + tensor_2 = T.allocate([200704], "uint8", "global") + for ax0_ax1_fused_4 in T.serial(0, 56): + for ax2_4 in T.serial(0, 56): + for ax3_init in T.serial(0, 64): + T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) + for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): + T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dtype="uint8")), True) + for ax0_ax1_fused_5 in T.serial(0, 56): + for ax2_5, ax3_3 in T.grid(56, 64): + T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True) + + @T.prim_func + def run_model(input: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) + # body + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 1) + sid_9 = T.allocate([301056], "int8", "global") + sid_8 = T.allocate([802816], "int8", "global") + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32")) + __tvm_meta__ = None +# fmt: on + + +def test_linear(): + target = Target("c") + fast_memory_pool = usmp_utils.PoolInfo( + pool_name="fast_memory", target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS} + ) + slow_memory_pool = usmp_utils.PoolInfo( + pool_name="slow_memory", target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS} + ) + tir_mod = LinearStructure + tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) + tir_mod = _assign_poolinfos_to_allocates_in_irmodule( + tir_mod, [fast_memory_pool, slow_memory_pool] + ) + buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(tir_mod["run_model"], tir_mod) + buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map) + + # check conflicts + _verify_conflicts("PaddedInput_7", ["sid_9", "sid_8", "Conv2dOutput_7"], buffer_info_map) + _verify_conflicts("tensor_2", ["sid_8"], buffer_info_map) + _verify_conflicts("sid_9", ["PaddedInput_7"], buffer_info_map) + _verify_conflicts("sid_8", ["PaddedInput_7", "Conv2dOutput_7", "tensor_2"], buffer_info_map) + _verify_conflicts("Conv2dOutput_7", ["sid_8", "PaddedInput_7"], buffer_info_map) + + # check sizes + assert buffer_info_map["sid_8"].size_bytes == 802816 + assert buffer_info_map["Conv2dOutput_7"].size_bytes == 256 + assert buffer_info_map["PaddedInput_7"].size_bytes == 314646 + assert buffer_info_map["tensor_2"].size_bytes == 200704 + assert buffer_info_map["sid_9"].size_bytes == 301056 + + # check_pool_candidates + assert [ + pool_info.pool_name for pool_info in list(buffer_info_map["sid_8"].pool_candidates) + ] == ["fast_memory", "slow_memory"] + + +# fmt: off +@tvm.script.ir_module +class ParallelSerialMixedForLoops: + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placeholder_68: T.handle, placeholder_69: T.handle, placeholder_70: T.handle, T_cast_22: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", "tir.noalias": True}) + placeholder_71 = T.match_buffer(placeholder_68, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_72 = T.match_buffer(placeholder_69, [3, 3, 64, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_73 = T.match_buffer(placeholder_70, [1, 1, 1, 192], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_23 = T.match_buffer(T_cast_22, [1, 56, 56, 192], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_8 = T.allocate([215296], "int16", "global") + for i0_i1_fused_8 in T.serial(0, 58): + for i2_8, i3_8 in T.grid(58, 64): + T.store(PaddedInput_8, (((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8), T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), T.load("int16", placeholder_71.data, ((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)), T.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_8 in T.parallel(0, 3136): + dummy_allocate = T.allocate([1], "int32", "global") + for ax3_outer_4 in T.serial(0, 3): + Conv2dOutput_8 = T.allocate([64], "int32", "global") + for ff_4 in T.serial(0, 64): + T.store(Conv2dOutput_8, ff_4, 0, True) + for ry_3, rx_3, rc_8 in T.grid(3, 3, 64): + T.store(Conv2dOutput_8, ff_4, (T.load("int32", Conv2dOutput_8, ff_4) + (T.cast(T.load("int16", PaddedInput_8, (((((T.floordiv(ax0_ax1_fused_ax2_fused_8, 56)*3712) + (ry_3*3712)) + (rx_3*64)) + (T.floormod(ax0_ax1_fused_ax2_fused_8, 56)*64)) + rc_8)), "int32")*T.cast(T.load("int16", placeholder_72.data, (((((ry_3*36864) + (rx_3*12288)) + (rc_8*192)) + (ax3_outer_4*64)) + ff_4)), "int32"))), True) + for ax3_inner_8 in T.serial(0, 64): + T.store(T_cast_23.data, (((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_8, ax3_inner_8) + T.load("int32", placeholder_73.data, ((ax3_outer_4*64) + ax3_inner_8))), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8"), True) + + @T.prim_func + def run_model(input: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) + # body + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 1) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", input, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), output, dtype="int32")) + + +__tvm_meta__ = None +# fmt: on + + +# fmt: off +@tvm.script.ir_module +class AllSerialForLoops: + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placeholder_68: T.handle, placeholder_69: T.handle, placeholder_70: T.handle, T_cast_22: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", "tir.noalias": True}) + placeholder_71 = T.match_buffer(placeholder_68, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_72 = T.match_buffer(placeholder_69, [3, 3, 64, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_73 = T.match_buffer(placeholder_70, [1, 1, 1, 192], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_23 = T.match_buffer(T_cast_22, [1, 56, 56, 192], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_8 = T.allocate([215296], "int16", "global") + for i0_i1_fused_8 in T.serial(0, 58): + for i2_8, i3_8 in T.grid(58, 64): + T.store(PaddedInput_8, (((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8), T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), T.load("int16", placeholder_71.data, ((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)), T.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_8 in T.serial(0, 3136): + dummy_allocate = T.allocate([1], "int32", "global") + for ax3_outer_4 in T.serial(0, 3): + Conv2dOutput_8 = T.allocate([64], "int32", "global") + for ff_4 in T.serial(0, 64): + T.store(Conv2dOutput_8, ff_4, 0, True) + for ry_3, rx_3, rc_8 in T.grid(3, 3, 64): + T.store(Conv2dOutput_8, ff_4, (T.load("int32", Conv2dOutput_8, ff_4) + (T.cast(T.load("int16", PaddedInput_8, (((((T.floordiv(ax0_ax1_fused_ax2_fused_8, 56)*3712) + (ry_3*3712)) + (rx_3*64)) + (T.floormod(ax0_ax1_fused_ax2_fused_8, 56)*64)) + rc_8)), "int32")*T.cast(T.load("int16", placeholder_72.data, (((((ry_3*36864) + (rx_3*12288)) + (rc_8*192)) + (ax3_outer_4*64)) + ff_4)), "int32"))), True) + for ax3_inner_8 in T.serial(0, 64): + T.store(T_cast_23.data, (((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_8, ax3_inner_8) + T.load("int32", placeholder_73.data, ((ax3_outer_4*64) + ax3_inner_8))), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8"), True) + + @T.prim_func + def run_model(input: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) + # body + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 1) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", input, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), output, dtype="int32")) + + +__tvm_meta__ = None +# fmt: on + + +def test_parallel_serial_mixed_for_loops(): + target = Target("c") + global_ws_pool = usmp_utils.PoolInfo( + pool_name="global_workspace", + target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + ) + all_serial_tir_mod = AllSerialForLoops + all_serial_tir_mod = _assign_targets_to_primfuncs_irmodule(all_serial_tir_mod, target) + all_serial_tir_mod = _assign_poolinfos_to_allocates_in_irmodule( + all_serial_tir_mod, [global_ws_pool] + ) + main_func = all_serial_tir_mod["run_model"] + buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, all_serial_tir_mod) + buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map) + + # When all loops are serial all allocates are touched by USMP + assert len(buffer_info_map) == 3 + for name, _ in buffer_info_map.items(): + assert name in ["dummy_allocate", "Conv2dOutput_8", "PaddedInput_8"] + + parallel_serial_mixed_tir_mod = ParallelSerialMixedForLoops + parallel_serial_mixed_tir_mod = _assign_targets_to_primfuncs_irmodule( + parallel_serial_mixed_tir_mod, target + ) + parallel_serial_mixed_tir_mod = _assign_poolinfos_to_allocates_in_irmodule( + parallel_serial_mixed_tir_mod, [global_ws_pool] + ) + main_func = parallel_serial_mixed_tir_mod["run_model"] + buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info( + main_func, parallel_serial_mixed_tir_mod + ) + buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map) + + # USMP will not touch (yet) the allocates inside parallel for loops + assert len(buffer_info_map) == 2 + for name, _ in buffer_info_map.items(): + assert name in ["Conv2dOutput_8", "PaddedInput_8"] + + +# fmt: off +@tvm.script.ir_module +class InceptionStructure: + @T.prim_func + def tvmgen_default_fused_nn_max_pool2d(placeholder: T.handle, tensor: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d", "tir.noalias": True}) + placeholder_1 = T.match_buffer(placeholder, [1, 56, 56, 192], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + tensor_1 = T.match_buffer(tensor, [1, 28, 28, 192], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + for ax0_ax1_fused in T.serial(0, 28): + for ax2 in T.serial(0, 28): + for ax3_outer_init, ax3_inner_init in T.grid(3, 64): + T.store(tensor_1.data, ((((ax0_ax1_fused*5376) + (ax2*192)) + (ax3_outer_init*64)) + ax3_inner_init), T.uint8(0), True) + for rv0_rv1_fused, ax3_outer, ax3_inner in T.grid(9, 3, 64): + T.store(tensor_1.data, ((((ax0_ax1_fused*5376) + (ax2*192)) + (ax3_outer*64)) + ax3_inner), T.max(T.load("uint8", tensor_1.data, ((((ax0_ax1_fused*5376) + (ax2*192)) + (ax3_outer*64)) + ax3_inner)), T.if_then_else(((((ax0_ax1_fused*2) + T.floordiv(rv0_rv1_fused, 3)) < 56) and (((ax2*2) + T.floormod(rv0_rv1_fused, 3)) < 56)), T.load("uint8", placeholder_1.data, ((((((ax0_ax1_fused*21504) + (T.floordiv(rv0_rv1_fused, 3)*10752)) + (ax2*384)) + (T.floormod(rv0_rv1_fused, 3)*192)) + (ax3_outer*64)) + ax3_inner)), T.uint8(0), dtype="uint8")), True) + + @T.prim_func + def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) + placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + for ax0_ax1_fused_1 in T.serial(0, 224): + for ax2_1, ax3_inner_1 in T.grid(224, 3): + T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(T.load("uint8", placeholder_4.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)), "int16") - T.load("int16", placeholder_5.data, 0)), True) + + @T.prim_func + def tvmgen_default_fused_cast(placeholder_6: T.handle, T_cast: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_cast", "tir.noalias": True}) + placeholder_7 = T.match_buffer(placeholder_6, [1, 28, 28, 192], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_1 = T.match_buffer(T_cast, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + for ax0_ax1_fused_2 in T.serial(0, 28): + for ax2_2, ax3_outer_1, ax3_inner_2 in T.grid(28, 12, 16): + T.store(T_cast_1.data, ((((ax0_ax1_fused_2*5376) + (ax2_2*192)) + (ax3_outer_1*16)) + ax3_inner_2), T.cast(T.load("uint8", placeholder_7.data, ((((ax0_ax1_fused_2*5376) + (ax2_2*192)) + (ax3_outer_1*16)) + ax3_inner_2)), "int16"), True) + + @T.prim_func + def tvmgen_default_fused_concatenate(placeholder_8: T.handle, placeholder_9: T.handle, placeholder_10: T.handle, placeholder_11: T.handle, T_concat: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_concatenate", "tir.noalias": True}) + placeholder_12 = T.match_buffer(placeholder_8, [1, 28, 28, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_concat_1 = T.match_buffer(T_concat, [1, 28, 28, 256], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_13 = T.match_buffer(placeholder_9, [1, 28, 28, 128], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_14 = T.match_buffer(placeholder_11, [1, 28, 28, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_15 = T.match_buffer(placeholder_10, [1, 28, 28, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + for ax0_ax1_fused_3 in T.serial(0, 28): + for ax2_3, ax3 in T.grid(28, 256): + T.store(T_concat_1.data, (((ax0_ax1_fused_3*7168) + (ax2_3*256)) + ax3), T.if_then_else((224 <= ax3), T.load("uint8", placeholder_14.data, ((((ax0_ax1_fused_3*896) + (ax2_3*32)) + ax3) - 224)), T.if_then_else((192 <= ax3), T.load("uint8", placeholder_15.data, ((((ax0_ax1_fused_3*896) + (ax2_3*32)) + ax3) - 192)), T.if_then_else((64 <= ax3), T.load("uint8", placeholder_13.data, ((((ax0_ax1_fused_3*3584) + (ax2_3*128)) + ax3) - 64)), T.load("uint8", placeholder_12.data, (((ax0_ax1_fused_3*1792) + (ax2_3*64)) + ax3)), dtype="uint8"), dtype="uint8"), dtype="uint8"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_cast_2: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True}) + placeholder_19 = T.match_buffer(placeholder_16, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_20 = T.match_buffer(placeholder_17, [1, 1, 64, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_21 = T.match_buffer(placeholder_18, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_3 = T.match_buffer(T_cast_2, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput = T.allocate([200704], "int16", "global") + for i0_i1_fused in T.serial(0, 56): + for i2, i3 in T.grid(56, 64): + T.store(PaddedInput, (((i0_i1_fused*3584) + (i2*64)) + i3), T.load("int16", placeholder_19.data, (((i0_i1_fused*3584) + (i2*64)) + i3)), True) + for ax0_ax1_fused_ax2_fused in T.serial(0, 3136): + Conv2dOutput = T.allocate([64], "int32", "global") + for ff in T.serial(0, 64): + T.store(Conv2dOutput, ff, 0, True) + for rc in T.serial(0, 64): + T.store(Conv2dOutput, ff, (T.load("int32", Conv2dOutput, ff) + (T.cast(T.load("int16", PaddedInput, ((ax0_ax1_fused_ax2_fused*64) + rc)), "int32")*T.cast(T.load("int16", placeholder_20.data, ((rc*64) + ff)), "int32"))), True) + for ax3_inner_3 in T.serial(0, 64): + T.store(T_cast_3.data, ((ax0_ax1_fused_ax2_fused*64) + ax3_inner_3), T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput, ax3_inner_3) + T.load("int32", placeholder_21.data, ax3_inner_3)), 1191576922, 31, -4, dtype="int32"), 255), 0), "uint8"), "int16"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, T_cast_4: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True}) + placeholder_25 = T.match_buffer(placeholder_22, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_26 = T.match_buffer(placeholder_23, [1, 1, 192, 96], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_27 = T.match_buffer(placeholder_24, [1, 1, 1, 96], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_5 = T.match_buffer(T_cast_4, [1, 28, 28, 96], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_1 = T.allocate([150528], "int16", "global") + for i0_i1_fused_1 in T.serial(0, 28): + for i2_1, i3_1 in T.grid(28, 192): + T.store(PaddedInput_1, (((i0_i1_fused_1*5376) + (i2_1*192)) + i3_1), T.load("int16", placeholder_25.data, (((i0_i1_fused_1*5376) + (i2_1*192)) + i3_1)), True) + for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 784): + Conv2dOutput_1 = T.allocate([1], "int32", "global") + for ax3_1 in T.serial(0, 96): + T.store(Conv2dOutput_1, 0, 0, True) + for rc_1 in T.serial(0, 192): + T.store(Conv2dOutput_1, 0, (T.load("int32", Conv2dOutput_1, 0) + (T.cast(T.load("int16", PaddedInput_1, ((ax0_ax1_fused_ax2_fused_1*192) + rc_1)), "int32")*T.cast(T.load("int16", placeholder_26.data, ((rc_1*96) + ax3_1)), "int32"))), True) + T.store(T_cast_5.data, ((ax0_ax1_fused_ax2_fused_1*96) + ax3_1), T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_1, 0) + T.load("int32", placeholder_27.data, ax3_1)), 1201322342, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True) + + @T.prim_func + def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) + placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + tensor_2 = T.allocate([200704], "uint8", "global") + for ax0_ax1_fused_4 in T.serial(0, 56): + for ax2_4 in T.serial(0, 56): + for ax3_init in T.serial(0, 64): + T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) + for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): + T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dtype="uint8")), True) + for ax0_ax1_fused_5 in T.serial(0, 56): + for ax2_5, ax3_3 in T.grid(56, 64): + T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_2(placeholder_30: T.handle, placeholder_31: T.handle, placeholder_32: T.handle, T_cast_8: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_2", "tir.noalias": True}) + placeholder_33 = T.match_buffer(placeholder_30, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_34 = T.match_buffer(placeholder_31, [1, 1, 192, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_35 = T.match_buffer(placeholder_32, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_9 = T.match_buffer(T_cast_8, [1, 28, 28, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_2 = T.allocate([150528], "int16", "global") + for i0_i1_fused_2 in T.serial(0, 28): + for i2_2, i3_2 in T.grid(28, 192): + T.store(PaddedInput_2, (((i0_i1_fused_2*5376) + (i2_2*192)) + i3_2), T.load("int16", placeholder_33.data, (((i0_i1_fused_2*5376) + (i2_2*192)) + i3_2)), True) + for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 784): + Conv2dOutput_2 = T.allocate([64], "int32", "global") + for ff_1 in T.serial(0, 64): + T.store(Conv2dOutput_2, ff_1, 0, True) + for rc_2 in T.serial(0, 192): + T.store(Conv2dOutput_2, ff_1, (T.load("int32", Conv2dOutput_2, ff_1) + (T.cast(T.load("int16", PaddedInput_2, ((ax0_ax1_fused_ax2_fused_2*192) + rc_2)), "int32")*T.cast(T.load("int16", placeholder_34.data, ((rc_2*64) + ff_1)), "int32"))), True) + for ax3_inner_4 in T.serial(0, 64): + T.store(T_cast_9.data, ((ax0_ax1_fused_ax2_fused_2*64) + ax3_inner_4), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_2, ax3_inner_4) + T.load("int32", placeholder_35.data, ax3_inner_4)), 1663316467, 31, -7, dtype="int32"), 255), 0), "uint8"), True) + + @T.prim_func + def tvmgen_default_fused_nn_max_pool2d_cast_1(placeholder_36: T.handle, T_cast_10: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast_1", "tir.noalias": True}) + placeholder_37 = T.match_buffer(placeholder_36, [1, 28, 28, 192], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_11 = T.match_buffer(T_cast_10, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + tensor_3 = T.allocate([150528], "uint8", "global") + for ax0_ax1_fused_6 in T.serial(0, 28): + for ax2_6 in T.serial(0, 28): + for ax3_outer_init_1, ax3_inner_init_1 in T.grid(3, 64): + T.store(tensor_3, ((((ax0_ax1_fused_6*5376) + (ax2_6*192)) + (ax3_outer_init_1*64)) + ax3_inner_init_1), T.uint8(0), True) + for rv0_rv1_fused_2, ax3_outer_2, ax3_inner_5 in T.grid(9, 3, 64): + T.store(tensor_3, ((((ax0_ax1_fused_6*5376) + (ax2_6*192)) + (ax3_outer_2*64)) + ax3_inner_5), T.max(T.load("uint8", tensor_3, ((((ax0_ax1_fused_6*5376) + (ax2_6*192)) + (ax3_outer_2*64)) + ax3_inner_5)), T.if_then_else(((((1 <= (T.floordiv(rv0_rv1_fused_2, 3) + ax0_ax1_fused_6)) and ((T.floordiv(rv0_rv1_fused_2, 3) + ax0_ax1_fused_6) < 29)) and (1 <= (ax2_6 + T.floormod(rv0_rv1_fused_2, 3)))) and ((ax2_6 + T.floormod(rv0_rv1_fused_2, 3)) < 29)), T.load("uint8", placeholder_37.data, (((((((T.floordiv(rv0_rv1_fused_2, 3)*5376) + (ax0_ax1_fused_6*5376)) + (ax2_6*192)) + (T.floormod(rv0_rv1_fused_2, 3)*192)) + (ax3_outer_2*64)) + ax3_inner_5) - 5568)), T.uint8(0), dtype="uint8")), True) + for ax0_ax1_fused_7 in T.serial(0, 28): + for ax2_7, ax3_4 in T.grid(28, 192): + T.store(T_cast_11.data, (((ax0_ax1_fused_7*5376) + (ax2_7*192)) + ax3_4), T.cast(T.load("uint8", tensor_3, (((ax0_ax1_fused_7*5376) + (ax2_7*192)) + ax3_4)), "int16"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__2(placeholder_38: T.handle, placeholder_39: T.handle, placeholder_40: T.handle, T_cast_12: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__2", "tir.noalias": True}) + placeholder_41 = T.match_buffer(placeholder_38, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_42 = T.match_buffer(placeholder_39, [1, 1, 192, 32], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_43 = T.match_buffer(placeholder_40, [1, 1, 1, 32], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_13 = T.match_buffer(T_cast_12, [1, 28, 28, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_3 = T.allocate([150528], "int16", "global") + for i0_i1_fused_3 in T.serial(0, 28): + for i2_3, i3_3 in T.grid(28, 192): + T.store(PaddedInput_3, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3), T.load("int16", placeholder_41.data, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)), True) + for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 784): + Conv2dOutput_3 = T.allocate([1], "int32", "global") + for ax3_5 in T.serial(0, 32): + T.store(Conv2dOutput_3, 0, 0, True) + for rc_3 in T.serial(0, 192): + T.store(Conv2dOutput_3, 0, (T.load("int32", Conv2dOutput_3, 0) + (T.cast(T.load("int16", PaddedInput_3, ((ax0_ax1_fused_ax2_fused_3*192) + rc_3)), "int32")*T.cast(T.load("int16", placeholder_42.data, ((rc_3*32) + ax3_5)), "int32"))), True) + T.store(T_cast_13.data, ((ax0_ax1_fused_ax2_fused_3*32) + ax3_5), T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_3, 0) + T.load("int32", placeholder_43.data, ax3_5)), 1811141736, 31, -6, dtype="int32"), 255), 0), "uint8"), "int32"), 1136333842, 31, 0, dtype="int32"), 255), 0), "uint8"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_44: T.handle, placeholder_45: T.handle, placeholder_46: T.handle, T_cast_14: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", "tir.noalias": True}) + placeholder_47 = T.match_buffer(placeholder_44, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_48 = T.match_buffer(placeholder_45, [1, 1, 192, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_49 = T.match_buffer(placeholder_46, [1, 1, 1, 16], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_15 = T.match_buffer(T_cast_14, [1, 28, 28, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_4 = T.allocate([150528], "int16", "global") + for i0_i1_fused_4 in T.serial(0, 28): + for i2_4, i3_4 in T.grid(28, 192): + T.store(PaddedInput_4, (((i0_i1_fused_4*5376) + (i2_4*192)) + i3_4), T.load("int16", placeholder_47.data, (((i0_i1_fused_4*5376) + (i2_4*192)) + i3_4)), True) + for ax0_ax1_fused_ax2_fused_4 in T.serial(0, 784): + Conv2dOutput_4 = T.allocate([1], "int32", "global") + for ax3_6 in T.serial(0, 16): + T.store(Conv2dOutput_4, 0, 0, True) + for rc_4 in T.serial(0, 192): + T.store(Conv2dOutput_4, 0, (T.load("int32", Conv2dOutput_4, 0) + (T.cast(T.load("int16", PaddedInput_4, ((ax0_ax1_fused_ax2_fused_4*192) + rc_4)), "int32")*T.cast(T.load("int16", placeholder_48.data, ((rc_4*16) + ax3_6)), "int32"))), True) + T.store(T_cast_15.data, ((ax0_ax1_fused_ax2_fused_4*16) + ax3_6), T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_4, 0) + T.load("int32", placeholder_49.data, ax3_6)), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__1(placeholder_50: T.handle, placeholder_51: T.handle, placeholder_52: T.handle, T_cast_16: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__1", "tir.noalias": True}) + placeholder_53 = T.match_buffer(placeholder_50, [1, 28, 28, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_54 = T.match_buffer(placeholder_51, [3, 3, 16, 32], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_55 = T.match_buffer(placeholder_52, [1, 1, 1, 32], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_17 = T.match_buffer(T_cast_16, [1, 28, 28, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_5 = T.allocate([14400], "int16", "global") + for i0_i1_fused_5 in T.serial(0, 30): + for i2_5, i3_5 in T.grid(30, 16): + T.store(PaddedInput_5, (((i0_i1_fused_5*480) + (i2_5*16)) + i3_5), T.if_then_else(((((1 <= i0_i1_fused_5) and (i0_i1_fused_5 < 29)) and (1 <= i2_5)) and (i2_5 < 29)), T.load("int16", placeholder_53.data, ((((i0_i1_fused_5*448) + (i2_5*16)) + i3_5) - 464)), T.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_5 in T.serial(0, 784): + Conv2dOutput_5 = T.allocate([1], "int32", "global") + for ax3_7 in T.serial(0, 32): + T.store(Conv2dOutput_5, 0, 0, True) + for ry, rx, rc_5 in T.grid(3, 3, 16): + T.store(Conv2dOutput_5, 0, (T.load("int32", Conv2dOutput_5, 0) + (T.cast(T.load("int16", PaddedInput_5, (((((T.floordiv(ax0_ax1_fused_ax2_fused_5, 28)*480) + (ry*480)) + (rx*16)) + (T.floormod(ax0_ax1_fused_ax2_fused_5, 28)*16)) + rc_5)), "int32")*T.cast(T.load("int16", placeholder_54.data, ((((ry*1536) + (rx*512)) + (rc_5*32)) + ax3_7)), "int32"))), True) + T.store(T_cast_17.data, ((ax0_ax1_fused_ax2_fused_5*32) + ax3_7), T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_5, 0) + T.load("int32", placeholder_55.data, ax3_7)), 1131968888, 31, -6, dtype="int32"), 255), 0), "uint8"), "int32"), 1900719667, 31, 0, dtype="int32"), 255), 0), "uint8"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320_(placeholder_56: T.handle, placeholder_57: T.handle, placeholder_58: T.handle, T_cast_18: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320_", "tir.noalias": True}) + placeholder_59 = T.match_buffer(placeholder_56, [1, 28, 28, 96], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_60 = T.match_buffer(placeholder_57, [3, 3, 96, 128], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_61 = T.match_buffer(placeholder_58, [1, 1, 1, 128], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_19 = T.match_buffer(T_cast_18, [1, 28, 28, 128], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_6 = T.allocate([86400], "int16", "global") + for i0_i1_fused_6 in T.serial(0, 30): + for i2_6, i3_6 in T.grid(30, 96): + T.store(PaddedInput_6, (((i0_i1_fused_6*2880) + (i2_6*96)) + i3_6), T.if_then_else(((((1 <= i0_i1_fused_6) and (i0_i1_fused_6 < 29)) and (1 <= i2_6)) and (i2_6 < 29)), T.load("int16", placeholder_59.data, ((((i0_i1_fused_6*2688) + (i2_6*96)) + i3_6) - 2784)), T.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_6 in T.serial(0, 784): + Conv2dOutput_6 = T.allocate([64], "int32", "global") + for ax3_outer_3 in T.serial(0, 2): + for ff_2 in T.serial(0, 64): + T.store(Conv2dOutput_6, ff_2, 0, True) + for ry_1, rx_1, rc_6 in T.grid(3, 3, 96): + T.store(Conv2dOutput_6, ff_2, (T.load("int32", Conv2dOutput_6, ff_2) + (T.cast(T.load("int16", PaddedInput_6, (((((T.floordiv(ax0_ax1_fused_ax2_fused_6, 28)*2880) + (ry_1*2880)) + (rx_1*96)) + (T.floormod(ax0_ax1_fused_ax2_fused_6, 28)*96)) + rc_6)), "int32")*T.cast(T.load("int16", placeholder_60.data, (((((ry_1*36864) + (rx_1*12288)) + (rc_6*128)) + (ax3_outer_3*64)) + ff_2)), "int32"))), True) + for ax3_inner_6 in T.serial(0, 64): + T.store(T_cast_19.data, (((ax0_ax1_fused_ax2_fused_6*128) + (ax3_outer_3*64)) + ax3_inner_6), T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_6, ax3_inner_6) + T.load("int32", placeholder_61.data, ((ax3_outer_3*64) + ax3_inner_6))), 1374050734, 31, -7, dtype="int32"), 255), 0), "uint8"), "int32"), 1544713713, 31, 0, dtype="int32"), 255), 0), "uint8"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "T.noalias": True}) + placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_7 = T.allocate([157323], "int16", "global") + for i0_i1_fused_7 in T.serial(0, 229): + for i2_7, i3_7 in T.grid(229, 3): + T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), T.load("int16", placeholder_65.data, ((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)), T.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): + Conv2dOutput_7 = T.allocate([64], "int32", "global") + for ff_3 in T.serial(0, 64): + T.store(Conv2dOutput_7, ff_3, 0, True) + for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): + T.store(Conv2dOutput_7, ff_3, (T.load("int32", Conv2dOutput_7, ff_3) + (T.cast(T.load("int16", PaddedInput_7, (((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)), "int32")*T.cast(T.load("int16", placeholder_66.data, ((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)), "int32"))), True) + for ax3_inner_7 in T.serial(0, 64): + T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_7, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7)), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placeholder_68: T.handle, placeholder_69: T.handle, placeholder_70: T.handle, T_cast_22: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", "tir.noalias": True}) + placeholder_71 = T.match_buffer(placeholder_68, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_72 = T.match_buffer(placeholder_69, [3, 3, 64, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_73 = T.match_buffer(placeholder_70, [1, 1, 1, 192], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_23 = T.match_buffer(T_cast_22, [1, 56, 56, 192], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_8 = T.allocate([215296], "int16", "global") + for i0_i1_fused_8 in T.serial(0, 58): + for i2_8, i3_8 in T.grid(58, 64): + T.store(PaddedInput_8, (((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8), T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), T.load("int16", placeholder_71.data, ((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)), T.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_8 in T.serial(0, 3136): + Conv2dOutput_8 = T.allocate([64], "int32", "global") + for ax3_outer_4 in T.serial(0, 3): + for ff_4 in T.serial(0, 64): + T.store(Conv2dOutput_8, ff_4, 0, True) + for ry_3, rx_3, rc_8 in T.grid(3, 3, 64): + T.store(Conv2dOutput_8, ff_4, (T.load("int32", Conv2dOutput_8, ff_4) + (T.cast(T.load("int16", PaddedInput_8, (((((T.floordiv(ax0_ax1_fused_ax2_fused_8, 56)*3712) + (ry_3*3712)) + (rx_3*64)) + (T.floormod(ax0_ax1_fused_ax2_fused_8, 56)*64)) + rc_8)), "int32")*T.cast(T.load("int16", placeholder_72.data, (((((ry_3*36864) + (rx_3*12288)) + (rc_8*192)) + (ax3_outer_4*64)) + ff_4)), "int32"))), True) + for ax3_inner_8 in T.serial(0, 64): + T.store(T_cast_23.data, (((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_8, ax3_inner_8) + T.load("int32", placeholder_73.data, ((ax3_outer_4*64) + ax3_inner_8))), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8"), True) + + @T.prim_func + def run_model(input: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) + # body + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 1) + sid_32 = T.allocate([301056], "int8", "global") + sid_20 = T.allocate([150528], "int8", "global") + sid_6 = T.allocate([401408], "int8", "global") + sid_9 = T.allocate([301056], "int8", "global") + sid_7 = T.allocate([401408], "int8", "global") + sid_8 = T.allocate([802816], "int8", "global") + sid_2 = T.allocate([50176], "int8", "global") + sid_3 = T.allocate([301056], "int8", "global") + sid_19 = T.allocate([100352], "int8", "global") + sid_4 = T.allocate([150528], "int8", "global") + sid_5 = T.allocate([602112], "int8", "global") + sid_25 = T.allocate([25088], "int8", "global") + sid_26 = T.allocate([25088], "int8", "global") + sid_31 = T.allocate([25088], "int8", "global") + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, sid_7, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_7, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_6, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", sid_6, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_5, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d", sid_5, sid_4, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_cast", sid_4, sid_3, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_2", sid_3, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_2, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_3, T.lookup_param("p9", dtype="handle"), T.lookup_param("p10", dtype="handle"), sid_20, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320_", sid_20, T.lookup_param("p11", dtype="handle"), T.lookup_param("p12", dtype="handle"), sid_19, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", sid_3, T.lookup_param("p13", dtype="handle"), T.lookup_param("p14", dtype="handle"), sid_26, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__1", sid_26, T.lookup_param("p15", dtype="handle"), T.lookup_param("p16", dtype="handle"), sid_25, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast_1", sid_4, sid_32, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__2", sid_32, T.lookup_param("p17", dtype="handle"), T.lookup_param("p18", dtype="handle"), sid_31, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_concatenate", sid_2, sid_19, sid_25, sid_31, output, dtype="int32")) + __tvm_meta__ = None +# fmt: on + + +def test_inception_structure(): + target = Target("c") + global_ws_pool = usmp_utils.PoolInfo( + pool_name="global_workspace", + target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + ) + tir_mod = InceptionStructure + tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) + tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_ws_pool]) + main_func = tir_mod["run_model"] + buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) + buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map) + + # check conflicts + _verify_conflicts( + "PaddedInput_8", + [ + "sid_6", + "Conv2dOutput_8", + "sid_5", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_26", + [ + "PaddedInput_4", + "Conv2dOutput_4", + "PaddedInput_5", + ], + buffer_info_map, + ) + _verify_conflicts( + "Conv2dOutput", + [ + "sid_6", + "PaddedInput", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_4", + [ + "sid_5", + "sid_3", + "tensor_3", + ], + buffer_info_map, + ) + _verify_conflicts( + "tensor_2", + [ + "sid_8", + "sid_7", + ], + buffer_info_map, + ) + _verify_conflicts( + "Conv2dOutput_7", + [ + "sid_8", + "PaddedInput_7", + ], + buffer_info_map, + ) + _verify_conflicts( + "Conv2dOutput_1", + [ + "sid_20", + "PaddedInput_1", + ], + buffer_info_map, + ) + _verify_conflicts( + "Conv2dOutput_4", + [ + "sid_26", + "PaddedInput_4", + ], + buffer_info_map, + ) + _verify_conflicts( + "Conv2dOutput_2", + [ + "PaddedInput_2", + "sid_2", + ], + buffer_info_map, + ) + _verify_conflicts( + "PaddedInput_3", + [ + "sid_32", + "sid_31", + "Conv2dOutput_3", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_3", + [ + "sid_4", + "PaddedInput_2", + "PaddedInput_1", + "PaddedInput_4", + ], + buffer_info_map, + ) + _verify_conflicts( + "Conv2dOutput_6", + [ + "PaddedInput_6", + "sid_19", + ], + buffer_info_map, + ) + _verify_conflicts( + "Conv2dOutput_5", + [ + "PaddedInput_5", + "sid_25", + ], + buffer_info_map, + ) + _verify_conflicts( + "PaddedInput_7", + [ + "sid_9", + "sid_8", + "Conv2dOutput_7", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_7", + [ + "tensor_2", + "PaddedInput", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_31", + [ + "PaddedInput_3", + "Conv2dOutput_3", + "sid_25", + "sid_2", + "sid_19", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_5", + [ + "Conv2dOutput_8", + "PaddedInput_8", + "sid_4", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_6", + [ + "PaddedInput", + "Conv2dOutput", + "PaddedInput_8", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_20", + [ + "PaddedInput_1", + "Conv2dOutput_1", + "PaddedInput_6", + ], + buffer_info_map, + ) + _verify_conflicts( + "Conv2dOutput_8", + [ + "PaddedInput_8", + "sid_5", + ], + buffer_info_map, + ) + _verify_conflicts( + "PaddedInput_1", + [ + "sid_3", + "sid_20", + "Conv2dOutput_1", + ], + buffer_info_map, + ) + _verify_conflicts( + "Conv2dOutput_3", + [ + "sid_31", + "PaddedInput_3", + ], + buffer_info_map, + ) + _verify_conflicts( + "PaddedInput", + [ + "sid_7", + "sid_6", + "Conv2dOutput", + ], + buffer_info_map, + ) + _verify_conflicts( + "PaddedInput_2", + [ + "sid_3", + "Conv2dOutput_2", + "sid_2", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_19", + [ + "Conv2dOutput_6", + "PaddedInput_6", + "sid_31", + "sid_2", + "sid_25", + ], + buffer_info_map, + ) + _verify_conflicts( + "PaddedInput_4", + [ + "sid_3", + "sid_26", + "Conv2dOutput_4", + ], + buffer_info_map, + ) + _verify_conflicts( + "PaddedInput_5", + [ + "sid_26", + "Conv2dOutput_5", + "sid_25", + ], + buffer_info_map, + ) + _verify_conflicts( + "PaddedInput_6", + [ + "sid_20", + "Conv2dOutput_6", + "sid_19", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_25", + [ + "Conv2dOutput_5", + "PaddedInput_5", + "sid_31", + "sid_2", + "sid_19", + ], + buffer_info_map, + ) + _verify_conflicts( + "tensor_3", + [ + "sid_4", + "sid_32", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_32", + [ + "tensor_3", + "PaddedInput_3", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_9", + [ + "PaddedInput_7", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_2", + [ + "Conv2dOutput_2", + "PaddedInput_2", + "sid_31", + "sid_25", + "sid_19", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_8", + [ + "PaddedInput_7", + "Conv2dOutput_7", + "tensor_2", + ], + buffer_info_map, + ) + + # check sizes + assert buffer_info_map["sid_20"].size_bytes == 150528 + assert buffer_info_map["tensor_2"].size_bytes == 200704 + assert buffer_info_map["sid_5"].size_bytes == 602112 + assert buffer_info_map["sid_9"].size_bytes == 301056 + assert buffer_info_map["Conv2dOutput_3"].size_bytes == 4 + assert buffer_info_map["sid_26"].size_bytes == 25088 + assert buffer_info_map["Conv2dOutput_2"].size_bytes == 256 + assert buffer_info_map["PaddedInput_5"].size_bytes == 28800 + assert buffer_info_map["sid_8"].size_bytes == 802816 + assert buffer_info_map["Conv2dOutput_5"].size_bytes == 4 + assert buffer_info_map["sid_3"].size_bytes == 301056 + assert buffer_info_map["Conv2dOutput"].size_bytes == 256 + assert buffer_info_map["PaddedInput_3"].size_bytes == 301056 + assert buffer_info_map["sid_32"].size_bytes == 301056 + assert buffer_info_map["PaddedInput_8"].size_bytes == 430592 + assert buffer_info_map["sid_4"].size_bytes == 150528 + assert buffer_info_map["PaddedInput_7"].size_bytes == 314646 + assert buffer_info_map["sid_6"].size_bytes == 401408 + assert buffer_info_map["Conv2dOutput_8"].size_bytes == 256 + assert buffer_info_map["sid_25"].size_bytes == 25088 + assert buffer_info_map["PaddedInput"].size_bytes == 401408 + assert buffer_info_map["sid_7"].size_bytes == 401408 + assert buffer_info_map["Conv2dOutput_1"].size_bytes == 4 + assert buffer_info_map["Conv2dOutput_4"].size_bytes == 4 + assert buffer_info_map["PaddedInput_2"].size_bytes == 301056 + assert buffer_info_map["sid_31"].size_bytes == 25088 + assert buffer_info_map["PaddedInput_1"].size_bytes == 301056 + assert buffer_info_map["Conv2dOutput_6"].size_bytes == 256 + assert buffer_info_map["PaddedInput_4"].size_bytes == 301056 + assert buffer_info_map["sid_2"].size_bytes == 50176 + assert buffer_info_map["tensor_3"].size_bytes == 150528 + assert buffer_info_map["Conv2dOutput_7"].size_bytes == 256 + assert buffer_info_map["sid_19"].size_bytes == 100352 + assert buffer_info_map["PaddedInput_6"].size_bytes == 172800 + + +# fmt: off +@tvm.script.ir_module +class MultipleCallsToSamePrimFuncModule: + @T.prim_func + def tvmgen_default_fused_layout_transform_1(placeholder: T.handle, T_layout_trans: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_layout_transform_1", "tir.noalias": True}) + placeholder_1 = T.match_buffer(placeholder, [1, 3, 24, 12], dtype="float32") + T_layout_trans_1 = T.match_buffer(T_layout_trans, [1, 1, 24, 12, 3], dtype="float32") + # body + for ax0_ax1_fused_ax2_fused, ax3, ax4_inner in T.grid(24, 12, 3): + T.store(T_layout_trans_1.data, ax0_ax1_fused_ax2_fused * 36 + ax3 * 3 + ax4_inner, T.load("float32", placeholder_1.data, ax4_inner * 288 + ax0_ax1_fused_ax2_fused * 12 + ax3), True) + + @T.prim_func + def tvmgen_default_fused_nn_contrib_conv2d_NCHWc(placeholder_2: T.handle, placeholder_3: T.handle, conv2d_NCHWc: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_nn_contrib_conv2d_NCHWc", "tir.noalias": True}) + placeholder_4 = T.match_buffer(placeholder_2, [1, 1, 24, 12, 3], dtype="float32") + placeholder_5 = T.match_buffer(placeholder_3, [1, 1, 3, 3, 3, 3], dtype="float32") + conv2d_NCHWc_1 = T.match_buffer(conv2d_NCHWc, [1, 1, 24, 12, 3], dtype="float32") + # body + data_pad = T.allocate([1, 1, 26, 14, 3], "float32", "global") + for i0_i1_fused_i2_fused, i3, i4 in T.grid(26, 14, 3): + T.store(data_pad, i0_i1_fused_i2_fused * 42 + i3 * 3 + i4, T.if_then_else(1 <= i0_i1_fused_i2_fused and i0_i1_fused_i2_fused < 25 and 1 <= i3 and i3 < 13, T.load("float32", placeholder_4.data, i0_i1_fused_i2_fused * 36 + i3 * 3 + i4 - 39), T.float32(0), dtype="float32"), True) + for n_oc_chunk_fused_oh_fused in T.serial(0, 24): + conv2d_NCHWc_global = T.allocate([1, 1, 1, 12, 3], "float32", "global") + for oc_block_c_init in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c_init, T.float32(0), True) + for oc_block_c_init in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c_init + 3, T.float32(0), True) + for oc_block_c_init in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c_init + 6, T.float32(0), True) + for oc_block_c_init in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c_init + 9, T.float32(0), True) + for oc_block_c_init in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c_init + 12, T.float32(0), True) + for oc_block_c_init in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c_init + 15, T.float32(0), True) + for oc_block_c_init in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c_init + 18, T.float32(0), True) + for oc_block_c_init in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c_init + 21, T.float32(0), True) + for oc_block_c_init in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c_init + 24, T.float32(0), True) + for oc_block_c_init in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c_init + 27, T.float32(0), True) + for oc_block_c_init in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c_init + 30, T.float32(0), True) + for oc_block_c_init in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c_init + 33, T.float32(0), True) + for kh, kw, ic_inner in T.grid(3, 3, 3): + for oc_block_c in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c, T.load("float32", conv2d_NCHWc_global, oc_block_c) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + for oc_block_c in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c + 3, T.load("float32", conv2d_NCHWc_global, oc_block_c + 3) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 3) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + for oc_block_c in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c + 6, T.load("float32", conv2d_NCHWc_global, oc_block_c + 6) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 6) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + for oc_block_c in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c + 9, T.load("float32", conv2d_NCHWc_global, oc_block_c + 9) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 9) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + for oc_block_c in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c + 12, T.load("float32", conv2d_NCHWc_global, oc_block_c + 12) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 12) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + for oc_block_c in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c + 15, T.load("float32", conv2d_NCHWc_global, oc_block_c + 15) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 15) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + for oc_block_c in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c + 18, T.load("float32", conv2d_NCHWc_global, oc_block_c + 18) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 18) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + for oc_block_c in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c + 21, T.load("float32", conv2d_NCHWc_global, oc_block_c + 21) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 21) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + for oc_block_c in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c + 24, T.load("float32", conv2d_NCHWc_global, oc_block_c + 24) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 24) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + for oc_block_c in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c + 27, T.load("float32", conv2d_NCHWc_global, oc_block_c + 27) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 27) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + for oc_block_c in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c + 30, T.load("float32", conv2d_NCHWc_global, oc_block_c + 30) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 30) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + for oc_block_c in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c + 33, T.load("float32", conv2d_NCHWc_global, oc_block_c + 33) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 33) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + for ow_inner, oc_block in T.grid(12, 3): + T.store(conv2d_NCHWc_1.data, n_oc_chunk_fused_oh_fused * 36 + ow_inner * 3 + oc_block, T.load("float32", conv2d_NCHWc_global, ow_inner * 3 + oc_block), True) + + @T.prim_func + def tvmgen_default_fused_nn_softmax_add_add_multiply_add(placeholder_6: T.handle, placeholder_7: T.handle, placeholder_8: T.handle, placeholder_9: T.handle, placeholder_10: T.handle, T_add: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_nn_softmax_add_add_multiply_add", "tir.noalias": True}) + placeholder_11 = T.match_buffer(placeholder_6, [1, 3, 24, 12], dtype="float32") + placeholder_12 = T.match_buffer(placeholder_7, [1, 3, 24, 12], dtype="float32") + placeholder_13 = T.match_buffer(placeholder_8, [3, 1, 1], dtype="float32") + placeholder_14 = T.match_buffer(placeholder_9, [3, 1, 1], dtype="float32") + placeholder_15 = T.match_buffer(placeholder_10, [3, 1, 1], dtype="float32") + T_add_1 = T.match_buffer(T_add, [1, 3, 24, 12], dtype="float32") + # body + for ax0_ax1_fused_ax2_fused in T.serial(0, 72): + T_softmax_norm = T.allocate([1, 1, 1, 12], "float32", "global") + with T.allocate([1, 1, 1], "float32", "global") as T_softmax_maxelem: + T.store(T_softmax_maxelem, 0, T.float32(-3.4028234663852886e+38), True) + for k in T.serial(0, 12): + T.store(T_softmax_maxelem, 0, T.max(T.load("float32", T_softmax_maxelem, 0), T.load("float32", placeholder_11.data, ax0_ax1_fused_ax2_fused * 12 + k)), True) + T_softmax_exp = T.allocate([1, 1, 1, 12], "float32", "global") + for i3 in T.serial(0, 12): + T.store(T_softmax_exp, i3, T.exp(T.load("float32", placeholder_11.data, ax0_ax1_fused_ax2_fused * 12 + i3) - T.load("float32", T_softmax_maxelem, 0), dtype="float32"), True) + T_softmax_expsum = T.allocate([1, 1, 1], "float32", "global") + T.store(T_softmax_expsum, 0, T.float32(0), True) + for k in T.serial(0, 12): + T.store(T_softmax_expsum, 0, T.load("float32", T_softmax_expsum, 0) + T.load("float32", T_softmax_exp, k), True) + for i3 in T.serial(0, 12): + T.store(T_softmax_norm, i3, T.load("float32", T_softmax_exp, i3) / T.load("float32", T_softmax_expsum, 0), True) + for ax3 in T.serial(0, 12): + T.store(T_add_1.data, ax0_ax1_fused_ax2_fused * 12 + ax3, (T.load("float32", placeholder_12.data, ax0_ax1_fused_ax2_fused * 12 + ax3) + T.load("float32", T_softmax_norm, ax3) + T.load("float32", placeholder_13.data, T.floordiv(ax0_ax1_fused_ax2_fused, 24))) * T.load("float32", placeholder_14.data, T.floordiv(ax0_ax1_fused_ax2_fused, 24)) + T.load("float32", placeholder_15.data, T.floordiv(ax0_ax1_fused_ax2_fused, 24)), True) + + @T.prim_func + def tvmgen_default_fused_nn_contrib_dense_pack_nn_relu(placeholder_16: T.handle, placeholder_17: T.handle, T_relu: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_nn_contrib_dense_pack_nn_relu", "tir.noalias": True}) + placeholder_18 = T.match_buffer(placeholder_16, [72, 12], dtype="float32") + placeholder_19 = T.match_buffer(placeholder_17, [2, 12, 6], dtype="float32") + T_relu_1 = T.match_buffer(T_relu, [72, 12], dtype="float32") + # body + for ax1_outer_ax0_outer_fused in T.serial(0, 18): + compute = T.allocate([8, 6], "float32", "global") + with T.allocate([8, 6], "float32", "global") as compute_global: + for x_c_init in T.serial(0, 6): + T.store(compute_global, x_c_init, T.float32(0), True) + for x_c_init in T.serial(0, 6): + T.store(compute_global, x_c_init + 6, T.float32(0), True) + for x_c_init in T.serial(0, 6): + T.store(compute_global, x_c_init + 12, T.float32(0), True) + for x_c_init in T.serial(0, 6): + T.store(compute_global, x_c_init + 18, T.float32(0), True) + for x_c_init in T.serial(0, 6): + T.store(compute_global, x_c_init + 24, T.float32(0), True) + for x_c_init in T.serial(0, 6): + T.store(compute_global, x_c_init + 30, T.float32(0), True) + for x_c_init in T.serial(0, 6): + T.store(compute_global, x_c_init + 36, T.float32(0), True) + for x_c_init in T.serial(0, 6): + T.store(compute_global, x_c_init + 42, T.float32(0), True) + for k_outer in T.serial(0, 12): + for x_c in T.serial(0, 6): + T.store(compute_global, x_c, T.load("float32", compute_global, x_c) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + for x_c in T.serial(0, 6): + T.store(compute_global, x_c + 6, T.load("float32", compute_global, x_c + 6) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 12) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + for x_c in T.serial(0, 6): + T.store(compute_global, x_c + 12, T.load("float32", compute_global, x_c + 12) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 24) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + for x_c in T.serial(0, 6): + T.store(compute_global, x_c + 18, T.load("float32", compute_global, x_c + 18) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 36) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + for x_c in T.serial(0, 6): + T.store(compute_global, x_c + 24, T.load("float32", compute_global, x_c + 24) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 48) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + for x_c in T.serial(0, 6): + T.store(compute_global, x_c + 30, T.load("float32", compute_global, x_c + 30) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 60) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + for x_c in T.serial(0, 6): + T.store(compute_global, x_c + 36, T.load("float32", compute_global, x_c + 36) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 72) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + for x_c in T.serial(0, 6): + T.store(compute_global, x_c + 42, T.load("float32", compute_global, x_c + 42) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 84) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + for x_inner_inner in T.serial(0, 6): + T.store(compute, x_inner_inner, T.load("float32", compute_global, x_inner_inner), True) + for x_inner_inner in T.serial(0, 6): + T.store(compute, x_inner_inner + 6, T.load("float32", compute_global, x_inner_inner + 6), True) + for x_inner_inner in T.serial(0, 6): + T.store(compute, x_inner_inner + 12, T.load("float32", compute_global, x_inner_inner + 12), True) + for x_inner_inner in T.serial(0, 6): + T.store(compute, x_inner_inner + 18, T.load("float32", compute_global, x_inner_inner + 18), True) + for x_inner_inner in T.serial(0, 6): + T.store(compute, x_inner_inner + 24, T.load("float32", compute_global, x_inner_inner + 24), True) + for x_inner_inner in T.serial(0, 6): + T.store(compute, x_inner_inner + 30, T.load("float32", compute_global, x_inner_inner + 30), True) + for x_inner_inner in T.serial(0, 6): + T.store(compute, x_inner_inner + 36, T.load("float32", compute_global, x_inner_inner + 36), True) + for x_inner_inner in T.serial(0, 6): + T.store(compute, x_inner_inner + 42, T.load("float32", compute_global, x_inner_inner + 42), True) + for ax0_inner_inner, ax1_inner_inner in T.grid(8, 6): + T.store(T_relu_1.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + ax0_inner_inner * 12 + T.floordiv(ax1_outer_ax0_outer_fused, 9) * 6 + ax1_inner_inner, T.max(T.load("float32", compute, ax0_inner_inner * 6 + ax1_inner_inner), T.float32(0)), True) + + @T.prim_func + def tvmgen_default_fused_reshape_1(placeholder_20: T.handle, T_reshape: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_reshape_1", "tir.noalias": True}) + placeholder_21 = T.match_buffer(placeholder_20, [1, 3, 24, 12], dtype="float32") + T_reshape_1 = T.match_buffer(T_reshape, [72, 12], dtype="float32") + # body + for ax0, ax1_inner in T.grid(72, 12): + T.store(T_reshape_1.data, ax0 * 12 + ax1_inner, T.load("float32", placeholder_21.data, ax0 * 12 + ax1_inner), True) + + @T.prim_func + def tvmgen_default_fused_layout_transform(placeholder_22: T.handle, T_layout_trans_2: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_layout_transform", "tir.noalias": True}) + placeholder_23 = T.match_buffer(placeholder_22, [1, 1, 24, 12, 3], dtype="float32") + T_layout_trans_3 = T.match_buffer(T_layout_trans_2, [1, 3, 24, 12], dtype="float32") + # body + for ax0_ax1_fused, ax2, ax3_inner in T.grid(3, 24, 12): + T.store(T_layout_trans_3.data, ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner, T.load("float32", placeholder_23.data, ax2 * 36 + ax3_inner * 3 + ax0_ax1_fused), True) + + @T.prim_func + def tvmgen_default_fused_reshape(placeholder_24: T.handle, T_reshape_2: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_reshape", "tir.noalias": True}) + placeholder_25 = T.match_buffer(placeholder_24, [72, 12], dtype="float32") + T_reshape_3 = T.match_buffer(T_reshape_2, [1, 3, 24, 12], dtype="float32") + # body + for ax0_ax1_fused, ax2, ax3_inner in T.grid(3, 24, 12): + T.store(T_reshape_3.data, ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner, T.load("float32", placeholder_25.data, ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner), True) + + @T.prim_func + def tvmgen_default_fused_nn_softmax_add(placeholder_26: T.handle, placeholder_27: T.handle, T_add_2: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_nn_softmax_add", "tir.noalias": True}) + placeholder_28 = T.match_buffer(placeholder_26, [1, 3, 24, 12], dtype="float32") + placeholder_29 = T.match_buffer(placeholder_27, [1, 3, 24, 12], dtype="float32") + T_add_3 = T.match_buffer(T_add_2, [1, 3, 24, 12], dtype="float32") + # body + for ax0_ax1_fused_ax2_fused in T.serial(0, 72): + T_softmax_norm = T.allocate([1, 1, 1, 12], "float32", "global") + with T.allocate([1, 1, 1], "float32", "global") as T_softmax_maxelem: + T.store(T_softmax_maxelem, 0, T.float32(-3.4028234663852886e+38), True) + for k in T.serial(0, 12): + T.store(T_softmax_maxelem, 0, T.max(T.load("float32", T_softmax_maxelem, 0), T.load("float32", placeholder_28.data, ax0_ax1_fused_ax2_fused * 12 + k)), True) + T_softmax_exp = T.allocate([1, 1, 1, 12], "float32", "global") + for i3 in T.serial(0, 12): + T.store(T_softmax_exp, i3, T.exp(T.load("float32", placeholder_28.data, ax0_ax1_fused_ax2_fused * 12 + i3) - T.load("float32", T_softmax_maxelem, 0), dtype="float32"), True) + T_softmax_expsum = T.allocate([1, 1, 1], "float32", "global") + T.store(T_softmax_expsum, 0, T.float32(0), True) + for k in T.serial(0, 12): + T.store(T_softmax_expsum, 0, T.load("float32", T_softmax_expsum, 0) + T.load("float32", T_softmax_exp, k), True) + for i3 in T.serial(0, 12): + T.store(T_softmax_norm, i3, T.load("float32", T_softmax_exp, i3) / T.load("float32", T_softmax_expsum, 0), True) + for ax3 in T.serial(0, 12): + T.store(T_add_3.data, ax0_ax1_fused_ax2_fused * 12 + ax3, T.load("float32", placeholder_29.data, ax0_ax1_fused_ax2_fused * 12 + ax3) + T.load("float32", T_softmax_norm, ax3), True) + + @T.prim_func + def run_model(data: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) + data_buffer = T.match_buffer(data, [1, 3, 24, 12], dtype="float32", align=16) + output_buffer = T.match_buffer(output, [1, 3, 24, 12], dtype="float32", align=16) + # body + sid_11 = T.allocate([3456], "int8", "global.workspace") + sid_5 = T.allocate([3456], "int8", "global.workspace") + sid_10 = T.allocate([3456], "int8", "global.workspace") + sid_6 = T.allocate([3456], "int8", "global.workspace") + sid_8 = T.allocate([3456], "int8", "global.workspace") + sid_2 = T.allocate([3456], "int8", "global.workspace") + sid_7 = T.allocate([3456], "int8", "global.workspace") + sid_3 = T.allocate([3456], "int8", "global.workspace") + sid_12 = T.allocate([3456], "int8", "global.workspace") + sid_4 = T.allocate([3456], "int8", "global.workspace") + sid_18 = T.allocate([3456], "int8", "global.workspace") + sid_19 = T.allocate([3456], "int8", "global.workspace") + sid_20 = T.allocate([3456], "int8", "global.workspace") + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform_1", data_buffer.data, sid_8, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_conv2d_NCHWc", sid_8, T.cast(T.lookup_param("p0", dtype="handle"), "handle"), sid_7, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform", sid_7, sid_6, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape_1", data_buffer.data, sid_12, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_dense_pack_nn_relu", sid_12, T.cast(T.lookup_param("p1", dtype="handle"), "handle"), sid_11, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape", sid_11, sid_10, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_softmax_add_add_multiply_add", sid_6, sid_10, T.cast(T.lookup_param("p2", dtype="handle"), "handle"), T.cast(T.lookup_param("p3", dtype="handle"), "handle"), T.cast(T.lookup_param("p4", dtype="handle"), "handle"), sid_5, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform_1", sid_5, sid_4, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_conv2d_NCHWc", sid_4, T.cast(T.lookup_param("p5", dtype="handle"), "handle"), sid_3, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform", sid_3, sid_2, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape_1", sid_5, sid_20, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_dense_pack_nn_relu", sid_20, T.cast(T.lookup_param("p6", dtype="handle"), "handle"), sid_19, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape", sid_19, sid_18, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_softmax_add", sid_2, sid_18, output_buffer.data, dtype="int32")) +# fmt: on + + +def test_multiple_calls_to_same_primfunc(): + target = Target("c") + global_ws_pool = usmp_utils.PoolInfo( + pool_name="global_workspace", + target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + ) + tir_mod = MultipleCallsToSamePrimFuncModule + tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) + tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_ws_pool]) + main_func = tir_mod["run_model"] + buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) + buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map) + + # check conflicts + _verify_conflicts( + "sid_18", + [ + "sid_19", + "sid_2", + "T_softmax_exp2", + "T_softmax_maxelem2", + "T_softmax_expsum2", + "T_softmax_norm2", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_3", + [ + "data_pad", + "conv2d_NCHWc_global", + "sid_2", + ], + buffer_info_map, + ) + _verify_conflicts( + "T_softmax_norm", + [ + "T_softmax_expsum", + "T_softmax_exp", + "sid_5", + "sid_6", + "T_softmax_maxelem", + "sid_10", + ], + buffer_info_map, + ) + _verify_conflicts( + "T_softmax_norm2", + [ + "T_softmax_expsum2", + "T_softmax_maxelem2", + "T_softmax_exp2", + "sid_18", + "sid_2", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_11", + [ + "compute", + "sid_12", + "compute_global", + "sid_10", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_10", + [ + "sid_11", + "sid_6", + "sid_5", + "T_softmax_norm", + "T_softmax_expsum", + "T_softmax_maxelem", + "T_softmax_exp", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_5", + [ + "T_softmax_norm", + "T_softmax_expsum", + "T_softmax_exp", + "sid_6", + "T_softmax_maxelem", + "sid_10", + "sid_4", + "sid_20", + ], + buffer_info_map, + ) + _verify_conflicts( + "T_softmax_expsum", + [ + "T_softmax_exp", + "T_softmax_norm", + "sid_5", + "sid_6", + "T_softmax_maxelem", + "sid_10", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_8", + [ + "data_pad", + ], + buffer_info_map, + ) + _verify_conflicts( + "T_softmax_expsum2", + [ + "T_softmax_maxelem2", + "T_softmax_exp2", + "sid_18", + "sid_2", + "T_softmax_norm2", + ], + buffer_info_map, + ) + _verify_conflicts( + "T_softmax_maxelem2", + [ + "T_softmax_exp2", + "sid_18", + "sid_2", + "T_softmax_expsum2", + "T_softmax_norm2", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_12", + [ + "sid_11", + "compute", + "compute_global", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_19", + [ + "sid_20", + "compute", + "compute_global", + "sid_18", + ], + buffer_info_map, + ) + _verify_conflicts( + "conv2d_NCHWc_global", + [ + "data_pad", + "sid_7", + "sid_3", + "data_pad", + ], + buffer_info_map, + ) + _verify_conflicts( + "T_softmax_exp2", + [ + "sid_18", + "sid_2", + "T_softmax_maxelem2", + "T_softmax_expsum2", + "T_softmax_norm2", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_7", + [ + "conv2d_NCHWc_global", + "data_pad", + "sid_6", + ], + buffer_info_map, + ) + _verify_conflicts( + "data_pad", + [ + "sid_8", + "conv2d_NCHWc_global", + "sid_7", + "sid_4", + "sid_3", + "conv2d_NCHWc_global", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_20", + [ + "sid_5", + "sid_19", + "compute", + "compute_global", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_4", + [ + "sid_5", + "data_pad", + ], + buffer_info_map, + ) + _verify_conflicts( + "T_softmax_exp", + [ + "T_softmax_expsum", + "T_softmax_norm", + "sid_5", + "sid_6", + "T_softmax_maxelem", + "sid_10", + ], + buffer_info_map, + ) + _verify_conflicts( + "compute_global", + [ + "sid_12", + "sid_11", + "compute", + "compute", + "sid_20", + "sid_19", + ], + buffer_info_map, + ) + _verify_conflicts( + "compute", + [ + "sid_11", + "sid_12", + "compute_global", + "sid_20", + "sid_19", + "compute_global", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_6", + [ + "sid_7", + "sid_5", + "T_softmax_norm", + "T_softmax_expsum", + "T_softmax_exp", + "T_softmax_maxelem", + "sid_10", + ], + buffer_info_map, + ) + _verify_conflicts( + "T_softmax_maxelem", + [ + "sid_6", + "sid_5", + "T_softmax_norm", + "T_softmax_expsum", + "T_softmax_exp", + "sid_10", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_2", + [ + "sid_3", + "sid_18", + "T_softmax_exp2", + "T_softmax_maxelem2", + "T_softmax_expsum2", + "T_softmax_norm2", + ], + buffer_info_map, + ) + + +if __name__ == "__main__": + pytest.main([__file__] + sys.argv[1:]) diff --git a/tests/python/unittest/test_tir_usmp_utils.py b/tests/python/unittest/test_tir_usmp_utils.py new file mode 100644 index 000000000000..232bf6a151fc --- /dev/null +++ b/tests/python/unittest/test_tir_usmp_utils.py @@ -0,0 +1,203 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import sys + +import tvm +from tvm.script import tir as T +from tvm.tir import stmt_functor +from tvm.tir.usmp import utils as usmp_utils +from tvm.target import Target + + +# fmt: off +@tvm.script.ir_module +class LinearStructure: + @T.prim_func + def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) + placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dTpe="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + for ax0_ax1_fused_1 in T.serial(0, 224): + for ax2_1, ax3_inner_1 in T.grid(224, 3): + T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(T.load("uint8", placeholder_4.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)), "int16") - T.load("int16", placeholder_5.data, 0)), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) + placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_7 = T.allocate([157323], "int16", "global") + for i0_i1_fused_7 in T.serial(0, 229): + for i2_7, i3_7 in T.grid(229, 3): + T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), T.load("int16", placeholder_65.data, ((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)), T.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): + Conv2dOutput_7 = T.allocate([64], "int32", "global") + for ff_3 in T.serial(0, 64): + T.store(Conv2dOutput_7, ff_3, 0, True) + for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): + T.store(Conv2dOutput_7, ff_3, (T.load("int32", Conv2dOutput_7, ff_3) + (T.cast(T.load("int16", PaddedInput_7, (((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)), "int32")*T.cast(T.load("int16", placeholder_66.data, ((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)), "int32"))), True) + for ax3_inner_7 in T.serial(0, 64): + T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_7, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7)), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) + + @T.prim_func + def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) + placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + tensor_2 = T.allocate([200704], "uint8", "global") + for ax0_ax1_fused_4 in T.serial(0, 56): + for ax2_4 in T.serial(0, 56): + for ax3_init in T.serial(0, 64): + T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) + for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): + T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dtype="uint8")), True) + for ax0_ax1_fused_5 in T.serial(0, 56): + for ax2_5, ax3_3 in T.grid(56, 64): + T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True) + + @T.prim_func + def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) + # body + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 1) + sid_9 = T.allocate([301056], "int8", "global") + sid_8 = T.allocate([802816], "int8", "global") + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32")) + __tvm_meta__ = None +# fmt: on + + +def test_create_pool_info(): + target = Target("c") + pool_info = usmp_utils.PoolInfo( + pool_name="foo_workspace", + target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + ) + assert pool_info.pool_name == "foo_workspace" + assert dict(pool_info.target_access) == {target: usmp_utils.PoolInfo.READ_WRITE_ACCESS} + # default pool size constraint + assert pool_info.size_hint_bytes == -1 + + pool_info = usmp_utils.PoolInfo( + pool_name="bar_workspace", + target_access={target: usmp_utils.PoolInfo.READ_ONLY_ACCESS}, + size_hint_bytes=1425, + ) + assert pool_info.pool_name == "bar_workspace" + assert dict(pool_info.target_access) == {target: usmp_utils.PoolInfo.READ_ONLY_ACCESS} + assert pool_info.size_hint_bytes == 1425 + + +def test_create_buffer_info(): + global_ws_pool = usmp_utils.PoolInfo( + pool_name="global_workspace", + target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + ) + buffer_info_obj = tvm.tir.usmp.BufferInfo( + name_hint="buf1", size_bytes=256, pool_candidates=[global_ws_pool] + ) + assert buffer_info_obj.name_hint == "buf1" + assert buffer_info_obj.size_bytes == 256 + assert list(buffer_info_obj.pool_candidates) == [global_ws_pool] + # default workspace alignment + assert buffer_info_obj.alignment == 1 + + buffer_info_obj = tvm.tir.usmp.BufferInfo("buf2", 512, [global_ws_pool], 8) + assert buffer_info_obj.name_hint == "buf2" + assert buffer_info_obj.size_bytes == 512 + assert list(buffer_info_obj.pool_candidates) == [global_ws_pool] + assert buffer_info_obj.alignment == 8 + + +def test_create_pool_allocation(): + pool_info = usmp_utils.PoolInfo( + pool_name="foo_workspace", + target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + ) + pool_allocation = usmp_utils.PoolAllocation(pool_info=pool_info, byte_offset=64) + assert pool_allocation.pool_info == pool_info + assert pool_allocation.byte_offset == 64 + + +def _assign_poolinfos_to_allocates_in_primfunc(primfunc, pool_infos): + """helper to assing poolinfos to allocate nodes in a tir.PrimFunc""" + + def set_poolinfos(stmt): + if isinstance(stmt, tvm.tir.Allocate): + return tvm.tir.Allocate( + buffer_var=stmt.buffer_var, + dtype=stmt.dtype, + extents=stmt.extents, + condition=stmt.condition, + body=stmt.body, + annotations={tvm.tir.usmp.utils.CANDIDATE_MEMORY_POOL_ATTR: pool_infos}, + ) + + return primfunc.with_body(stmt_functor.ir_transform(primfunc.body, None, set_poolinfos)) + + +def _assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos): + """helper to assing poolinfos to allocate nodes in a IRModule""" + ret = tvm.IRModule() + for global_var, basefunc in mod.functions.items(): + if isinstance(basefunc, tvm.tir.PrimFunc): + ret[global_var] = _assign_poolinfos_to_allocates_in_primfunc(basefunc, pool_infos) + return ret + + +def _assign_targets_to_primfuncs_irmodule(mod, target): + """helper to assign target for PrimFunc in a IRModule""" + ret = tvm.IRModule() + for global_var, basefunc in mod.functions.items(): + if isinstance(basefunc, tvm.tir.PrimFunc): + ret[global_var] = basefunc.with_attr("target", target) + return ret + + +def test_create_array_buffer_info(): + target = Target("c") + global_ws_pool = usmp_utils.PoolInfo( + pool_name="global_workspace", + target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + ) + fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo") + tir_mod = LinearStructure + tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) + tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_ws_pool]) + main_func = tir_mod["tvmgen_default_run_model"] + buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) + buffer_info_array = fcreate_array_bi(buffer_info_map) + for buffer_info in buffer_info_array: + assert buffer_info in buffer_info_map.keys() + + +if __name__ == "__main__": + pytest.main([__file__] + sys.argv[1:])