From 6129eb087e687ac299708b68c400d0104bad97e2 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Wed, 27 Oct 2021 12:33:57 +0100 Subject: [PATCH 1/7] [TIR][USMP] adding the pass to convert to pool offsets This commit adds a transform pass that consumes the planned pool allocations using memory planning algorithm that convertes them to pool offsets. * adds two test cases for a linear structure with two pools * adds test case with a single pool for residual structures Change-Id: I9d31e854461b5c21df72d1452120d286b96791c0 --- python/tvm/script/tir/__init__.py | 2 +- python/tvm/script/tir/ty.py | 1 + python/tvm/tir/usmp/__init__.py | 1 + python/tvm/tir/usmp/transform/__init__.py | 20 + python/tvm/tir/usmp/transform/_ffi_api.py | 21 + python/tvm/tir/usmp/transform/transform.py | 40 ++ src/printer/text_printer.h | 7 +- .../convert_pool_allocations_to_offsets.cc | 278 ++++++++++ src/tir/usmp/utils.cc | 15 + ...orm_convert_pool_allocations_to_offsets.py | 519 ++++++++++++++++++ 10 files changed, 900 insertions(+), 4 deletions(-) create mode 100644 python/tvm/tir/usmp/transform/__init__.py create mode 100644 python/tvm/tir/usmp/transform/_ffi_api.py create mode 100644 python/tvm/tir/usmp/transform/transform.py create mode 100644 src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc create mode 100644 tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/tir/__init__.py index 472b3de0e43b..de4045913102 100644 --- a/python/tvm/script/tir/__init__.py +++ b/python/tvm/script/tir/__init__.py @@ -17,7 +17,7 @@ """TVMScript for TIR""" # Type system -from .ty import int8, int16, int32, int64, float16, float32, float64 +from .ty import uint8, int8, int16, int32, int64, float16, float32, float64 from .ty import boolean, handle, Ptr, Tuple, Buffer from .prim_func import prim_func diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/tir/ty.py index 2808e7a48735..0432692f5f4f 100644 --- a/python/tvm/script/tir/ty.py +++ b/python/tvm/script/tir/ty.py @@ -137,6 +137,7 @@ def __getitem__(self, args): pass # pylint: disable=unnecessary-pass +uint8 = ConcreteType("uint8") int8 = ConcreteType("int8") int16 = ConcreteType("int16") int32 = ConcreteType("int32") diff --git a/python/tvm/tir/usmp/__init__.py b/python/tvm/tir/usmp/__init__.py index 8aa0d4ccfe88..514727d52e2e 100644 --- a/python/tvm/tir/usmp/__init__.py +++ b/python/tvm/tir/usmp/__init__.py @@ -18,4 +18,5 @@ """Namespace for Unified Static Memory Planner""" from . import analysis +from . import transform from .utils import BufferInfo diff --git a/python/tvm/tir/usmp/transform/__init__.py b/python/tvm/tir/usmp/transform/__init__.py new file mode 100644 index 000000000000..1a9d83328f8d --- /dev/null +++ b/python/tvm/tir/usmp/transform/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import, redefined-builtin +"""Namespace for Unified Static Memory Planner""" + +from .transform import convert_pool_allocations_to_offsets diff --git a/python/tvm/tir/usmp/transform/_ffi_api.py b/python/tvm/tir/usmp/transform/_ffi_api.py new file mode 100644 index 000000000000..7973ca5b0da0 --- /dev/null +++ b/python/tvm/tir/usmp/transform/_ffi_api.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.tir.usmp.analysis""" +import tvm._ffi + + +tvm._ffi._init_api("tir.usmp.transform", __name__) diff --git a/python/tvm/tir/usmp/transform/transform.py b/python/tvm/tir/usmp/transform/transform.py new file mode 100644 index 000000000000..5739d5f78067 --- /dev/null +++ b/python/tvm/tir/usmp/transform/transform.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""USMP Transform Python API for passes""" +# pylint: disable=invalid-name + +from typing import Dict + +from . import _ffi_api +from ....tir import Stmt +from ..utils import PoolAllocation + + +def convert_pool_allocations_to_offsets(pool_allocations: Dict[Stmt, PoolAllocation]): + """Convert pool allocations to Load nodes with offsets from pools. + + Parameters + ---------- + pool_allocations : Dict[Stmt, PoolAllocation] + Allocate or AllocateConst node to pool allocation mapping + + Returns + ------- + ret: tvm.transform.Pass + The registered pass that converts the allocations to offsets. + """ + return _ffi_api.ConvertPoolAllocationsToOffsets(pool_allocations) diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index ebd667ae2ac7..97146b84450d 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -449,10 +449,11 @@ class TextPrinter { Doc PrintFinal(const ObjectRef& node) { Doc doc; - if (node->IsInstance()) { + if (node.defined() && node->IsInstance()) { doc << PrintMod(Downcast(node)); - } else if (node->IsInstance() || node->IsInstance() || - node->IsInstance()) { + } else if (node.defined() && + (node->IsInstance() || node->IsInstance() || + node->IsInstance())) { doc << tir_text_printer_.Print(node); } else { doc << relay_text_printer_.PrintFinal(node); diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc new file mode 100644 index 000000000000..fb1ec91ec149 --- /dev/null +++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/analysis/usmp/transform/convert_pool_allocations_to_offsets.cc + * \brief This pass would convert the pool allocations to offsets from pools + */ + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace tir { +namespace usmp { + +class PoolAllocationToOffsetConverter : public StmtExprMutator { + public: + explicit PoolAllocationToOffsetConverter(const IRModule& module, + const Map& pool_allocations) + : pool_allocations_(pool_allocations) { + module_ = module->ShallowCopy(); + for (const auto& gv_func : module_->functions) { + function_global_vars_.Set(gv_func.first->name_hint, gv_func.first); + } + for (const auto& kv : pool_allocations) { + // TODO(@manupa-arm): add AllocateConstNode when it is available + ICHECK(kv.first->IsInstance()); + Allocate allocate_node = Downcast(kv.first); + PoolAllocation pool_allocation = kv.second; + PoolInfo pool_info = pool_allocation->pool_info; + pool_ordering_.insert(pool_info); + int byte_pool_offset = pool_allocation->byte_offset->value; + int required_pool_size_for_allocation = + byte_pool_offset + CalculateExtentsSize(allocate_node.operator->()); + if (all_pools_sizes_.find(pool_info) == all_pools_sizes_.end()) { + all_pools_sizes_[pool_info] = required_pool_size_for_allocation; + } else { + int prev_required_pool_size = all_pools_sizes_[pool_info]; + if (prev_required_pool_size < required_pool_size_for_allocation) { + all_pools_sizes_[pool_info] = required_pool_size_for_allocation; + } + } + } + } + IRModule operator()(); + + private: + PrimExpr VisitExpr_(const CallNode* op) override; + Stmt VisitStmt_(const AllocateNode* op) override; + PrimExpr VisitExpr_(const LoadNode* op) override; + Stmt VisitStmt_(const StoreNode* op) override; + + /*! \brief This is a structure where the modified function + * signature is kept while body of the function is mutated + */ + struct ScopeInfo { + Array params; + Map pools_to_params; + Map buffer_map; + }; + + /*! \brief The function scope information that are needed + * in the mutation of the function need to be stacked and + * popped when each function is entered/exited in the + * mutation process. + */ + std::stack scope_stack; + /*! \brief Each PrimFunc signature needs to be updated + * with pool variables. This is a helper function to + * capture the updated information to ScopeInfo object. + */ + ScopeInfo UpdateFunctionScopeInfo(const PrimFunc& original_func); + /*! \brief This is a helper to create the PrimFunc with + * pool variables that calls the UpdateFunctionScopeInfo + * inside of it. + */ + PrimFunc CreatePrimFuncWithPoolParams(const PrimFunc& original_primfunc); + /*! \brief This is a helper to append the pool args to + * the callsite of the function. + */ + Array AppendPoolParamsToArgs(const CallNode* op); + /*! \brief Some arguments that used to be Allocate nodes + * should be replaced by Let nodes in the pass that loads + * the space from a pool variable. + */ + Array ReplaceAllocateArgsWithLetArgs(const Array& args); + + /*! \brief The tir::Var map to PoolInfo objects */ + Map primfunc_args_to_pool_info_map_; + /*! \brief The buffer var map to their allocate nodes */ + Map allocate_var_to_stmt_map_; + /*! \brief The IRModule being constructed/mutated */ + IRModule module_; + /*! \brief The input allocate node to PoolAllocation map */ + Map pool_allocations_; + /*! \brief The set of ordered pools to ensure an unique order of args for functions */ + std::set pool_ordering_; + /*! \brief The storage of calculated pool size at init */ + std::unordered_map all_pools_sizes_; + /*! \brief The AoT codegen uses extern_calls due to some functions not being exposed in the TIR + * IRModule This maps maintains the map of which to each function + */ + Map function_global_vars_; + /*! \brief After mutation, each allocate buffer is replaced with tir::Var that is let bounded + * to position from a pool as designated by a PoolAllocation + */ + Map allocate_buf_to_let_var_; + /*! \brief A counter to give references to pools a reproducible unique set of names */ + int pool_var_count_ = 0; +}; + +PoolAllocationToOffsetConverter::ScopeInfo PoolAllocationToOffsetConverter::UpdateFunctionScopeInfo( + const PrimFunc& original_func) { + ScopeInfo si; + si.params = original_func->params; + si.buffer_map = original_func->buffer_map; + Map ret; + for (const PoolInfo& pool_info : pool_ordering_) { + String pool_ref_name = pool_info->pool_name + "_" + std::to_string(pool_var_count_++); + String var_name = pool_ref_name + "_var"; + DataType elem_dtype = DataType::UInt(8); + Var buffer_var(var_name, PointerType(PrimType(elem_dtype), "global")); + Var pool_var(var_name, DataType::Handle()); + si.params.push_back(pool_var); + si.pools_to_params.Set(pool_info, pool_var); + + int pool_size = all_pools_sizes_[pool_info]; + String buffer_var_name = pool_ref_name + "_buffer_var"; + si.buffer_map.Set(pool_var, Buffer(buffer_var, elem_dtype, {pool_size}, {1}, 1, buffer_var_name, + 16, 1, BufferType::kDefault)); + } + return si; +} + +PrimFunc PoolAllocationToOffsetConverter::CreatePrimFuncWithPoolParams( + const PrimFunc& original_primfunc) { + ScopeInfo si = UpdateFunctionScopeInfo(original_primfunc); + this->scope_stack.push(si); + Stmt new_body = this->VisitStmt(original_primfunc->body); + this->scope_stack.pop(); + return PrimFunc(si.params, new_body, original_primfunc->ret_type, si.buffer_map, + original_primfunc->attrs); +} + +Array PoolAllocationToOffsetConverter::AppendPoolParamsToArgs(const CallNode* op) { + Array new_args; + for (const auto& arg : op->args) { + new_args.push_back(VisitExpr(arg)); + } + for (const auto& pools_vars : this->scope_stack.top().pools_to_params) { + tir::Var pool_var = pools_vars.second; + new_args.push_back(pool_var); + } + return new_args; +} + +Array PoolAllocationToOffsetConverter::ReplaceAllocateArgsWithLetArgs( + const Array& args) { + Array ret; + for (const PrimExpr& arg : args) { + if (arg->IsInstance() && + allocate_buf_to_let_var_.find(Downcast(arg)) != allocate_buf_to_let_var_.end()) { + ret.push_back(allocate_buf_to_let_var_[Downcast(arg)]); + } else { + ret.push_back(arg); + } + } + return ret; +} + +PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const CallNode* op) { + if (op->op.same_as(builtin::call_extern())) { + String func_name = Downcast(op->args[0])->value; + GlobalVar gv = function_global_vars_.at(func_name); + PrimFunc func = Downcast(module_->Lookup(gv)); + PrimFunc prim_func = CreatePrimFuncWithPoolParams(func); + module_->Update(gv, prim_func); + Array new_args = AppendPoolParamsToArgs(op); + new_args = ReplaceAllocateArgsWithLetArgs(new_args); + return Call(op->dtype, builtin::call_extern(), new_args); + } else if (op->op->IsInstance()) { + PrimFunc func = Downcast(op->op); + PrimFunc prim_func = CreatePrimFuncWithPoolParams(func); + Array new_args = AppendPoolParamsToArgs(op); + new_args = ReplaceAllocateArgsWithLetArgs(new_args); + return Call(op->dtype, prim_func, new_args); + } else { + return StmtExprMutator::VisitExpr_(op); + } +} + +Stmt PoolAllocationToOffsetConverter::VisitStmt_(const AllocateNode* op) { + if (pool_allocations_.count(GetRef(op))) { + ScopeInfo scope_info = scope_stack.top(); + PoolAllocation pool_allocation = pool_allocations_[GetRef(op)]; + Var param = scope_info.pools_to_params[pool_allocation->pool_info]; + Buffer buffer_var = scope_info.buffer_map[param]; + ICHECK(pool_allocation->byte_offset < all_pools_sizes_[pool_allocation->pool_info]); + Load load_node = Load(op->dtype, buffer_var->data, pool_allocation->byte_offset, op->condition); + Var tir_var(op->buffer_var->name_hint + "_let", op->dtype); + allocate_buf_to_let_var_.Set(op->buffer_var, tir_var); + Stmt new_body = VisitStmt(op->body); + allocate_buf_to_let_var_.erase(op->buffer_var); + return LetStmt(tir_var, load_node, new_body); + } + return StmtExprMutator::VisitStmt_(op); +} + +Stmt PoolAllocationToOffsetConverter::VisitStmt_(const StoreNode* op) { + if (allocate_buf_to_let_var_.find(op->buffer_var) != allocate_buf_to_let_var_.end()) { + return Store(allocate_buf_to_let_var_[op->buffer_var], VisitExpr(op->value), op->index, + VisitExpr(op->predicate)); + } + return StmtExprMutator::VisitStmt_(op); +} + +PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const LoadNode* op) { + if (allocate_buf_to_let_var_.find(op->buffer_var) != allocate_buf_to_let_var_.end()) { + return Load(op->dtype, allocate_buf_to_let_var_[op->buffer_var], op->index, + VisitExpr(op->predicate)); + } + return StmtExprMutator::VisitExpr_(op); +} + +IRModule PoolAllocationToOffsetConverter::operator()() { + GlobalVar gv = function_global_vars_.at(::tvm::runtime::symbol::tvm_run_func_suffix); + PrimFunc main_func = Downcast(module_->Lookup(gv)); + ScopeInfo si = UpdateFunctionScopeInfo(main_func); + this->scope_stack.push(si); + Stmt main_func_body = this->VisitStmt(main_func->body); + this->scope_stack.pop(); + module_->Update(gv, PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, + main_func->attrs)); + return this->module_; +} + +namespace transform { + +tvm::transform::Pass ConvertPoolAllocationsToOffsets( + const Map& pool_allocations) { + auto pass_func = [=](IRModule m, tvm::transform::PassContext ctx) { + return Downcast(PoolAllocationToOffsetConverter(m, pool_allocations)()); + }; + return tvm::transform::CreateModulePass(pass_func, 0, "tir.usmp.ConvertPoolAllocationsToOffsets", + {}); +} + +TVM_REGISTER_GLOBAL("tir.usmp.transform.ConvertPoolAllocationsToOffsets") + .set_body_typed(ConvertPoolAllocationsToOffsets); + +} // namespace transform + +} // namespace usmp +} // namespace tir +} // namespace tvm diff --git a/src/tir/usmp/utils.cc b/src/tir/usmp/utils.cc index 7a6a683770b0..69529c8c196f 100644 --- a/src/tir/usmp/utils.cc +++ b/src/tir/usmp/utils.cc @@ -144,6 +144,19 @@ Array CreateArrayBufferInfo(const Map& buffer_info return ret; } +Map AssignStmtPoolAllocations( + const Map& buffer_info_to_stmt, + const Map& buffer_info_to_pool_allocation) { + Map ret; + for (const auto& kv : buffer_info_to_pool_allocation) { + BufferInfo bi = kv.first; + Stmt stmt_ = buffer_info_to_stmt[bi]; + PoolAllocation pa = kv.second; + ret.Set(stmt_, pa); + } + return ret; +} + Integer CalculateExtentsSize(const AllocateNode* op) { size_t element_size_bytes = op->dtype.bytes(); size_t num_elements = 1; @@ -163,6 +176,8 @@ TVM_REGISTER_GLOBAL("tir.usmp.CreateArrayBufferInfo") return (CreateArrayBufferInfo(buffer_info_map)); }); +TVM_REGISTER_GLOBAL("tir.usmp.AssignStmtPoolAllocations").set_body_typed(AssignStmtPoolAllocations); + } // namespace usmp } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py new file mode 100644 index 000000000000..bd3e6287b7c7 --- /dev/null +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -0,0 +1,519 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import sys + +import tvm +from tvm.script import tir as T +from tvm.tir import stmt_functor +from tvm.tir.usmp import utils as usmp_utils +from tvm.target import Target + + +def _get_primfuncs_from_module(module): + primfuncs = list() + for gv, primfunc in module.functions.items(): + primfuncs.append(primfunc) + return primfuncs + + +def assign_poolinfos_to_allocates_in_primfunc(primfunc, pool_infos): + """helper to assing poolinfos to allocate nodes in a tir.PrimFunc""" + + def set_poolinfos(stmt): + if isinstance(stmt, tvm.tir.Allocate): + return tvm.tir.Allocate( + buffer_var=stmt.buffer_var, + dtype=stmt.dtype, + extents=stmt.extents, + condition=stmt.condition, + body=stmt.body, + annotations={tvm.tir.usmp.utils.CANDIDATE_MEMORY_POOL_ATTR: pool_infos}, + ) + + return primfunc.with_body(stmt_functor.ir_transform(primfunc.body, None, set_poolinfos)) + + +def assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos): + """helper to assing poolinfos to allocate nodes in a IRModule""" + ret = tvm.IRModule() + for global_var, basefunc in mod.functions.items(): + if isinstance(basefunc, tvm.tir.PrimFunc): + ret[global_var] = assign_poolinfos_to_allocates_in_primfunc(basefunc, pool_infos) + return ret + + +# fmt: off +@tvm.script.ir_module +class LinearStructure: + @T.prim_func + def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) + placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + for ax0_ax1_fused_1 in T.serial(0, 224): + for ax2_1, ax3_inner_1 in T.grid(224, 3): + T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(T.load("uint8", placeholder_4.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)), "int16") - T.load("int16", placeholder_5.data, 0)), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) + placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_7 = T.allocate([157323], "int16", "global") + for i0_i1_fused_7 in T.serial(0, 229): + for i2_7, i3_7 in T.grid(229, 3): + T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), T.load("int16", placeholder_65.data, ((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)), T.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): + Conv2dOutput_7 = T.allocate([64], "int32", "global") + for ff_3 in T.serial(0, 64): + T.store(Conv2dOutput_7, ff_3, 0, True) + for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): + T.store(Conv2dOutput_7, ff_3, (T.load("int32", Conv2dOutput_7, ff_3) + (T.cast(T.load("int16", PaddedInput_7, (((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)), "int32")*T.cast(T.load("int16", placeholder_66.data, ((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)), "int32"))), True) + for ax3_inner_7 in T.serial(0, 64): + T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_7, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7)), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) + + @T.prim_func + def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) + placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + tensor_2 = T.allocate([200704], "uint8", "global") + for ax0_ax1_fused_4 in T.serial(0, 56): + for ax2_4 in T.serial(0, 56): + for ax3_init in T.serial(0, 64): + T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) + for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): + T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dtype="uint8")), True) + for ax0_ax1_fused_5 in T.serial(0, 56): + for ax2_5, ax3_3 in T.grid(56, 64): + T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True) + + @T.prim_func + def run_model(input: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "run_model", "runner_function": True}) + # body + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 1) + sid_9 = T.allocate([301056], "int8", "global") + sid_8 = T.allocate([802816], "int8", "global") + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32")) +# fmt: on + + +# fmt: off +@tvm.script.ir_module +class LinearStructurePlanned: + @T.prim_func + def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle, fast_memory_6_var: T.handle, slow_memory_7_var: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) + placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8") + T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16") + fast_memory_6_buffer_var = T.match_buffer(fast_memory_6_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) + slow_memory_7_buffer_var = T.match_buffer(slow_memory_7_var, [802816], dtype="uint8", strides=[1], elem_offset=1, align=16) + # body + tensor_2_let: T.uint8 = T.load("uint8", fast_memory_6_buffer_var.data, 0) + for ax0_ax1_fused_4, ax2_4 in T.grid(56, 56): + for ax3_init in T.serial(0, 64): + T.store(tensor_2_let, ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_init, T.uint8(0), True) + for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): + T.store(tensor_2_let, ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_2, T.max(T.load("uint8", tensor_2_let, ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_2), T.if_then_else(ax0_ax1_fused_4 * 2 + T.floordiv(rv0_rv1_fused_1, 3) < 112 and ax2_4 * 2 + T.floormod(rv0_rv1_fused_1, 3) < 112, T.load("uint8", placeholder_29.data, ax0_ax1_fused_4 * 14336 + T.floordiv(rv0_rv1_fused_1, 3) * 7168 + ax2_4 * 128 + T.floormod(rv0_rv1_fused_1, 3) * 64 + ax3_2), T.uint8(0), dtype="uint8")), True) + for ax0_ax1_fused_5, ax2_5, ax3_3 in T.grid(56, 56, 64): + T.store(T_cast_7.data, ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3, T.cast(T.load("uint8", tensor_2_let, ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3), "int16"), True) + + @T.prim_func + def run_model(input: T.handle, output: T.handle, fast_memory_0_var: T.handle, slow_memory_1_var: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "run_model", "runner_function": True}) + fast_memory_0_buffer_var = T.match_buffer(fast_memory_0_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) + slow_memory_1_buffer_var = T.match_buffer(slow_memory_1_var, [802816], dtype="uint8", strides=[1], elem_offset=1, align=16) + # body + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 1) + sid_9_let: T.int8 = T.load("int8", slow_memory_1_buffer_var.data, 314646) + sid_8_let: T.int8 = T.load("int8", slow_memory_1_buffer_var.data, 0) + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9_let, fast_memory_0_var, slow_memory_1_var, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9_let, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8_let, fast_memory_0_var, slow_memory_1_var, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8_let, output, fast_memory_0_var, slow_memory_1_var, dtype="int32")) + + @T.prim_func + def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle, fast_memory_2_var: T.handle, slow_memory_3_var: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) + placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dtype="uint8") + placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16") + T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16") + fast_memory_2_buffer_var = T.match_buffer(fast_memory_2_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) + slow_memory_3_buffer_var = T.match_buffer(slow_memory_3_var, [802816], dtype="uint8", strides=[1], elem_offset=1, align=16) + # body + for ax0_ax1_fused_1, ax2_1, ax3_inner_1 in T.grid(224, 224, 3): + T.store(T_subtract_1.data, ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1, T.cast(T.load("uint8", placeholder_4.data, ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1), "int16") - T.load("int16", placeholder_5.data, 0), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle, fast_memory_4_var: T.handle, slow_memory_5_var: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) + placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16") + placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16") + placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32") + T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8") + fast_memory_4_buffer_var = T.match_buffer(fast_memory_4_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) + slow_memory_5_buffer_var = T.match_buffer(slow_memory_5_var, [802816], dtype="uint8", strides=[1], elem_offset=1, align=16) + # body + PaddedInput_7_let: T.int16 = T.load("int16", slow_memory_5_buffer_var.data, 0) + for i0_i1_fused_7, i2_7, i3_7 in T.grid(229, 229, 3): + T.store(PaddedInput_7_let, i0_i1_fused_7 * 687 + i2_7 * 3 + i3_7, T.if_then_else(2 <= i0_i1_fused_7 and i0_i1_fused_7 < 226 and 2 <= i2_7 and i2_7 < 226, T.load("int16", placeholder_65.data, i0_i1_fused_7 * 672 + i2_7 * 3 + i3_7 - 1350), T.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): + Conv2dOutput_7_let: T.int32 = T.load("int32", fast_memory_4_buffer_var.data, 0) + for ff_3 in T.serial(0, 64): + T.store(Conv2dOutput_7_let, ff_3, 0, True) + for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): + T.store(Conv2dOutput_7_let, ff_3, T.load("int32", Conv2dOutput_7_let, ff_3) + T.cast(T.load("int16", PaddedInput_7_let, T.floordiv(ax0_ax1_fused_ax2_fused_7, 112) * 1374 + ry_2 * 687 + T.floormod(ax0_ax1_fused_ax2_fused_7, 112) * 6 + rx_2 * 3 + rc_7), "int32") * T.cast(T.load("int16", placeholder_66.data, ry_2 * 1344 + rx_2 * 192 + rc_7 * 64 + ff_3), "int32"), True) + for ax3_inner_7 in T.serial(0, 64): + T.store(T_cast_21.data, ax0_ax1_fused_ax2_fused_7 * 64 + ax3_inner_7, T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_7_let, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) +# fmt: on + + +def test_linear(): + fast_memory_pool = usmp_utils.PoolInfo( + pool_name="fast_memory", + target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + size_hint_bytes=200704, + ) + slow_memory_pool = usmp_utils.PoolInfo( + pool_name="slow_memory", target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS} + ) + tir_mod = LinearStructure + tir_mod = assign_poolinfos_to_allocates_in_irmodule( + tir_mod, [fast_memory_pool, slow_memory_pool] + ) + main_func = tir_mod["run_model"] + buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) + + fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo") + buffer_info_arr = fcreate_array_bi(buffer_info_map) + fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size") + buffer_pool_allocations = fusmp_algo_greedy_by_size(buffer_info_arr) + fassign_stmt_pool_allocations = tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations") + pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, buffer_pool_allocations) + tir_mod_with_offsets = tvm.tir.usmp.transform.convert_pool_allocations_to_offsets( + pool_allocations + )(tir_mod) + + tir_mod_with_offsets_ref = LinearStructurePlanned + # The TIR produced fails on roundtrip TVMScript testing. + # Therefore, indicates the TVMScript produced here and/or the parser + # is lacking functionality. Thus for these tests, uses a string + # version of the TVMScript for each function as a check instead. + for gv, func in tir_mod_with_offsets_ref.functions.items(): + assert str(tir_mod_with_offsets_ref[gv.name_hint].script()) == str( + tir_mod_with_offsets[gv.name_hint].script() + ) + + +# fmt: off +@tvm.script.ir_module +class ResnetStructure: + @T.prim_func + def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [1, 75, 75, 64], dtype="uint8") + placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") + T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16") + # body + for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): + T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.load("uint8", placeholder_2.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner), "int32") - 94, 1843157232, 31, 1, dtype="int32") + T.load("int32", placeholder_3.data, ax3_outer * 16 + ax3_inner), 255), 0), "uint8"), "int16"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True}) + placeholder_13 = T.match_buffer(placeholder_10, [1, 75, 75, 64], dtype="int16") + placeholder_14 = T.match_buffer(placeholder_11, [3, 3, 64, 64], dtype="int16") + placeholder_15 = T.match_buffer(placeholder_12, [1, 1, 1, 64], dtype="int32") + T_cast_5 = T.match_buffer(T_cast_4, [1, 75, 75, 64], dtype="int16") + # body + PaddedInput_1 = T.allocate([379456], "int16", "global") + for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): + T.store(PaddedInput_1, i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1, T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, T.load("int16", placeholder_13.data, i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864), T.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): + Conv2dOutput_1 = T.allocate([64], "int32", "global") + for ff_1 in T.serial(0, 64): + T.store(Conv2dOutput_1, ff_1, 0, True) + for ry, rx, rc_1 in T.grid(3, 3, 64): + T.store(Conv2dOutput_1, ff_1, T.load("int32", Conv2dOutput_1, ff_1) + T.cast(T.load("int16", PaddedInput_1, T.floordiv(ax0_ax1_fused_ax2_fused_1, 75) * 4928 + ry * 4928 + rx * 64 + T.floormod(ax0_ax1_fused_ax2_fused_1, 75) * 64 + rc_1), "int32") * T.cast(T.load("int16", placeholder_14.data, ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1), "int32"), True) + for ax3_inner_2 in T.serial(0, 64): + T.store(T_cast_5.data, ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_1, ax3_inner_2) + T.load("int32", placeholder_15.data, ax3_inner_2), 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", "tir.noalias": True}) + placeholder_19 = T.match_buffer(placeholder_16, [1, 75, 75, 64], dtype="int16") + placeholder_20 = T.match_buffer(placeholder_17, [1, 1, 64, 256], dtype="int16") + placeholder_21 = T.match_buffer(placeholder_18, [1, 1, 1, 256], dtype="int32") + T_add_1 = T.match_buffer(T_add, [1, 75, 75, 256], dtype="int32") + # body + PaddedInput_2 = T.allocate([360000], "int16", "global") + for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): + T.store(PaddedInput_2, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2, T.load("int16", placeholder_19.data, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2), True) + for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625): + Conv2dOutput_2 = T.allocate([64], "int32", "global") + for ax3_outer_1 in T.serial(0, 4): + for ff_2 in T.serial(0, 64): + T.store(Conv2dOutput_2, ff_2, 0, True) + for rc_2 in T.serial(0, 64): + T.store(Conv2dOutput_2, ff_2, T.load("int32", Conv2dOutput_2, ff_2) + T.cast(T.load("int16", PaddedInput_2, ax0_ax1_fused_ax2_fused_2 * 64 + rc_2), "int32") * T.cast(T.load("int16", placeholder_20.data, rc_2 * 256 + ax3_outer_1 * 64 + ff_2), "int32"), True) + for ax3_inner_3 in T.serial(0, 64): + T.store(T_add_1.data, ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3, T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_2, ax3_inner_3) + T.load("int32", placeholder_21.data, ax3_outer_1 * 64 + ax3_inner_3), 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136, True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", "tir.noalias": True}) + placeholder_29 = T.match_buffer(placeholder_22, [1, 75, 75, 64], dtype="int16") + placeholder_27 = T.match_buffer(placeholder_23, [1, 1, 64, 256], dtype="int16") + placeholder_26 = T.match_buffer(placeholder_24, [1, 1, 1, 256], dtype="int32") + placeholder_28 = T.match_buffer(placeholder_25, [1, 75, 75, 256], dtype="int32") + T_cast_7 = T.match_buffer(T_cast_6, [1, 75, 75, 256], dtype="uint8") + # body + PaddedInput_3 = T.allocate([360000], "int16", "global") + for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): + T.store(PaddedInput_3, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3, T.load("int16", placeholder_29.data, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3), True) + for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625): + Conv2dOutput_3 = T.allocate([64], "int32", "global") + for ax3_outer_2 in T.serial(0, 4): + for ff_3 in T.serial(0, 64): + T.store(Conv2dOutput_3, ff_3, 0, True) + for rc_3 in T.serial(0, 64): + T.store(Conv2dOutput_3, ff_3, T.load("int32", Conv2dOutput_3, ff_3) + T.cast(T.load("int16", PaddedInput_3, ax0_ax1_fused_ax2_fused_3 * 64 + rc_3), "int32") * T.cast(T.load("int16", placeholder_27.data, rc_3 * 256 + ax3_outer_2 * 64 + ff_3), "int32"), True) + for ax3_inner_4 in T.serial(0, 64): + T.store(T_cast_7.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4, T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_3, ax3_inner_4) + T.load("int32", placeholder_26.data, ax3_outer_2 * 64 + ax3_inner_4), 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + T.load("int32", placeholder_28.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4), 255), 0), "uint8"), True) + + @T.prim_func + def run_model(input: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "run_model", "runner_function": True}) + # body + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 1) + sid_2 = T.allocate([720000], "int8", "global") + sid_6 = T.allocate([5760000], "int8", "global") + sid_7 = T.allocate([720000], "int8", "global") + sid_8 = T.allocate([720000], "int8", "global") + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6, output, dtype="int32")) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True}) + placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16") + placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16") + placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32") + T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16") + # body + PaddedInput = T.allocate([360000], "int16", "global") + for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): + T.store(PaddedInput, i0_i1_fused * 4800 + i2 * 64 + i3, T.load("int16", placeholder_7.data, i0_i1_fused * 4800 + i2 * 64 + i3), True) + for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): + Conv2dOutput = T.allocate([64], "int32", "global") + for ff in T.serial(0, 64): + T.store(Conv2dOutput, ff, 0, True) + for rc in T.serial(0, 64): + T.store(Conv2dOutput, ff, T.load("int32", Conv2dOutput, ff) + T.cast(T.load("int16", PaddedInput, ax0_ax1_fused_ax2_fused * 64 + rc), "int32") * T.cast(T.load("int16", placeholder_8.data, rc * 64 + ff), "int32"), True) + for ax3_inner_1 in T.serial(0, 64): + T.store(T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput, ax3_inner_1) + T.load("int32", placeholder_9.data, ax3_inner_1), 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True) +# fmt: on + + +# fmt: off +@tvm.script.ir_module +class ResnetStructurePlanned: + @T.prim_func + def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [1, 75, 75, 64], dtype="uint8") + placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") + T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16") + global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + # body + for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): + T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.load("uint8", placeholder_2.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner), "int32") - 94, 1843157232, 31, 1, dtype="int32") + T.load("int32", placeholder_3.data, ax3_outer * 16 + ax3_inner), 255), 0), "uint8"), "int16"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle, global_workspace_4_var: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", "tir.noalias": True}) + placeholder_19 = T.match_buffer(placeholder_16, [1, 75, 75, 64], dtype="int16") + placeholder_20 = T.match_buffer(placeholder_17, [1, 1, 64, 256], dtype="int16") + placeholder_21 = T.match_buffer(placeholder_18, [1, 1, 1, 256], dtype="int32") + T_add_1 = T.match_buffer(T_add, [1, 75, 75, 256], dtype="int32") + global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + # body + PaddedInput_2_let: T.int16 = T.load("int16", global_workspace_4_buffer_var.data, 7200000) + for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): + T.store(PaddedInput_2_let, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2, T.load("int16", placeholder_19.data, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2), True) + for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625): + Conv2dOutput_2_let: T.int32 = T.load("int32", global_workspace_4_buffer_var.data, 7920000) + for ax3_outer_1 in T.serial(0, 4): + for ff_2 in T.serial(0, 64): + T.store(Conv2dOutput_2_let, ff_2, 0, True) + for rc_2 in T.serial(0, 64): + T.store(Conv2dOutput_2_let, ff_2, T.load("int32", Conv2dOutput_2_let, ff_2) + T.cast(T.load("int16", PaddedInput_2_let, ax0_ax1_fused_ax2_fused_2 * 64 + rc_2), "int32") * T.cast(T.load("int16", placeholder_20.data, rc_2 * 256 + ax3_outer_1 * 64 + ff_2), "int32"), True) + for ax3_inner_3 in T.serial(0, 64): + T.store(T_add_1.data, ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3, T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_2_let, ax3_inner_3) + T.load("int32", placeholder_21.data, ax3_outer_1 * 64 + ax3_inner_3), 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136, True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle, global_workspace_3_var: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True}) + placeholder_13 = T.match_buffer(placeholder_10, [1, 75, 75, 64], dtype="int16") + placeholder_14 = T.match_buffer(placeholder_11, [3, 3, 64, 64], dtype="int16") + placeholder_15 = T.match_buffer(placeholder_12, [1, 1, 1, 64], dtype="int32") + T_cast_5 = T.match_buffer(T_cast_4, [1, 75, 75, 64], dtype="int16") + global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + # body + PaddedInput_1_let: T.int16 = T.load("int16", global_workspace_3_buffer_var.data, 0) + for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): + T.store(PaddedInput_1_let, i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1, T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, T.load("int16", placeholder_13.data, i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864), T.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): + Conv2dOutput_1_let: T.int32 = T.load("int32", global_workspace_3_buffer_var.data, 7200000) + for ff_1 in T.serial(0, 64): + T.store(Conv2dOutput_1_let, ff_1, 0, True) + for ry, rx, rc_1 in T.grid(3, 3, 64): + T.store(Conv2dOutput_1_let, ff_1, T.load("int32", Conv2dOutput_1_let, ff_1) + T.cast(T.load("int16", PaddedInput_1_let, T.floordiv(ax0_ax1_fused_ax2_fused_1, 75) * 4928 + ry * 4928 + rx * 64 + T.floormod(ax0_ax1_fused_ax2_fused_1, 75) * 64 + rc_1), "int32") * T.cast(T.load("int16", placeholder_14.data, ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1), "int32"), True) + for ax3_inner_2 in T.serial(0, 64): + T.store(T_cast_5.data, ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_1_let, ax3_inner_2) + T.load("int32", placeholder_15.data, ax3_inner_2), 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle, global_workspace_2_var: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True}) + placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16") + placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16") + placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32") + T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16") + global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + # body + PaddedInput_let: T.int16 = T.load("int16", global_workspace_2_buffer_var.data, 6480000) + for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): + T.store(PaddedInput_let, i0_i1_fused * 4800 + i2 * 64 + i3, T.load("int16", placeholder_7.data, i0_i1_fused * 4800 + i2 * 64 + i3), True) + for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): + Conv2dOutput_let: T.int32 = T.load("int32", global_workspace_2_buffer_var.data, 7200000) + for ff in T.serial(0, 64): + T.store(Conv2dOutput_let, ff, 0, True) + for rc in T.serial(0, 64): + T.store(Conv2dOutput_let, ff, T.load("int32", Conv2dOutput_let, ff) + T.cast(T.load("int16", PaddedInput_let, ax0_ax1_fused_ax2_fused * 64 + rc), "int32") * T.cast(T.load("int16", placeholder_8.data, rc * 64 + ff), "int32"), True) + for ax3_inner_1 in T.serial(0, 64): + T.store(T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_let, ax3_inner_1) + T.load("int32", placeholder_9.data, ax3_inner_1), 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True) + + @T.prim_func + def run_model(input: T.handle, output: T.handle, global_workspace_0_var: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "run_model", "runner_function": True}) + global_workspace_0_buffer_var = T.match_buffer(global_workspace_0_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + # body + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 1) + sid_2_let: T.int8 = T.load("int8", global_workspace_0_buffer_var.data, 5760000) + sid_6_let: T.int8 = T.load("int8", global_workspace_0_buffer_var.data, 0) + sid_7_let: T.int8 = T.load("int8", global_workspace_0_buffer_var.data, 6480000) + sid_8_let: T.int8 = T.load("int8", global_workspace_0_buffer_var.data, 6480000) + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2_let, global_workspace_0_var, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2_let, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8_let, global_workspace_0_var, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8_let, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7_let, global_workspace_0_var, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7_let, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6_let, global_workspace_0_var, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2_let, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6_let, output, global_workspace_0_var, dtype="int32")) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle, global_workspace_5_var: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", "tir.noalias": True}) + placeholder_29 = T.match_buffer(placeholder_22, [1, 75, 75, 64], dtype="int16") + placeholder_27 = T.match_buffer(placeholder_23, [1, 1, 64, 256], dtype="int16") + placeholder_26 = T.match_buffer(placeholder_24, [1, 1, 1, 256], dtype="int32") + placeholder_28 = T.match_buffer(placeholder_25, [1, 75, 75, 256], dtype="int32") + T_cast_7 = T.match_buffer(T_cast_6, [1, 75, 75, 256], dtype="uint8") + global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + # body + PaddedInput_3_let: T.int16 = T.load("int16", global_workspace_5_buffer_var.data, 6480000) + for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): + T.store(PaddedInput_3_let, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3, T.load("int16", placeholder_29.data, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3), True) + for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625): + Conv2dOutput_3_let: T.int32 = T.load("int32", global_workspace_5_buffer_var.data, 7200000) + for ax3_outer_2 in T.serial(0, 4): + for ff_3 in T.serial(0, 64): + T.store(Conv2dOutput_3_let, ff_3, 0, True) + for rc_3 in T.serial(0, 64): + T.store(Conv2dOutput_3_let, ff_3, T.load("int32", Conv2dOutput_3_let, ff_3) + T.cast(T.load("int16", PaddedInput_3_let, ax0_ax1_fused_ax2_fused_3 * 64 + rc_3), "int32") * T.cast(T.load("int16", placeholder_27.data, rc_3 * 256 + ax3_outer_2 * 64 + ff_3), "int32"), True) + for ax3_inner_4 in T.serial(0, 64): + T.store(T_cast_7.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4, T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_3_let, ax3_inner_4) + T.load("int32", placeholder_26.data, ax3_outer_2 * 64 + ax3_inner_4), 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + T.load("int32", placeholder_28.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4), 255), 0), "uint8"), True) +# fmt: on + + +def test_fanout(): + global_workspace_pool = usmp_utils.PoolInfo( + pool_name="global_workspace", + target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + ) + tir_mod = ResnetStructure + tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool]) + main_func = tir_mod["run_model"] + buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) + + fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo") + buffer_info_arr = fcreate_array_bi(buffer_info_map) + fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size") + buffer_pool_allocations = fusmp_algo_greedy_by_size(buffer_info_arr) + fassign_stmt_pool_allocations = tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations") + pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, buffer_pool_allocations) + tir_mod_with_offsets = tvm.tir.usmp.transform.convert_pool_allocations_to_offsets( + pool_allocations + )(tir_mod) + + tir_mod_with_offsets_ref = ResnetStructurePlanned + # The TIR produced fails on roundtrip TVMScript testing. + # Therefore, indicates the TVMScript produced here and/or the parser + # is lacking functionality. Thus for these tests, uses a string + # version of the TVMScript for each function as a check instead. + for gv, func in tir_mod_with_offsets_ref.functions.items(): + assert str(tir_mod_with_offsets_ref[gv.name_hint].script()) == str( + tir_mod_with_offsets[gv.name_hint].script() + ) + + +if __name__ == "__main__": + pytest.main([__file__] + sys.argv[1:]) From 2b318cf30482300cc77d70b2d28a11706aa9274f Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Mon, 22 Nov 2021 18:26:58 +0000 Subject: [PATCH 2/7] [TIR][USMP] adding the pass to convert to pool offsets * Adding a toggle to produce TIR that is TVMScript printable for unit testing * Fixing the unit tests * Ensure deterministic pool variable ordering. Change-Id: I317675df03327b0ebbf4ca074255384e63f07cd6 --- include/tvm/tir/usmp/utils.h | 52 +++++ python/tvm/tir/usmp/transform/transform.py | 9 +- src/tir/ir/stmt.cc | 7 +- .../convert_pool_allocations_to_offsets.cc | 144 +++++++++--- src/tir/usmp/utils.cc | 24 ++ ...orm_convert_pool_allocations_to_offsets.py | 219 +++++++++--------- 6 files changed, 303 insertions(+), 152 deletions(-) diff --git a/include/tvm/tir/usmp/utils.h b/include/tvm/tir/usmp/utils.h index 145c61dd518b..a7a245c06378 100644 --- a/include/tvm/tir/usmp/utils.h +++ b/include/tvm/tir/usmp/utils.h @@ -225,6 +225,44 @@ class PoolAllocation : public ObjectRef { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PoolAllocation, ObjectRef, PoolAllocationNode); }; +/*! + * \brief This object contains information post-allocation for PoolInfo objects + */ +struct AllocatedPoolInfoNode : public Object { + /*! \brief The assigned PoolInfo object */ + PoolInfo pool_info; + /*! \brief The allocated size into this pool */ + Integer allocated_size; + /*! \brief An optional associated pool Var*/ + Optional pool_var; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pool_info", &pool_info); + v->Visit("allocated_size", &allocated_size); + v->Visit("pool_var", &pool_var); + } + + bool SEqualReduce(const AllocatedPoolInfoNode* other, SEqualReducer equal) const { + return equal(pool_info, other->pool_info) && equal(allocated_size, other->allocated_size) && + equal(pool_var, other->pool_var); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(pool_info); + hash_reduce(allocated_size); + hash_reduce(pool_var); + } + + static constexpr const char* _type_key = "tir.usmp.AllocatedPoolInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(AllocatedPoolInfoNode, Object); +}; + +class AllocatedPoolInfo : public ObjectRef { + public: + TVM_DLL AllocatedPoolInfo(PoolInfo pool_info, Integer allocated_size, Var pool_var = Var()); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AllocatedPoolInfo, ObjectRef, AllocatedPoolInfoNode); +}; + /*! * \brief Convert the IR-bound BufferInfo map to an array of BufferInfo * @@ -248,6 +286,20 @@ Integer CalculateExtentsSize(const AllocateNode* op); } // namespace usmp } // namespace tir + +namespace attr { +/*! + * \brief This is a BaseFunc attribute to indicate which input var represent + * a PoolInfo Object in the form of a Map. + */ +static constexpr const char* kPoolArgs = "pool_args"; +/*! + * \brief This is a BaseFunc attribute to indicate which input var represent + * a PoolInfo Object in the form of a Map. + */ +static constexpr const char* kPoolInfoIRModuleAttr = "pool_infos"; +} // namespace attr + } // namespace tvm #endif // TVM_TIR_USMP_UTILS_H_ diff --git a/python/tvm/tir/usmp/transform/transform.py b/python/tvm/tir/usmp/transform/transform.py index 5739d5f78067..4976215c21a6 100644 --- a/python/tvm/tir/usmp/transform/transform.py +++ b/python/tvm/tir/usmp/transform/transform.py @@ -24,17 +24,22 @@ from ..utils import PoolAllocation -def convert_pool_allocations_to_offsets(pool_allocations: Dict[Stmt, PoolAllocation]): +def convert_pool_allocations_to_offsets( + pool_allocations: Dict[Stmt, PoolAllocation], emit_tvmscript_printable: bool = False +): """Convert pool allocations to Load nodes with offsets from pools. Parameters ---------- pool_allocations : Dict[Stmt, PoolAllocation] Allocate or AllocateConst node to pool allocation mapping + emit_tvmscript_printable : bool + A toggle to emit TVMScript printable IRModule for unit tests + removing all attributes that should be attached for integration Returns ------- ret: tvm.transform.Pass The registered pass that converts the allocations to offsets. """ - return _ffi_api.ConvertPoolAllocationsToOffsets(pool_allocations) + return _ffi_api.ConvertPoolAllocationsToOffsets(pool_allocations, emit_tvmscript_printable) diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 0d42c20c2822..46c406ba902f 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -35,7 +35,12 @@ namespace tir { LetStmt::LetStmt(Var var, PrimExpr value, Stmt body, Span span) { ICHECK(value.defined()); ICHECK(body.defined()); - ICHECK_EQ(value.dtype(), var.dtype()); + auto vdtype = value.dtype(); + if (var->type_annotation.as()) { + ICHECK(vdtype.is_handle()); + } else { + ICHECK_EQ(value.dtype(), var.dtype()); + } ObjectPtr node = make_object(); node->var = std::move(var); diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc index fb1ec91ec149..1993f8685c15 100644 --- a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc +++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc @@ -39,8 +39,9 @@ namespace usmp { class PoolAllocationToOffsetConverter : public StmtExprMutator { public: explicit PoolAllocationToOffsetConverter(const IRModule& module, - const Map& pool_allocations) - : pool_allocations_(pool_allocations) { + const Map& pool_allocations, + bool emit_tvmscript_printable = false) + : pool_allocations_(pool_allocations), emit_tvmscript_printable_(emit_tvmscript_printable) { module_ = module->ShallowCopy(); for (const auto& gv_func : module_->functions) { function_global_vars_.Set(gv_func.first->name_hint, gv_func.first); @@ -51,7 +52,6 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator { Allocate allocate_node = Downcast(kv.first); PoolAllocation pool_allocation = kv.second; PoolInfo pool_info = pool_allocation->pool_info; - pool_ordering_.insert(pool_info); int byte_pool_offset = pool_allocation->byte_offset->value; int required_pool_size_for_allocation = byte_pool_offset + CalculateExtentsSize(allocate_node.operator->()); @@ -64,12 +64,26 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator { } } } + + for (const auto& kv : all_pools_sizes_) { + PoolInfo pi = kv.first; + int allocated_size = kv.second; + allocated_pool_ordering_.push_back(AllocatedPoolInfo(pi, allocated_size)); + } + std::sort(allocated_pool_ordering_.begin(), allocated_pool_ordering_.end(), + [](const AllocatedPoolInfo& lhs, const AllocatedPoolInfo& rhs) { + if (lhs->pool_info->pool_name < rhs->pool_info->pool_name) { + return true; + } + return false; + }); } IRModule operator()(); private: PrimExpr VisitExpr_(const CallNode* op) override; Stmt VisitStmt_(const AllocateNode* op) override; + // PrimExpr VisitExpr_(const VarNode* op) override; PrimExpr VisitExpr_(const LoadNode* op) override; Stmt VisitStmt_(const StoreNode* op) override; @@ -79,6 +93,7 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator { struct ScopeInfo { Array params; Map pools_to_params; + Array allocated_pool_params; Map buffer_map; }; @@ -101,7 +116,7 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator { /*! \brief This is a helper to append the pool args to * the callsite of the function. */ - Array AppendPoolParamsToArgs(const CallNode* op); + Array AppendPoolParamsToArgs(const Array& args); /*! \brief Some arguments that used to be Allocate nodes * should be replaced by Let nodes in the pass that loads * the space from a pool variable. @@ -117,7 +132,7 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator { /*! \brief The input allocate node to PoolAllocation map */ Map pool_allocations_; /*! \brief The set of ordered pools to ensure an unique order of args for functions */ - std::set pool_ordering_; + std::vector allocated_pool_ordering_; /*! \brief The storage of calculated pool size at init */ std::unordered_map all_pools_sizes_; /*! \brief The AoT codegen uses extern_calls due to some functions not being exposed in the TIR @@ -130,6 +145,10 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator { Map allocate_buf_to_let_var_; /*! \brief A counter to give references to pools a reproducible unique set of names */ int pool_var_count_ = 0; + /*! \brief This toggles to remove non tvmscript printable items for IRModule for unit tests */ + bool emit_tvmscript_printable_ = false; + /*! \brief A counter to give references to pools a reproducible unique set of names */ + std::unordered_set visited_primfuncs; }; PoolAllocationToOffsetConverter::ScopeInfo PoolAllocationToOffsetConverter::UpdateFunctionScopeInfo( @@ -138,14 +157,22 @@ PoolAllocationToOffsetConverter::ScopeInfo PoolAllocationToOffsetConverter::Upda si.params = original_func->params; si.buffer_map = original_func->buffer_map; Map ret; - for (const PoolInfo& pool_info : pool_ordering_) { + for (const AllocatedPoolInfo& allocated_pool_info : allocated_pool_ordering_) { + PoolInfo pool_info = allocated_pool_info->pool_info; String pool_ref_name = pool_info->pool_name + "_" + std::to_string(pool_var_count_++); String var_name = pool_ref_name + "_var"; DataType elem_dtype = DataType::UInt(8); Var buffer_var(var_name, PointerType(PrimType(elem_dtype), "global")); - Var pool_var(var_name, DataType::Handle()); + Var pool_var; + if (!emit_tvmscript_printable_) { + pool_var = Var(var_name, PointerType(PrimType(elem_dtype), "global")); + } else { + pool_var = Var(var_name, DataType::Handle(8)); + } si.params.push_back(pool_var); si.pools_to_params.Set(pool_info, pool_var); + si.allocated_pool_params.push_back(AllocatedPoolInfo( + allocated_pool_info->pool_info, allocated_pool_info->allocated_size, pool_var)); int pool_size = all_pools_sizes_[pool_info]; String buffer_var_name = pool_ref_name + "_buffer_var"; @@ -157,22 +184,40 @@ PoolAllocationToOffsetConverter::ScopeInfo PoolAllocationToOffsetConverter::Upda PrimFunc PoolAllocationToOffsetConverter::CreatePrimFuncWithPoolParams( const PrimFunc& original_primfunc) { - ScopeInfo si = UpdateFunctionScopeInfo(original_primfunc); - this->scope_stack.push(si); - Stmt new_body = this->VisitStmt(original_primfunc->body); - this->scope_stack.pop(); - return PrimFunc(si.params, new_body, original_primfunc->ret_type, si.buffer_map, - original_primfunc->attrs); + // Only create the new function if it was not modified with pool params + if (visited_primfuncs.find(original_primfunc) == visited_primfuncs.end()) { + ScopeInfo si = UpdateFunctionScopeInfo(original_primfunc); + this->scope_stack.push(si); + Stmt new_body = this->VisitStmt(original_primfunc->body); + this->scope_stack.pop(); + DictAttrs original_attrs = original_primfunc->attrs; + // We dont need attrs of PrimFunc that might include non printable attrs such as target + // for unit tests where emit_tvmscript_printable_ is to be used. + if (emit_tvmscript_printable_) { + original_attrs = DictAttrs(); + } + PrimFunc ret = + PrimFunc(si.params, new_body, original_primfunc->ret_type, si.buffer_map, original_attrs); + if (!emit_tvmscript_printable_) { + return WithAttr(ret, tvm::attr::kPoolArgs, si.allocated_pool_params); + } + visited_primfuncs.insert(ret); + return ret; + } + return original_primfunc; } -Array PoolAllocationToOffsetConverter::AppendPoolParamsToArgs(const CallNode* op) { +Array PoolAllocationToOffsetConverter::AppendPoolParamsToArgs( + const Array& args) { Array new_args; - for (const auto& arg : op->args) { + for (const auto& arg : args) { new_args.push_back(VisitExpr(arg)); } - for (const auto& pools_vars : this->scope_stack.top().pools_to_params) { + ScopeInfo top_scope = this->scope_stack.top(); + for (const auto& pools_vars : top_scope.pools_to_params) { tir::Var pool_var = pools_vars.second; - new_args.push_back(pool_var); + Buffer buffer_var = top_scope.buffer_map[pool_var]; + new_args.push_back(buffer_var->data); } return new_args; } @@ -192,24 +237,30 @@ Array PoolAllocationToOffsetConverter::ReplaceAllocateArgsWithLetArgs( } PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const CallNode* op) { - if (op->op.same_as(builtin::call_extern())) { + if (op->op.same_as(builtin::call_extern()) || op->op.same_as(builtin::tvm_call_cpacked())) { String func_name = Downcast(op->args[0])->value; - GlobalVar gv = function_global_vars_.at(func_name); - PrimFunc func = Downcast(module_->Lookup(gv)); - PrimFunc prim_func = CreatePrimFuncWithPoolParams(func); - module_->Update(gv, prim_func); - Array new_args = AppendPoolParamsToArgs(op); - new_args = ReplaceAllocateArgsWithLetArgs(new_args); - return Call(op->dtype, builtin::call_extern(), new_args); - } else if (op->op->IsInstance()) { + Array new_args; + if (function_global_vars_.find(func_name) != function_global_vars_.end()) { + GlobalVar gv = function_global_vars_.at(func_name); + PrimFunc func = Downcast(module_->Lookup(gv)); + PrimFunc prim_func = CreatePrimFuncWithPoolParams(func); + module_->Update(gv, prim_func); + new_args = AppendPoolParamsToArgs(op->args); + new_args = ReplaceAllocateArgsWithLetArgs(new_args); + } else { + new_args = ReplaceAllocateArgsWithLetArgs(op->args); + } + return Call(op->dtype, op->op, new_args); + } + if (op->op->IsInstance()) { PrimFunc func = Downcast(op->op); PrimFunc prim_func = CreatePrimFuncWithPoolParams(func); - Array new_args = AppendPoolParamsToArgs(op); + Array new_args = AppendPoolParamsToArgs(op->args); + new_args = AppendPoolParamsToArgs(new_args); new_args = ReplaceAllocateArgsWithLetArgs(new_args); return Call(op->dtype, prim_func, new_args); - } else { - return StmtExprMutator::VisitExpr_(op); } + return StmtExprMutator::VisitExpr_(op); } Stmt PoolAllocationToOffsetConverter::VisitStmt_(const AllocateNode* op) { @@ -219,12 +270,19 @@ Stmt PoolAllocationToOffsetConverter::VisitStmt_(const AllocateNode* op) { Var param = scope_info.pools_to_params[pool_allocation->pool_info]; Buffer buffer_var = scope_info.buffer_map[param]; ICHECK(pool_allocation->byte_offset < all_pools_sizes_[pool_allocation->pool_info]); - Load load_node = Load(op->dtype, buffer_var->data, pool_allocation->byte_offset, op->condition); - Var tir_var(op->buffer_var->name_hint + "_let", op->dtype); + Load load_node = + Load(DataType::UInt(8), buffer_var->data, pool_allocation->byte_offset, op->condition); + Call address_of_load = Call(DataType::Handle(8), builtin::address_of(), {load_node}); + Var tir_var; + if (!emit_tvmscript_printable_) { + tir_var = Var(op->buffer_var->name_hint + "_let", op->buffer_var->type_annotation); + } else { + tir_var = Var(op->buffer_var->name_hint + "_let", DataType::Handle(8)); + } allocate_buf_to_let_var_.Set(op->buffer_var, tir_var); Stmt new_body = VisitStmt(op->body); allocate_buf_to_let_var_.erase(op->buffer_var); - return LetStmt(tir_var, load_node, new_body); + return LetStmt(tir_var, address_of_load, new_body); } return StmtExprMutator::VisitStmt_(op); } @@ -252,17 +310,31 @@ IRModule PoolAllocationToOffsetConverter::operator()() { this->scope_stack.push(si); Stmt main_func_body = this->VisitStmt(main_func->body); this->scope_stack.pop(); - module_->Update(gv, PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, - main_func->attrs)); + // We dont need attrs of PrimFunc that might include non printable attrs such as target + // for unit tests where emit_tvmscript_printable_ is to be used. + if (!emit_tvmscript_printable_) { + main_func = + PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, main_func->attrs); + main_func = WithAttr(main_func, tvm::attr::kPoolArgs, si.allocated_pool_params); + } else { + main_func = + PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, DictAttrs()); + } + module_->Update(gv, main_func); + if (!emit_tvmscript_printable_) { + return WithAttr(this->module_, tvm::attr::kPoolArgs, si.allocated_pool_params); + } return this->module_; } namespace transform { tvm::transform::Pass ConvertPoolAllocationsToOffsets( - const Map& pool_allocations) { + const Map& pool_allocations, + Bool emit_tvmscript_printable = Bool(false)) { auto pass_func = [=](IRModule m, tvm::transform::PassContext ctx) { - return Downcast(PoolAllocationToOffsetConverter(m, pool_allocations)()); + return Downcast(PoolAllocationToOffsetConverter( + m, pool_allocations, emit_tvmscript_printable->value != 0)()); }; return tvm::transform::CreateModulePass(pass_func, 0, "tir.usmp.ConvertPoolAllocationsToOffsets", {}); diff --git a/src/tir/usmp/utils.cc b/src/tir/usmp/utils.cc index 69529c8c196f..14b3d26641a3 100644 --- a/src/tir/usmp/utils.cc +++ b/src/tir/usmp/utils.cc @@ -135,6 +135,30 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ")"; }); +AllocatedPoolInfo::AllocatedPoolInfo(PoolInfo pool_info, Integer allocated_size, Var pool_var) { + auto allocated_poolinfo_node = make_object(); + allocated_poolinfo_node->pool_info = pool_info; + allocated_poolinfo_node->allocated_size = allocated_size; + if (pool_var.defined()) { + allocated_poolinfo_node->pool_var = pool_var; + } + data_ = std::move(allocated_poolinfo_node); +} + +TVM_REGISTER_NODE_TYPE(AllocatedPoolInfoNode); +TVM_REGISTER_GLOBAL("tir.usmp.AllocatedPoolInfo") + .set_body_typed([](PoolInfo pool_info, Integer allocated_size) { + return AllocatedPoolInfo(pool_info, allocated_size); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "AllocatedPoolInfoNode(\n" + << "pool_info=" << node->pool_info << ",\n allocated_size=" << node->allocated_size + << ")"; + }); + Array CreateArrayBufferInfo(const Map& buffer_info_map) { Array ret; for (const auto& kv : buffer_info_map) { diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index bd3e6287b7c7..7174a796a70e 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -57,6 +57,15 @@ def assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos): return ret +def _assign_targets_to_primfuncs_irmodule(mod, target): + """helper to assign target for PrimFunc in a IRModule""" + ret = tvm.IRModule() + for global_var, basefunc in mod.functions.items(): + if isinstance(basefunc, tvm.tir.PrimFunc): + ret[global_var] = basefunc.with_attr("target", target) + return ret + + # fmt: off @tvm.script.ir_module class LinearStructure: @@ -130,87 +139,81 @@ def run_model(input: T.handle, output: T.handle) -> None: # fmt: off @tvm.script.ir_module class LinearStructurePlanned: + @T.prim_func + def run_model(input: T.handle, output: T.handle, fast_memory_0_var: T.handle, slow_memory_1_var: T.handle) -> None: + fast_memory_0_buffer_var = T.match_buffer(fast_memory_0_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) + slow_memory_1_buffer_var = T.match_buffer(slow_memory_1_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) + # body + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 1) + sid_9_let: T.handle = T.address_of(T.load("uint8", slow_memory_1_buffer_var.data, 1117472), dtype="handle") + sid_8_let: T.handle = T.address_of(T.load("uint8", slow_memory_1_buffer_var.data, 0), dtype="handle") + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9_let, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9_let, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8_let, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8_let, output, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32")) + @T.prim_func def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle, fast_memory_6_var: T.handle, slow_memory_7_var: T.handle) -> None: - # function attr dict - T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8") T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16") fast_memory_6_buffer_var = T.match_buffer(fast_memory_6_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) - slow_memory_7_buffer_var = T.match_buffer(slow_memory_7_var, [802816], dtype="uint8", strides=[1], elem_offset=1, align=16) + slow_memory_7_buffer_var = T.match_buffer(slow_memory_7_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) # body - tensor_2_let: T.uint8 = T.load("uint8", fast_memory_6_buffer_var.data, 0) + tensor_2_let: T.handle = T.address_of(T.load("uint8", fast_memory_6_buffer_var.data, 0), dtype="handle") for ax0_ax1_fused_4, ax2_4 in T.grid(56, 56): for ax3_init in T.serial(0, 64): T.store(tensor_2_let, ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_init, T.uint8(0), True) for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): - T.store(tensor_2_let, ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_2, T.max(T.load("uint8", tensor_2_let, ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_2), T.if_then_else(ax0_ax1_fused_4 * 2 + T.floordiv(rv0_rv1_fused_1, 3) < 112 and ax2_4 * 2 + T.floormod(rv0_rv1_fused_1, 3) < 112, T.load("uint8", placeholder_29.data, ax0_ax1_fused_4 * 14336 + T.floordiv(rv0_rv1_fused_1, 3) * 7168 + ax2_4 * 128 + T.floormod(rv0_rv1_fused_1, 3) * 64 + ax3_2), T.uint8(0), dtype="uint8")), True) + T.store(tensor_2_let, ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_2, T.max(T.load("uint8", tensor_2_let, ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_2), T.if_then_else(ax0_ax1_fused_4 * 2 + rv0_rv1_fused_1 // 3 < 112 and ax2_4 * 2 + rv0_rv1_fused_1 % 3 < 112, T.load("uint8", placeholder_29.data, ax0_ax1_fused_4 * 14336 + rv0_rv1_fused_1 // 3 * 7168 + ax2_4 * 128 + rv0_rv1_fused_1 % 3 * 64 + ax3_2), T.uint8(0), dtype="uint8")), True) for ax0_ax1_fused_5, ax2_5, ax3_3 in T.grid(56, 56, 64): T.store(T_cast_7.data, ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3, T.cast(T.load("uint8", tensor_2_let, ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3), "int16"), True) - @T.prim_func - def run_model(input: T.handle, output: T.handle, fast_memory_0_var: T.handle, slow_memory_1_var: T.handle) -> None: - # function attr dict - T.func_attr({"global_symbol": "run_model", "runner_function": True}) - fast_memory_0_buffer_var = T.match_buffer(fast_memory_0_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) - slow_memory_1_buffer_var = T.match_buffer(slow_memory_1_var, [802816], dtype="uint8", strides=[1], elem_offset=1, align=16) - # body - T.attr("default", "device_id", 0) - T.attr("default", "device_type", 1) - sid_9_let: T.int8 = T.load("int8", slow_memory_1_buffer_var.data, 314646) - sid_8_let: T.int8 = T.load("int8", slow_memory_1_buffer_var.data, 0) - T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9_let, fast_memory_0_var, slow_memory_1_var, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9_let, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8_let, fast_memory_0_var, slow_memory_1_var, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8_let, output, fast_memory_0_var, slow_memory_1_var, dtype="int32")) - @T.prim_func def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle, fast_memory_2_var: T.handle, slow_memory_3_var: T.handle) -> None: - # function attr dict - T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dtype="uint8") placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16") T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16") fast_memory_2_buffer_var = T.match_buffer(fast_memory_2_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) - slow_memory_3_buffer_var = T.match_buffer(slow_memory_3_var, [802816], dtype="uint8", strides=[1], elem_offset=1, align=16) + slow_memory_3_buffer_var = T.match_buffer(slow_memory_3_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) # body for ax0_ax1_fused_1, ax2_1, ax3_inner_1 in T.grid(224, 224, 3): T.store(T_subtract_1.data, ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1, T.cast(T.load("uint8", placeholder_4.data, ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1), "int16") - T.load("int16", placeholder_5.data, 0), True) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle, fast_memory_4_var: T.handle, slow_memory_5_var: T.handle) -> None: - # function attr dict - T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16") placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16") placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32") T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8") fast_memory_4_buffer_var = T.match_buffer(fast_memory_4_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) - slow_memory_5_buffer_var = T.match_buffer(slow_memory_5_var, [802816], dtype="uint8", strides=[1], elem_offset=1, align=16) + slow_memory_5_buffer_var = T.match_buffer(slow_memory_5_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) # body - PaddedInput_7_let: T.int16 = T.load("int16", slow_memory_5_buffer_var.data, 0) + PaddedInput_7_let: T.handle = T.address_of(T.load("uint8", slow_memory_5_buffer_var.data, 802816), dtype="handle") for i0_i1_fused_7, i2_7, i3_7 in T.grid(229, 229, 3): T.store(PaddedInput_7_let, i0_i1_fused_7 * 687 + i2_7 * 3 + i3_7, T.if_then_else(2 <= i0_i1_fused_7 and i0_i1_fused_7 < 226 and 2 <= i2_7 and i2_7 < 226, T.load("int16", placeholder_65.data, i0_i1_fused_7 * 672 + i2_7 * 3 + i3_7 - 1350), T.int16(0), dtype="int16"), True) for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): - Conv2dOutput_7_let: T.int32 = T.load("int32", fast_memory_4_buffer_var.data, 0) + Conv2dOutput_7_let: T.handle = T.address_of(T.load("uint8", fast_memory_4_buffer_var.data, 0), dtype="handle") for ff_3 in T.serial(0, 64): T.store(Conv2dOutput_7_let, ff_3, 0, True) for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): - T.store(Conv2dOutput_7_let, ff_3, T.load("int32", Conv2dOutput_7_let, ff_3) + T.cast(T.load("int16", PaddedInput_7_let, T.floordiv(ax0_ax1_fused_ax2_fused_7, 112) * 1374 + ry_2 * 687 + T.floormod(ax0_ax1_fused_ax2_fused_7, 112) * 6 + rx_2 * 3 + rc_7), "int32") * T.cast(T.load("int16", placeholder_66.data, ry_2 * 1344 + rx_2 * 192 + rc_7 * 64 + ff_3), "int32"), True) + T.store(Conv2dOutput_7_let, ff_3, T.load("int32", Conv2dOutput_7_let, ff_3) + T.cast(T.load("int16", PaddedInput_7_let, ax0_ax1_fused_ax2_fused_7 // 112 * 1374 + ry_2 * 687 + ax0_ax1_fused_ax2_fused_7 % 112 * 6 + rx_2 * 3 + rc_7), "int32") * T.cast(T.load("int16", placeholder_66.data, ry_2 * 1344 + rx_2 * 192 + rc_7 * 64 + ff_3), "int32"), True) for ax3_inner_7 in T.serial(0, 64): T.store(T_cast_21.data, ax0_ax1_fused_ax2_fused_7 * 64 + ax3_inner_7, T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_7_let, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) # fmt: on def test_linear(): + target = Target("c") fast_memory_pool = usmp_utils.PoolInfo( pool_name="fast_memory", - target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, size_hint_bytes=200704, ) slow_memory_pool = usmp_utils.PoolInfo( - pool_name="slow_memory", target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS} + pool_name="slow_memory", target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS} ) tir_mod = LinearStructure + tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) tir_mod = assign_poolinfos_to_allocates_in_irmodule( tir_mod, [fast_memory_pool, slow_memory_pool] ) @@ -224,7 +227,7 @@ def test_linear(): fassign_stmt_pool_allocations = tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations") pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, buffer_pool_allocations) tir_mod_with_offsets = tvm.tir.usmp.transform.convert_pool_allocations_to_offsets( - pool_allocations + pool_allocations, emit_tvmscript_printable=True )(tir_mod) tir_mod_with_offsets_ref = LinearStructurePlanned @@ -361,78 +364,19 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place # fmt: off @tvm.script.ir_module class ResnetStructurePlanned: - @T.prim_func - def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: T.handle) -> None: - # function attr dict - T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", "tir.noalias": True}) - placeholder_2 = T.match_buffer(placeholder, [1, 75, 75, 64], dtype="uint8") - placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") - T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16") - global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) - # body - for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): - T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.load("uint8", placeholder_2.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner), "int32") - 94, 1843157232, 31, 1, dtype="int32") + T.load("int32", placeholder_3.data, ax3_outer * 16 + ax3_inner), 255), 0), "uint8"), "int16"), True) - - @T.prim_func - def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle, global_workspace_4_var: T.handle) -> None: - # function attr dict - T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", "tir.noalias": True}) - placeholder_19 = T.match_buffer(placeholder_16, [1, 75, 75, 64], dtype="int16") - placeholder_20 = T.match_buffer(placeholder_17, [1, 1, 64, 256], dtype="int16") - placeholder_21 = T.match_buffer(placeholder_18, [1, 1, 1, 256], dtype="int32") - T_add_1 = T.match_buffer(T_add, [1, 75, 75, 256], dtype="int32") - global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) - # body - PaddedInput_2_let: T.int16 = T.load("int16", global_workspace_4_buffer_var.data, 7200000) - for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): - T.store(PaddedInput_2_let, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2, T.load("int16", placeholder_19.data, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2), True) - for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625): - Conv2dOutput_2_let: T.int32 = T.load("int32", global_workspace_4_buffer_var.data, 7920000) - for ax3_outer_1 in T.serial(0, 4): - for ff_2 in T.serial(0, 64): - T.store(Conv2dOutput_2_let, ff_2, 0, True) - for rc_2 in T.serial(0, 64): - T.store(Conv2dOutput_2_let, ff_2, T.load("int32", Conv2dOutput_2_let, ff_2) + T.cast(T.load("int16", PaddedInput_2_let, ax0_ax1_fused_ax2_fused_2 * 64 + rc_2), "int32") * T.cast(T.load("int16", placeholder_20.data, rc_2 * 256 + ax3_outer_1 * 64 + ff_2), "int32"), True) - for ax3_inner_3 in T.serial(0, 64): - T.store(T_add_1.data, ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3, T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_2_let, ax3_inner_3) + T.load("int32", placeholder_21.data, ax3_outer_1 * 64 + ax3_inner_3), 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136, True) - - @T.prim_func - def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle, global_workspace_3_var: T.handle) -> None: - # function attr dict - T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True}) - placeholder_13 = T.match_buffer(placeholder_10, [1, 75, 75, 64], dtype="int16") - placeholder_14 = T.match_buffer(placeholder_11, [3, 3, 64, 64], dtype="int16") - placeholder_15 = T.match_buffer(placeholder_12, [1, 1, 1, 64], dtype="int32") - T_cast_5 = T.match_buffer(T_cast_4, [1, 75, 75, 64], dtype="int16") - global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) - # body - PaddedInput_1_let: T.int16 = T.load("int16", global_workspace_3_buffer_var.data, 0) - for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): - T.store(PaddedInput_1_let, i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1, T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, T.load("int16", placeholder_13.data, i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864), T.int16(0), dtype="int16"), True) - for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): - Conv2dOutput_1_let: T.int32 = T.load("int32", global_workspace_3_buffer_var.data, 7200000) - for ff_1 in T.serial(0, 64): - T.store(Conv2dOutput_1_let, ff_1, 0, True) - for ry, rx, rc_1 in T.grid(3, 3, 64): - T.store(Conv2dOutput_1_let, ff_1, T.load("int32", Conv2dOutput_1_let, ff_1) + T.cast(T.load("int16", PaddedInput_1_let, T.floordiv(ax0_ax1_fused_ax2_fused_1, 75) * 4928 + ry * 4928 + rx * 64 + T.floormod(ax0_ax1_fused_ax2_fused_1, 75) * 64 + rc_1), "int32") * T.cast(T.load("int16", placeholder_14.data, ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1), "int32"), True) - for ax3_inner_2 in T.serial(0, 64): - T.store(T_cast_5.data, ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_1_let, ax3_inner_2) + T.load("int32", placeholder_15.data, ax3_inner_2), 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) - @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle, global_workspace_2_var: T.handle) -> None: - # function attr dict - T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True}) placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16") placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16") placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32") T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16") - global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7200000], dtype="uint8", strides=[1], elem_offset=1, align=16) # body - PaddedInput_let: T.int16 = T.load("int16", global_workspace_2_buffer_var.data, 6480000) + PaddedInput_let: T.handle = T.address_of(T.load("uint8", global_workspace_2_buffer_var.data, 0), dtype="handle") for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): T.store(PaddedInput_let, i0_i1_fused * 4800 + i2 * 64 + i3, T.load("int16", placeholder_7.data, i0_i1_fused * 4800 + i2 * 64 + i3), True) for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): - Conv2dOutput_let: T.int32 = T.load("int32", global_workspace_2_buffer_var.data, 7200000) + Conv2dOutput_let: T.handle = T.address_of(T.load("uint8", global_workspace_2_buffer_var.data, 1478912), dtype="handle") for ff in T.serial(0, 64): T.store(Conv2dOutput_let, ff, 0, True) for rc in T.serial(0, 64): @@ -442,38 +386,34 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place @T.prim_func def run_model(input: T.handle, output: T.handle, global_workspace_0_var: T.handle) -> None: - # function attr dict - T.func_attr({"global_symbol": "run_model", "runner_function": True}) - global_workspace_0_buffer_var = T.match_buffer(global_workspace_0_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + global_workspace_0_buffer_var = T.match_buffer(global_workspace_0_var, [7200000], dtype="uint8", strides=[1], elem_offset=1, align=16) # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) - sid_2_let: T.int8 = T.load("int8", global_workspace_0_buffer_var.data, 5760000) - sid_6_let: T.int8 = T.load("int8", global_workspace_0_buffer_var.data, 0) - sid_7_let: T.int8 = T.load("int8", global_workspace_0_buffer_var.data, 6480000) - sid_8_let: T.int8 = T.load("int8", global_workspace_0_buffer_var.data, 6480000) - T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2_let, global_workspace_0_var, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2_let, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8_let, global_workspace_0_var, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8_let, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7_let, global_workspace_0_var, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7_let, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6_let, global_workspace_0_var, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2_let, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6_let, output, global_workspace_0_var, dtype="int32")) + sid_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 6480000), dtype="handle") + sid_6_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 0), dtype="handle") + sid_7_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 758912), dtype="handle") + sid_8_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 758912), dtype="handle") + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2_let, global_workspace_0_buffer_var.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2_let, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8_let, global_workspace_0_buffer_var.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8_let, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7_let, global_workspace_0_buffer_var.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7_let, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6_let, global_workspace_0_buffer_var.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2_let, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6_let, output, global_workspace_0_buffer_var.data, dtype="int32")) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle, global_workspace_5_var: T.handle) -> None: - # function attr dict - T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", "tir.noalias": True}) placeholder_29 = T.match_buffer(placeholder_22, [1, 75, 75, 64], dtype="int16") placeholder_27 = T.match_buffer(placeholder_23, [1, 1, 64, 256], dtype="int16") placeholder_26 = T.match_buffer(placeholder_24, [1, 1, 1, 256], dtype="int32") placeholder_28 = T.match_buffer(placeholder_25, [1, 75, 75, 256], dtype="int32") T_cast_7 = T.match_buffer(T_cast_6, [1, 75, 75, 256], dtype="uint8") - global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7200000], dtype="uint8", strides=[1], elem_offset=1, align=16) # body - PaddedInput_3_let: T.int16 = T.load("int16", global_workspace_5_buffer_var.data, 6480000) + PaddedInput_3_let: T.handle = T.address_of(T.load("uint8", global_workspace_5_buffer_var.data, 5760000), dtype="handle") for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): T.store(PaddedInput_3_let, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3, T.load("int16", placeholder_29.data, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3), True) for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625): - Conv2dOutput_3_let: T.int32 = T.load("int32", global_workspace_5_buffer_var.data, 7200000) + Conv2dOutput_3_let: T.handle = T.address_of(T.load("uint8", global_workspace_5_buffer_var.data, 6480000), dtype="handle") for ax3_outer_2 in T.serial(0, 4): for ff_3 in T.serial(0, 64): T.store(Conv2dOutput_3_let, ff_3, 0, True) @@ -481,15 +421,68 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s T.store(Conv2dOutput_3_let, ff_3, T.load("int32", Conv2dOutput_3_let, ff_3) + T.cast(T.load("int16", PaddedInput_3_let, ax0_ax1_fused_ax2_fused_3 * 64 + rc_3), "int32") * T.cast(T.load("int16", placeholder_27.data, rc_3 * 256 + ax3_outer_2 * 64 + ff_3), "int32"), True) for ax3_inner_4 in T.serial(0, 64): T.store(T_cast_7.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4, T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_3_let, ax3_inner_4) + T.load("int32", placeholder_26.data, ax3_outer_2 * 64 + ax3_inner_4), 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + T.load("int32", placeholder_28.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4), 255), 0), "uint8"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle, global_workspace_4_var: T.handle) -> None: + placeholder_19 = T.match_buffer(placeholder_16, [1, 75, 75, 64], dtype="int16") + placeholder_20 = T.match_buffer(placeholder_17, [1, 1, 64, 256], dtype="int16") + placeholder_21 = T.match_buffer(placeholder_18, [1, 1, 1, 256], dtype="int32") + T_add_1 = T.match_buffer(T_add, [1, 75, 75, 256], dtype="int32") + global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7200000], dtype="uint8", strides=[1], elem_offset=1, align=16) + # body + PaddedInput_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_4_buffer_var.data, 5760000), dtype="handle") + for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): + T.store(PaddedInput_2_let, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2, T.load("int16", placeholder_19.data, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2), True) + for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625): + Conv2dOutput_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_4_buffer_var.data, 6480000), dtype="handle") + for ax3_outer_1 in T.serial(0, 4): + for ff_2 in T.serial(0, 64): + T.store(Conv2dOutput_2_let, ff_2, 0, True) + for rc_2 in T.serial(0, 64): + T.store(Conv2dOutput_2_let, ff_2, T.load("int32", Conv2dOutput_2_let, ff_2) + T.cast(T.load("int16", PaddedInput_2_let, ax0_ax1_fused_ax2_fused_2 * 64 + rc_2), "int32") * T.cast(T.load("int16", placeholder_20.data, rc_2 * 256 + ax3_outer_1 * 64 + ff_2), "int32"), True) + for ax3_inner_3 in T.serial(0, 64): + T.store(T_add_1.data, ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3, T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_2_let, ax3_inner_3) + T.load("int32", placeholder_21.data, ax3_outer_1 * 64 + ax3_inner_3), 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136, True) + + @T.prim_func + def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: T.handle) -> None: + placeholder_2 = T.match_buffer(placeholder, [1, 75, 75, 64], dtype="uint8") + placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") + T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16") + global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7200000], dtype="uint8", strides=[1], elem_offset=1, align=16) + # body + for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): + T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.load("uint8", placeholder_2.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner), "int32") - 94, 1843157232, 31, 1, dtype="int32") + T.load("int32", placeholder_3.data, ax3_outer * 16 + ax3_inner), 255), 0), "uint8"), "int16"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle, global_workspace_3_var: T.handle) -> None: + placeholder_13 = T.match_buffer(placeholder_10, [1, 75, 75, 64], dtype="int16") + placeholder_14 = T.match_buffer(placeholder_11, [3, 3, 64, 64], dtype="int16") + placeholder_15 = T.match_buffer(placeholder_12, [1, 1, 1, 64], dtype="int32") + T_cast_5 = T.match_buffer(T_cast_4, [1, 75, 75, 64], dtype="int16") + global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7200000], dtype="uint8", strides=[1], elem_offset=1, align=16) + # body + PaddedInput_1_let: T.handle = T.address_of(T.load("uint8", global_workspace_3_buffer_var.data, 0), dtype="handle") + for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): + T.store(PaddedInput_1_let, i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1, T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, T.load("int16", placeholder_13.data, i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864), T.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): + Conv2dOutput_1_let: T.handle = T.address_of(T.load("uint8", global_workspace_3_buffer_var.data, 1478912), dtype="handle") + for ff_1 in T.serial(0, 64): + T.store(Conv2dOutput_1_let, ff_1, 0, True) + for ry, rx, rc_1 in T.grid(3, 3, 64): + T.store(Conv2dOutput_1_let, ff_1, T.load("int32", Conv2dOutput_1_let, ff_1) + T.cast(T.load("int16", PaddedInput_1_let, ax0_ax1_fused_ax2_fused_1 // 75 * 4928 + ry * 4928 + rx * 64 + ax0_ax1_fused_ax2_fused_1 % 75 * 64 + rc_1), "int32") * T.cast(T.load("int16", placeholder_14.data, ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1), "int32"), True) + for ax3_inner_2 in T.serial(0, 64): + T.store(T_cast_5.data, ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_1_let, ax3_inner_2) + T.load("int32", placeholder_15.data, ax3_inner_2), 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) # fmt: on def test_fanout(): + target = Target("c") global_workspace_pool = usmp_utils.PoolInfo( pool_name="global_workspace", - target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, ) tir_mod = ResnetStructure + tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool]) main_func = tir_mod["run_model"] buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) @@ -501,7 +494,7 @@ def test_fanout(): fassign_stmt_pool_allocations = tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations") pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, buffer_pool_allocations) tir_mod_with_offsets = tvm.tir.usmp.transform.convert_pool_allocations_to_offsets( - pool_allocations + pool_allocations, emit_tvmscript_printable=True )(tir_mod) tir_mod_with_offsets_ref = ResnetStructurePlanned From 4b34dd7aae3ff6e6359099c2548597547668bd74 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Wed, 1 Dec 2021 16:20:49 +0000 Subject: [PATCH 3/7] [TIR][USMP] adding the pass to convert to pool offsets Fixing the references after changes in the memory planning algorithm. Change-Id: Id7c22356fd5de43d10a2b4fc70e978af2c6d599d --- ...orm_convert_pool_allocations_to_offsets.py | 103 +++++++++--------- 1 file changed, 53 insertions(+), 50 deletions(-) diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index 7174a796a70e..bfa00f097202 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -202,7 +202,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde # fmt: on -def test_linear(): +def test_mobilenet_subgraph(): target = Target("c") fast_memory_pool = usmp_utils.PoolInfo( pool_name="fast_memory", @@ -231,6 +231,7 @@ def test_linear(): )(tir_mod) tir_mod_with_offsets_ref = LinearStructurePlanned + tir_mod_with_offsets_ref = tvm.script.from_source(tir_mod_with_offsets_ref.script(show_meta=False)) # The TIR produced fails on roundtrip TVMScript testing. # Therefore, indicates the TVMScript produced here and/or the parser # is lacking functionality. Thus for these tests, uses a string @@ -365,40 +366,14 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place @tvm.script.ir_module class ResnetStructurePlanned: @T.prim_func - def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle, global_workspace_2_var: T.handle) -> None: - placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16") - placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16") - placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32") - T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16") - global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7200000], dtype="uint8", strides=[1], elem_offset=1, align=16) - # body - PaddedInput_let: T.handle = T.address_of(T.load("uint8", global_workspace_2_buffer_var.data, 0), dtype="handle") - for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): - T.store(PaddedInput_let, i0_i1_fused * 4800 + i2 * 64 + i3, T.load("int16", placeholder_7.data, i0_i1_fused * 4800 + i2 * 64 + i3), True) - for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): - Conv2dOutput_let: T.handle = T.address_of(T.load("uint8", global_workspace_2_buffer_var.data, 1478912), dtype="handle") - for ff in T.serial(0, 64): - T.store(Conv2dOutput_let, ff, 0, True) - for rc in T.serial(0, 64): - T.store(Conv2dOutput_let, ff, T.load("int32", Conv2dOutput_let, ff) + T.cast(T.load("int16", PaddedInput_let, ax0_ax1_fused_ax2_fused * 64 + rc), "int32") * T.cast(T.load("int16", placeholder_8.data, rc * 64 + ff), "int32"), True) - for ax3_inner_1 in T.serial(0, 64): - T.store(T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_let, ax3_inner_1) + T.load("int32", placeholder_9.data, ax3_inner_1), 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True) - - @T.prim_func - def run_model(input: T.handle, output: T.handle, global_workspace_0_var: T.handle) -> None: - global_workspace_0_buffer_var = T.match_buffer(global_workspace_0_var, [7200000], dtype="uint8", strides=[1], elem_offset=1, align=16) + def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: T.handle) -> None: + placeholder_2 = T.match_buffer(placeholder, [1, 75, 75, 64], dtype="uint8") + placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") + T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16") + global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body - T.attr("default", "device_id", 0) - T.attr("default", "device_type", 1) - sid_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 6480000), dtype="handle") - sid_6_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 0), dtype="handle") - sid_7_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 758912), dtype="handle") - sid_8_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 758912), dtype="handle") - T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2_let, global_workspace_0_buffer_var.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2_let, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8_let, global_workspace_0_buffer_var.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8_let, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7_let, global_workspace_0_buffer_var.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7_let, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6_let, global_workspace_0_buffer_var.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2_let, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6_let, output, global_workspace_0_buffer_var.data, dtype="int32")) + for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): + T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.load("uint8", placeholder_2.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner), "int32") - 94, 1843157232, 31, 1, dtype="int32") + T.load("int32", placeholder_3.data, ax3_outer * 16 + ax3_inner), 255), 0), "uint8"), "int16"), True) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle, global_workspace_5_var: T.handle) -> None: @@ -407,13 +382,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s placeholder_26 = T.match_buffer(placeholder_24, [1, 1, 1, 256], dtype="int32") placeholder_28 = T.match_buffer(placeholder_25, [1, 75, 75, 256], dtype="int32") T_cast_7 = T.match_buffer(T_cast_6, [1, 75, 75, 256], dtype="uint8") - global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7200000], dtype="uint8", strides=[1], elem_offset=1, align=16) + global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body - PaddedInput_3_let: T.handle = T.address_of(T.load("uint8", global_workspace_5_buffer_var.data, 5760000), dtype="handle") + PaddedInput_3_let: T.handle = T.address_of(T.load("uint8", global_workspace_5_buffer_var.data, 6480000), dtype="handle") for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): T.store(PaddedInput_3_let, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3, T.load("int16", placeholder_29.data, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3), True) for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625): - Conv2dOutput_3_let: T.handle = T.address_of(T.load("uint8", global_workspace_5_buffer_var.data, 6480000), dtype="handle") + Conv2dOutput_3_let: T.handle = T.address_of(T.load("uint8", global_workspace_5_buffer_var.data, 7200000), dtype="handle") for ax3_outer_2 in T.serial(0, 4): for ff_3 in T.serial(0, 64): T.store(Conv2dOutput_3_let, ff_3, 0, True) @@ -428,13 +403,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s placeholder_20 = T.match_buffer(placeholder_17, [1, 1, 64, 256], dtype="int16") placeholder_21 = T.match_buffer(placeholder_18, [1, 1, 1, 256], dtype="int32") T_add_1 = T.match_buffer(T_add, [1, 75, 75, 256], dtype="int32") - global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7200000], dtype="uint8", strides=[1], elem_offset=1, align=16) + global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body - PaddedInput_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_4_buffer_var.data, 5760000), dtype="handle") + PaddedInput_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_4_buffer_var.data, 7200000), dtype="handle") for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): T.store(PaddedInput_2_let, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2, T.load("int16", placeholder_19.data, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2), True) for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625): - Conv2dOutput_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_4_buffer_var.data, 6480000), dtype="handle") + Conv2dOutput_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_4_buffer_var.data, 7920000), dtype="handle") for ax3_outer_1 in T.serial(0, 4): for ff_2 in T.serial(0, 64): T.store(Conv2dOutput_2_let, ff_2, 0, True) @@ -444,14 +419,24 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s T.store(T_add_1.data, ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3, T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_2_let, ax3_inner_3) + T.load("int32", placeholder_21.data, ax3_outer_1 * 64 + ax3_inner_3), 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136, True) @T.prim_func - def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: T.handle) -> None: - placeholder_2 = T.match_buffer(placeholder, [1, 75, 75, 64], dtype="uint8") - placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") - T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16") - global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7200000], dtype="uint8", strides=[1], elem_offset=1, align=16) + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle, global_workspace_2_var: T.handle) -> None: + placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16") + placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16") + placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32") + T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16") + global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body - for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): - T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.load("uint8", placeholder_2.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner), "int32") - 94, 1843157232, 31, 1, dtype="int32") + T.load("int32", placeholder_3.data, ax3_outer * 16 + ax3_inner), 255), 0), "uint8"), "int16"), True) + PaddedInput_let: T.handle = T.address_of(T.load("uint8", global_workspace_2_buffer_var.data, 7200000), dtype="handle") + for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): + T.store(PaddedInput_let, i0_i1_fused * 4800 + i2 * 64 + i3, T.load("int16", placeholder_7.data, i0_i1_fused * 4800 + i2 * 64 + i3), True) + for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): + Conv2dOutput_let: T.handle = T.address_of(T.load("uint8", global_workspace_2_buffer_var.data, 7920000), dtype="handle") + for ff in T.serial(0, 64): + T.store(Conv2dOutput_let, ff, 0, True) + for rc in T.serial(0, 64): + T.store(Conv2dOutput_let, ff, T.load("int32", Conv2dOutput_let, ff) + T.cast(T.load("int16", PaddedInput_let, ax0_ax1_fused_ax2_fused * 64 + rc), "int32") * T.cast(T.load("int16", placeholder_8.data, rc * 64 + ff), "int32"), True) + for ax3_inner_1 in T.serial(0, 64): + T.store(T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_let, ax3_inner_1) + T.load("int32", placeholder_9.data, ax3_inner_1), 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle, global_workspace_3_var: T.handle) -> None: @@ -459,23 +444,40 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla placeholder_14 = T.match_buffer(placeholder_11, [3, 3, 64, 64], dtype="int16") placeholder_15 = T.match_buffer(placeholder_12, [1, 1, 1, 64], dtype="int32") T_cast_5 = T.match_buffer(T_cast_4, [1, 75, 75, 64], dtype="int16") - global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7200000], dtype="uint8", strides=[1], elem_offset=1, align=16) + global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body PaddedInput_1_let: T.handle = T.address_of(T.load("uint8", global_workspace_3_buffer_var.data, 0), dtype="handle") for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): T.store(PaddedInput_1_let, i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1, T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, T.load("int16", placeholder_13.data, i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864), T.int16(0), dtype="int16"), True) for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): - Conv2dOutput_1_let: T.handle = T.address_of(T.load("uint8", global_workspace_3_buffer_var.data, 1478912), dtype="handle") + Conv2dOutput_1_let: T.handle = T.address_of(T.load("uint8", global_workspace_3_buffer_var.data, 7200000), dtype="handle") for ff_1 in T.serial(0, 64): T.store(Conv2dOutput_1_let, ff_1, 0, True) for ry, rx, rc_1 in T.grid(3, 3, 64): T.store(Conv2dOutput_1_let, ff_1, T.load("int32", Conv2dOutput_1_let, ff_1) + T.cast(T.load("int16", PaddedInput_1_let, ax0_ax1_fused_ax2_fused_1 // 75 * 4928 + ry * 4928 + rx * 64 + ax0_ax1_fused_ax2_fused_1 % 75 * 64 + rc_1), "int32") * T.cast(T.load("int16", placeholder_14.data, ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1), "int32"), True) for ax3_inner_2 in T.serial(0, 64): T.store(T_cast_5.data, ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_1_let, ax3_inner_2) + T.load("int32", placeholder_15.data, ax3_inner_2), 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) + + @T.prim_func + def run_model(input: T.handle, output: T.handle, global_workspace_0_var: T.handle) -> None: + global_workspace_0_buffer_var = T.match_buffer(global_workspace_0_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + # body + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 1) + sid_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 5760000), dtype="handle") + sid_6_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 0), dtype="handle") + sid_7_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 6480000), dtype="handle") + sid_8_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 6480000), dtype="handle") + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2_let, global_workspace_0_buffer_var.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2_let, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8_let, global_workspace_0_buffer_var.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8_let, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7_let, global_workspace_0_buffer_var.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7_let, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6_let, global_workspace_0_buffer_var.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2_let, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6_let, output, global_workspace_0_buffer_var.data, dtype="int32")) + __tvm_meta__ = None # fmt: on -def test_fanout(): +def test_resnet_subgraph(): target = Target("c") global_workspace_pool = usmp_utils.PoolInfo( pool_name="global_workspace", @@ -498,6 +500,7 @@ def test_fanout(): )(tir_mod) tir_mod_with_offsets_ref = ResnetStructurePlanned + # The TIR produced fails on roundtrip TVMScript testing. # Therefore, indicates the TVMScript produced here and/or the parser # is lacking functionality. Thus for these tests, uses a string From 673b2980c5fabda9bcbff5a4f8cf0f23b2b77b51 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Thu, 2 Dec 2021 17:39:00 +0000 Subject: [PATCH 4/7] [TIR][USMP] adding the pass to convert to pool offsets * fixing the lint Change-Id: I7ff920b92d14a9919c930a4b35a2169c77a57dd1 --- ..._tir_usmp_transform_convert_pool_allocations_to_offsets.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index bfa00f097202..a9d17cb9b6f6 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -231,7 +231,9 @@ def test_mobilenet_subgraph(): )(tir_mod) tir_mod_with_offsets_ref = LinearStructurePlanned - tir_mod_with_offsets_ref = tvm.script.from_source(tir_mod_with_offsets_ref.script(show_meta=False)) + tir_mod_with_offsets_ref = tvm.script.from_source( + tir_mod_with_offsets_ref.script(show_meta=False) + ) # The TIR produced fails on roundtrip TVMScript testing. # Therefore, indicates the TVMScript produced here and/or the parser # is lacking functionality. Thus for these tests, uses a string From 2ff702a89446a123523ad6afc9d049cdb24d6d5e Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Mon, 6 Dec 2021 16:48:33 +0000 Subject: [PATCH 5/7] [TIR][USMP] adding the pass to convert to pool offsets * removing unnecessary defitinitions * remove global var map * adding explaination for let bindings to pointer type Change-Id: I31bd1a9f3057ee7f06252263565b0f75c51e6d13 --- include/tvm/tir/usmp/utils.h | 6 +---- src/tir/ir/stmt.cc | 2 ++ .../convert_pool_allocations_to_offsets.cc | 27 ++++++++++--------- 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/include/tvm/tir/usmp/utils.h b/include/tvm/tir/usmp/utils.h index a7a245c06378..30c8f2ddea49 100644 --- a/include/tvm/tir/usmp/utils.h +++ b/include/tvm/tir/usmp/utils.h @@ -293,11 +293,7 @@ namespace attr { * a PoolInfo Object in the form of a Map. */ static constexpr const char* kPoolArgs = "pool_args"; -/*! - * \brief This is a BaseFunc attribute to indicate which input var represent - * a PoolInfo Object in the form of a Map. - */ -static constexpr const char* kPoolInfoIRModuleAttr = "pool_infos"; + } // namespace attr } // namespace tvm diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 46c406ba902f..078561c447ad 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -36,6 +36,8 @@ LetStmt::LetStmt(Var var, PrimExpr value, Stmt body, Span span) { ICHECK(value.defined()); ICHECK(body.defined()); auto vdtype = value.dtype(); + // It is still valid to bind a pointer type + // var to a value that is of type handle. if (var->type_annotation.as()) { ICHECK(vdtype.is_handle()); } else { diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc index 1993f8685c15..eac83d9b8657 100644 --- a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc +++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc @@ -36,16 +36,21 @@ namespace tvm { namespace tir { namespace usmp { +/*! + * \brief The StmtExpr mutator class to replace allocate nodes + * with offsets within memory pools + * + * This mutator class with add Pool variables recursively to every PrimFunc + * starting from the main PrimFunc. For all allocate nodes, that have been + * memory planned, will be mutated into an offset using a Let binding. + */ class PoolAllocationToOffsetConverter : public StmtExprMutator { public: - explicit PoolAllocationToOffsetConverter(const IRModule& module, - const Map& pool_allocations, - bool emit_tvmscript_printable = false) + PoolAllocationToOffsetConverter(const IRModule& module, + const Map& pool_allocations, + bool emit_tvmscript_printable = false) : pool_allocations_(pool_allocations), emit_tvmscript_printable_(emit_tvmscript_printable) { module_ = module->ShallowCopy(); - for (const auto& gv_func : module_->functions) { - function_global_vars_.Set(gv_func.first->name_hint, gv_func.first); - } for (const auto& kv : pool_allocations) { // TODO(@manupa-arm): add AllocateConstNode when it is available ICHECK(kv.first->IsInstance()); @@ -135,10 +140,6 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator { std::vector allocated_pool_ordering_; /*! \brief The storage of calculated pool size at init */ std::unordered_map all_pools_sizes_; - /*! \brief The AoT codegen uses extern_calls due to some functions not being exposed in the TIR - * IRModule This maps maintains the map of which to each function - */ - Map function_global_vars_; /*! \brief After mutation, each allocate buffer is replaced with tir::Var that is let bounded * to position from a pool as designated by a PoolAllocation */ @@ -240,8 +241,8 @@ PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const CallNode* op) { if (op->op.same_as(builtin::call_extern()) || op->op.same_as(builtin::tvm_call_cpacked())) { String func_name = Downcast(op->args[0])->value; Array new_args; - if (function_global_vars_.find(func_name) != function_global_vars_.end()) { - GlobalVar gv = function_global_vars_.at(func_name); + if (module_->ContainGlobalVar(func_name)) { + GlobalVar gv = module_->GetGlobalVar(func_name); PrimFunc func = Downcast(module_->Lookup(gv)); PrimFunc prim_func = CreatePrimFuncWithPoolParams(func); module_->Update(gv, prim_func); @@ -304,7 +305,7 @@ PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const LoadNode* op) { } IRModule PoolAllocationToOffsetConverter::operator()() { - GlobalVar gv = function_global_vars_.at(::tvm::runtime::symbol::tvm_run_func_suffix); + GlobalVar gv = module_->GetGlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix); PrimFunc main_func = Downcast(module_->Lookup(gv)); ScopeInfo si = UpdateFunctionScopeInfo(main_func); this->scope_stack.push(si); From 3ac83792c51323b660d01be4a70172f534f94b13 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Wed, 8 Dec 2021 13:50:50 +0000 Subject: [PATCH 6/7] [TIR][USMP] adding the pass to convert to pool offsets * rebase changes * making imports absolute * fixing typos and removing unnecesary lines Change-Id: I4c94b9955b001513fecb39ca94f81b1ad99c7bfc --- python/tvm/tir/usmp/transform/transform.py | 7 ++++--- .../convert_pool_allocations_to_offsets.cc | 4 +--- ...ransform_convert_pool_allocations_to_offsets.py | 14 ++++++++++---- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/python/tvm/tir/usmp/transform/transform.py b/python/tvm/tir/usmp/transform/transform.py index 4976215c21a6..f472172cf36f 100644 --- a/python/tvm/tir/usmp/transform/transform.py +++ b/python/tvm/tir/usmp/transform/transform.py @@ -19,14 +19,15 @@ from typing import Dict +import tvm +from tvm.tir import Stmt +from tvm.tir.usmp.utils import PoolAllocation from . import _ffi_api -from ....tir import Stmt -from ..utils import PoolAllocation def convert_pool_allocations_to_offsets( pool_allocations: Dict[Stmt, PoolAllocation], emit_tvmscript_printable: bool = False -): +) -> tvm.transform.Pass: """Convert pool allocations to Load nodes with offsets from pools. Parameters diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc index eac83d9b8657..5ebf3c557b06 100644 --- a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc +++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc @@ -40,7 +40,7 @@ namespace usmp { * \brief The StmtExpr mutator class to replace allocate nodes * with offsets within memory pools * - * This mutator class with add Pool variables recursively to every PrimFunc + * This mutator class will add Pool variables recursively to every PrimFunc * starting from the main PrimFunc. For all allocate nodes, that have been * memory planned, will be mutated into an offset using a Let binding. */ @@ -88,7 +88,6 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator { private: PrimExpr VisitExpr_(const CallNode* op) override; Stmt VisitStmt_(const AllocateNode* op) override; - // PrimExpr VisitExpr_(const VarNode* op) override; PrimExpr VisitExpr_(const LoadNode* op) override; Stmt VisitStmt_(const StoreNode* op) override; @@ -270,7 +269,6 @@ Stmt PoolAllocationToOffsetConverter::VisitStmt_(const AllocateNode* op) { PoolAllocation pool_allocation = pool_allocations_[GetRef(op)]; Var param = scope_info.pools_to_params[pool_allocation->pool_info]; Buffer buffer_var = scope_info.buffer_map[param]; - ICHECK(pool_allocation->byte_offset < all_pools_sizes_[pool_allocation->pool_info]); Load load_node = Load(DataType::UInt(8), buffer_var->data, pool_allocation->byte_offset, op->condition); Call address_of_load = Call(DataType::Handle(8), builtin::address_of(), {load_node}); diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index a9d17cb9b6f6..4220abcdad74 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -218,12 +218,15 @@ def test_mobilenet_subgraph(): tir_mod, [fast_memory_pool, slow_memory_pool] ) main_func = tir_mod["run_model"] - buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) + buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) + buffer_info_map = buffer_analysis.buffer_info_stmts fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo") buffer_info_arr = fcreate_array_bi(buffer_info_map) fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size") - buffer_pool_allocations = fusmp_algo_greedy_by_size(buffer_info_arr) + buffer_pool_allocations = fusmp_algo_greedy_by_size( + buffer_info_arr, buffer_analysis.memory_pressure + ) fassign_stmt_pool_allocations = tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations") pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, buffer_pool_allocations) tir_mod_with_offsets = tvm.tir.usmp.transform.convert_pool_allocations_to_offsets( @@ -489,12 +492,15 @@ def test_resnet_subgraph(): tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool]) main_func = tir_mod["run_model"] - buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) + buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) + buffer_info_map = buffer_analysis.buffer_info_stmts fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo") buffer_info_arr = fcreate_array_bi(buffer_info_map) fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size") - buffer_pool_allocations = fusmp_algo_greedy_by_size(buffer_info_arr) + buffer_pool_allocations = fusmp_algo_greedy_by_size( + buffer_info_arr, buffer_analysis.memory_pressure + ) fassign_stmt_pool_allocations = tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations") pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, buffer_pool_allocations) tir_mod_with_offsets = tvm.tir.usmp.transform.convert_pool_allocations_to_offsets( From 2c6cf6cbc81be57384573c7d49ba14378ee35487 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Wed, 8 Dec 2021 14:52:45 +0000 Subject: [PATCH 7/7] [TIR][USMP] adding the pass to convert to pool offsets * fixing typos Change-Id: I42c557fd394aefdf8c2e825c4e88770eb0732f9b --- ...ir_usmp_transform_convert_pool_allocations_to_offsets.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index 4220abcdad74..fc615775c160 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -32,7 +32,7 @@ def _get_primfuncs_from_module(module): def assign_poolinfos_to_allocates_in_primfunc(primfunc, pool_infos): - """helper to assing poolinfos to allocate nodes in a tir.PrimFunc""" + """Helper to assign poolinfos to allocate nodes in a tir.PrimFunc""" def set_poolinfos(stmt): if isinstance(stmt, tvm.tir.Allocate): @@ -49,7 +49,7 @@ def set_poolinfos(stmt): def assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos): - """helper to assing poolinfos to allocate nodes in a IRModule""" + """Helper to assign poolinfos to allocate nodes in a IRModule""" ret = tvm.IRModule() for global_var, basefunc in mod.functions.items(): if isinstance(basefunc, tvm.tir.PrimFunc): @@ -58,7 +58,7 @@ def assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos): def _assign_targets_to_primfuncs_irmodule(mod, target): - """helper to assign target for PrimFunc in a IRModule""" + """Helper to assign target for PrimFunc in a IRModule""" ret = tvm.IRModule() for global_var, basefunc in mod.functions.items(): if isinstance(basefunc, tvm.tir.PrimFunc):