diff --git a/include/tvm/target/se_scope.h b/include/tvm/target/se_scope.h index 595f986686ed..ec5da3a80cae 100644 --- a/include/tvm/target/se_scope.h +++ b/include/tvm/target/se_scope.h @@ -170,7 +170,7 @@ class SEScopeNode : public AttrsNode { * * kInvalidDeviceType denotes unconstrained. */ - int device_type_int; + int /* actually DLDeviceType */ device_type_int; DLDeviceType device_type() const { return static_cast(device_type_int); } @@ -303,6 +303,11 @@ class SEScope : public ObjectRef { return SEScope(device_type, /*virtual_device_id=*/0, std::move(target)); } + /*! \brief Returns the \p SEScope for \p memory_scope alone. */ + static SEScope ForMemoryScope(MemoryScope memory_scope) { + return SEScope(kInvalidDeviceType, -1, {}, std::move(memory_scope)); + } + /*! \brief Returns the \p SEScope for \p device, \p target and \p memory_scope. */ TVM_DLL static SEScope ForDeviceTargetAndMemoryScope(const Device& device, Target target, MemoryScope memory_scope) { diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index f04209d0b061..69453e23ac1a 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -144,7 +144,7 @@ class Buffer : public ObjectRef { public: // User can specify data_alignment and offset_factor to be 0 // A default value will be picked. - TVM_DLL Buffer(Var ptr, DataType dtype, Array shape, Array strides, + TVM_DLL Buffer(Var data, DataType dtype, Array shape, Array strides, PrimExpr elem_offset, String name, int data_alignment, int offset_factor, BufferType buffer_type, Span span = Span()); diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 24773a5a471f..0b4ace20078c 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -280,6 +280,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor { */ Stmt VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit, std::function fmutate = nullptr); + // internal helper. class Internal; }; diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 40a0d1ab2f74..1ac58e18db3e 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -52,7 +52,7 @@ class VarNode : public PrimExprNode { */ String name_hint; /*! - * \brief type annotaion of the variable. + * \brief type annotation of the variable. * * It is an optional field that provides a refined type of the variable than dtype. * diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index d1aaa61c3aae..c2338dd9b611 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -18,6 +18,7 @@ # pylint: disable=invalid-name from typing import Dict, List +from tvm import Object from tvm.tir.stmt import Block, BufferRegion from tvm.tir.stmt import PrimExpr from tvm.tir.expr import Var @@ -196,3 +197,70 @@ def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]: Map from buffer to the LCA of all access to it. """ return _ffi_api.detect_buffer_access_lca(func) # type: ignore # pylint: disable=no-member + + +# NOTE: relay_func_type in the following two functions should be relay.FuncType however that would +# introduce a cycling dependency. We make do with Object. + + +def get_prim_func_arg_and_result_memory_constraints( + func: PrimFunc, relay_func_type: Object +) -> List[str]: + """Returns the memory (aka storage) scope constraints for all the arguments and result + of func. However the result will be w.r.t. the func's representation as a Relay Function + of relay_func_type before lowering and conversion to DPS. + + Visible for testing. + + Parameters + ---------- + func: tvm.tir.PrimFunc + The function to retrieve constraints from. + + relay_func_type: tvm.relay.FuncType + The type of the Relay Function from which the func was derived. + + Returns + ------- + result: List[AnyStr] + Memory scope constraints for funcs args and result in Relay form. The empty string + denotes 'no constraint'. + """ + return _ffi_api.GetPrimFuncArgAndResultMemoryConstraints( # type: ignore # pylint: disable=no-member + func, relay_func_type + ) + + +def apply_prim_func_arg_and_result_memory_constraints( + func: PrimFunc, relay_func_type: Object, arg_and_result_memory_scopes: List[str] +) -> PrimFunc: + """Returns func written to capture the memory (aka storage) scope constraints + for each of the func's parameters given by arg_and_result_memory_scopes. However, + arg_and_result_memory_scopes should be w.r.t. the func's representation as a Relay + Function of relay_func_type before lowering and conversion to DPS. + + Visible for testing. + + CAUTION: This is experimental. The resulting PrimFunc may not have fully accounted + for all new memory scopes. + + Parameters + ---------- + func: tvm.tir.PrimFunc + The function to retrieve constraints from. + + relay_func_type: tvm.relay.FuncType + The type of the Relay Function from which the func was derived. + + arg_and_result_memory_scopes: Array[AnyStr] + Memory constraints for funcs args and result in Relay form. The empty string denotes + 'no constraint'. + + Returns + ------- + result: tvm.tir.PrimFunc + The rewritten func. + """ + return _ffi_api.ApplyPrimFuncArgAndResultMemoryConstraints( # type: ignore # pylint: disable=no-member + func, relay_func_type, arg_and_result_memory_scopes + ) diff --git a/src/tir/analysis/device_constraint_utils.cc b/src/tir/analysis/device_constraint_utils.cc new file mode 100644 index 000000000000..8412cb8b8923 --- /dev/null +++ b/src/tir/analysis/device_constraint_utils.cc @@ -0,0 +1,514 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/analysis/apply_device_constraints.cc + * \brief Applies device-related constraints to \p PrimFunc parameters. + * + * This is used by the \p PlanDevices pass to flow device-constraints *into* \p PrimFuncs. + * + * Currently only applies memory scope constraints into \p Buffer data pointer + * storage scopes. Aliased ('matched') buffers take on any scope introduced on + * the buffer they alias. However currently does not attempt to flow constraints into + * allocated buffers. + */ + +#include "./device_constraint_utils.h" + +#include +#include +#include +#include + +namespace tvm { +namespace tir { +namespace { + +/*! + * \brief Returns the \p PointerTypeNode for \p buffer, or nullptr if \p buffer does not describe a + * pointer. + */ +const PointerTypeNode* PointerInBuffer(const tir::Buffer& buffer) { + return buffer->data->type_annotation.defined() + ? buffer->data->type_annotation.as() + : nullptr; +} + +/*! + * \brief Returns the parameter variable and corresponding buffer at or after \p + * *current_primfunc_param_index in \p prim_func. Will skip over any non-pointer parameters. This + * can be used to find the parameter matching a tensor type in a flattened Relay function parameter + * or result. + */ +std::pair FindPointerParam(const tir::PrimFunc& prim_func, + size_t* current_primfunc_param_index) { + while (true) { + ICHECK_LT(*current_primfunc_param_index, prim_func->params.size()); + const tir::Var& param = prim_func->params[*current_primfunc_param_index]; + auto itr = prim_func->buffer_map.find(param); + if (itr == prim_func->buffer_map.end()) { + VLOG(2) << "no buffer map entry for '" << param->name_hint << "'"; + ++*current_primfunc_param_index; + continue; + } + const auto* pointer_type_node = PointerInBuffer((*itr).second); + if (pointer_type_node == nullptr) { + VLOG(2) << "not a pointer type for '" << param->name_hint << "'"; + ++*current_primfunc_param_index; + continue; + } + VLOG(2) << "using PrimFunc param '" << param->name_hint << "'"; + return *itr; + } +} + +/*! + * \brief Check fails if any parameter at or after \p *current_primfunc_param_index in \p prim_func + * is for a pointer type. This can be used to check all \p prim_func parameters have been accounted + * for when using \p FindPointerParam above. + */ +void CheckNoRemainingPointerParams(const tir::PrimFunc& prim_func, + size_t* current_primfunc_param_index) { + while (*current_primfunc_param_index < prim_func->params.size()) { + const tir::Var& param = prim_func->params[*current_primfunc_param_index]; + auto itr = prim_func->buffer_map.find(param); + if (itr == prim_func->buffer_map.end()) { + VLOG(1) << "no buffer map entry for '" << param->name_hint << "'"; + ++*current_primfunc_param_index; + continue; + } + const auto* pointer_type_node = PointerInBuffer((*itr).second); + ICHECK(pointer_type_node == nullptr); + ++*current_primfunc_param_index; + } +} + +/*! + * \brief Returns the (consistent) constraint to use for a Relay parameter of \p type, + * using \p prim_func parameters at or after \p *current_primfunc_param_index. Currently + * only memory scope is extracted. Fails if constraints are not consistent, ie \p type is a tuple + * type and the \p prim_func is attempting to map different fields of that tuple to different memory + * scopes. Returns the fully unconstrained \p SEScope if no memory scopes constraints arise from + * the \p prim_func, ie all storage scope strings in pointer types are empty. + */ +SEScope ConsistentParamConstraint(const tir::PrimFunc& prim_func, const Type& type, + size_t* current_primfunc_param_index) { + std::string memory_scope; // default empty => no constraint + for (size_t i = 0; i < relay::FlattenTupleType(type).size(); ++i) { + std::pair kv = FindPointerParam(prim_func, current_primfunc_param_index); + const tir::Buffer& buffer = kv.second; + const auto* pointer_type_node = buffer->data->type_annotation.as(); + const MemoryScope& buffer_memory_scope = pointer_type_node->storage_scope; + if (memory_scope.empty()) { + memory_scope = buffer_memory_scope; + } else if (buffer_memory_scope.empty()) { + // No constraint. + } else { + // Tuples must be homogenous on their SEScope and thus memory scope. + ICHECK_EQ(buffer_memory_scope, memory_scope); + } + ++*current_primfunc_param_index; + } + return SEScope::ForMemoryScope(memory_scope); +} + +/*! + * \brief Insert into param_constraints an entry for each parameter of \p prim_func starting from + * \p *current_primfunc_param_index for the flattened form of a Rleay parameters of \p type. Each + * entry maps to \p se_scope. + */ +void InsertParamConstraints(const tir::PrimFunc& prim_func, const Type& type, + const SEScope& se_scope, size_t* current_primfunc_param_index, + std::unordered_map* param_constraints) { + for (size_t i = 0; i < relay::FlattenTupleType(type).size(); ++i) { + std::pair kv = FindPointerParam(prim_func, current_primfunc_param_index); + param_constraints->emplace(kv.first.get(), se_scope); + ++*current_primfunc_param_index; + } +} + +/*! + * \brief Apply the memory scope constraints to the \p Buffers and data \p Vars of a \p PrimFunc. + * + * All definitional occurrences of buffer Vars are rewritten to capture memory scopes in their + * PointerTypes: + * - Buffer::data (if the buffer itself is a definitional occurrence) + * - AllocateNode::buffer_var + * - FUTURE: LetStmtNode::var if aliasing a buffer data var. + * + * All referential occurrences of buffer Vars are replaced with their new definitions: + * - LoadNode::buffer_var + * - StoreNode::buffer_var + * + * Similarly all definitional occurrences of Buffers are rewritten to account for any new memory + * scopes: + * - PrimFuncNode::buffer_map keys. + * - BlockNode::match_buffers.buffer + * - FUTURE: BlockNode::alloc_buffers? + * + * And all referential occurrences of Buffers are replaced with their new definitions: + * - BufferLoadNode::buffer + * - BufferStoreNode::buffer + * - BufferRealizeNode::buffer + * - PrefetchNode::buffer + * - BufferRegionNode:buffer + * - BlockNode.match_buffers.source.buffer + * - BlockNode::{reads, writes}.buffer + * + * CAUTION: We assume strict sharing of Buffer objects and do not attempt to rewrite the bodies + * of referential buffers. + * + * CAUTION: EXPERIMENTAL: We don't yet account for all buffers and pointer types. + */ +class ApplyDeviceConstraintsMutator : public StmtExprMutator { + public: + ApplyDeviceConstraintsMutator() = default; + + /*! + * \brief Returns \p prim_func written to capture the memory scope constraints in \p + * param_constraints for each pointer \p prim_func parameter. Returns \p prim_func unchanged if no + * memory scopes needed to change. + */ + PrimFunc Rewrite(const PrimFunc& prim_func, const FuncType& relay_func_type, + const Array& arg_and_result_se_scopes) { + size_t current_primfunc_param_index = 0; + std::unordered_map param_constraints; + + // For each Relay function parameter... + for (size_t i = 0; i < relay_func_type->arg_types.size(); ++i) { + const Type& param_type = relay_func_type->arg_types[i]; + const SEScope& param_se_scope = arg_and_result_se_scopes[i]; + InsertParamConstraints(prim_func, param_type, param_se_scope, ¤t_primfunc_param_index, + ¶m_constraints); + } + + // For the Relay function result... + const Type& ret_type = relay_func_type->ret_type; + const SEScope& ret_se_scope = arg_and_result_se_scopes.back(); + InsertParamConstraints(prim_func, ret_type, ret_se_scope, ¤t_primfunc_param_index, + ¶m_constraints); + + // Make sure we accounted for all prim_func parameters. + CheckNoRemainingPointerParams(prim_func, ¤t_primfunc_param_index); + + // Start with a copy of the current prim_func buffer map. + Map new_buffer_map(prim_func->buffer_map.begin(), prim_func->buffer_map.end()); + bool any_change = false; + + // For each constrained parameter... + for (const auto& kv : param_constraints) { + const tir::Var param = GetRef(kv.first); + const SEScope& se_scope = kv.second; + const tir::Buffer& buffer = prim_func->buffer_map[param]; + // Rewrite the buffer to account for constraint. + const Buffer new_buffer = RewriteBuffer(buffer, se_scope); + if (!new_buffer.same_as(buffer)) { + any_change = true; + } + new_buffer_map.Set(param, new_buffer); + } + // Make sure we have accounted for all prim_func parameters. + CheckNoRemainingPointerParams(prim_func, ¤t_primfunc_param_index); + + // Apply data variable and buffer substitutions to the prim_func body. These will have been + // accumulated from processing the parameters above. + Stmt new_body = VisitStmt(prim_func->body); + if (!new_body.same_as(prim_func->body)) { + any_change = true; + } + + // We are done with the substitutions. + var_subst_.clear(); + buffer_subst_.clear(); + + if (any_change) { + return PrimFunc(prim_func->params, std::move(new_body), prim_func->ret_type, + std::move(new_buffer_map), prim_func->attrs, prim_func->span); + } else { + return prim_func; + } + } + + private: + PrimExpr VisitExpr_(const VarNode* var_node) final { return Subst(var_node); } + + PrimExpr VisitExpr_(const LoadNode* load_node) final { + Load new_load = Downcast(StmtExprMutator::VisitExpr_(load_node)); + Var new_buffer_var = Subst(new_load->buffer_var.get()); + if (!new_buffer_var.same_as(new_load->buffer_var)) { + return Load(load_node->dtype, new_buffer_var, load_node->index, load_node->predicate); + } + return new_load; + } + + PrimExpr VisitExpr_(const BufferLoadNode* buffer_load_node) final { + BufferLoad new_buffer_load = + Downcast(StmtExprMutator::VisitExpr_(buffer_load_node)); + Buffer new_buffer = Subst(new_buffer_load->buffer.get()); + if (!new_buffer.same_as(new_buffer_load->buffer)) { + return BufferLoad(new_buffer, new_buffer_load->indices, new_buffer_load->span); + } + return new_buffer_load; + } + + Stmt VisitStmt_(const LetStmtNode* let_stmt_node) final { + // TODO(mbs): If the let-bound var is aliasing an existing buffer data var we need to + // rewrite it. + return StmtExprMutator::VisitStmt_(let_stmt_node); + } + + Stmt VisitStmt_(const AttrStmtNode* attr_stmt_node) final { + AttrStmt new_attr_stmt = Downcast(StmtExprMutator::VisitStmt_(attr_stmt_node)); + // remap node if a var + if (const auto* var_node = new_attr_stmt->node.as()) { + Var new_var = Subst(var_node); + if (!new_var.same_as(new_attr_stmt->node)) { + return AttrStmt(new_var, new_attr_stmt->attr_key, new_attr_stmt->value, + new_attr_stmt->body); + } + } + return new_attr_stmt; + } + + // ForNode default ok since loop_var never of PointerType + + // WhileNode default ok + + Stmt VisitStmt_(const AllocateNode* allocate_node) final { + // TODO(mbs): What memory scope should we assign to the new pointer? + return StmtExprMutator::VisitStmt_(allocate_node); + } + + Stmt VisitStmt_(const StoreNode* store_node) final { + Store new_store = Downcast(StmtExprMutator::VisitStmt_(store_node)); + Var new_buffer_var = Subst(new_store->buffer_var.get()); + if (!new_buffer_var.same_as(new_store->buffer_var)) { + Store(new_buffer_var, new_store->value, new_store->index, new_store->predicate); + } + return new_store; + } + + Stmt VisitStmt_(const BufferStoreNode* buffer_store_node) final { + BufferStore new_buffer_store = + Downcast(StmtExprMutator::VisitStmt_(buffer_store_node)); + Buffer new_buffer = Subst(new_buffer_store->buffer.get()); + if (!new_buffer.same_as(new_buffer_store->buffer)) { + return BufferStore(new_buffer, new_buffer_store->value, new_buffer_store->indices, + new_buffer_store->span); + } + return new_buffer_store; + } + + Stmt VisitStmt_(const BufferRealizeNode* buffer_realize_node) final { + BufferRealize new_buffer_realize = + Downcast(StmtExprMutator::VisitStmt_(buffer_realize_node)); + Buffer new_buffer = Subst(new_buffer_realize->buffer.get()); + if (!new_buffer.same_as(new_buffer_realize->buffer)) { + return BufferRealize(new_buffer, new_buffer_realize->bounds, new_buffer_realize->condition, + new_buffer_realize->body, new_buffer_realize->span); + } + return new_buffer_realize; + } + + // IfThenElseNode default ok + // AssertStmtNode default ok + // ProducerStoreNode default ok (though does not visit producer) + // ProducerRealizeNode default ok (though does not visit producer) + + Stmt VisitStmt_(const PrefetchNode* prefetch_node) final { + Prefetch new_prefetch = Downcast(StmtExprMutator::VisitStmt_(prefetch_node)); + Buffer new_buffer = Subst(new_prefetch->buffer.get()); + if (!new_buffer.same_as(new_prefetch->buffer)) { + return Prefetch(new_buffer, prefetch_node->bounds, prefetch_node->span); + } + return new_prefetch; + } + + // SeqStmtNode default ok + // EvaluateNode default ok + + BufferRegion VisitItem(const BufferRegionNode* buffer_region_node) { + Buffer new_buffer = Subst(buffer_region_node->buffer.get()); + if (!new_buffer.same_as(buffer_region_node->buffer)) { + return BufferRegion(new_buffer, buffer_region_node->region); + } + return GetRef(buffer_region_node); + } + + MatchBufferRegion VisitItem(const MatchBufferRegionNode* match_buffer_region_node) { + // The source field has a referential occurrence of the buffer. Apply the buffer substitution + // to that. + BufferRegion new_source = VisitItem(match_buffer_region_node->source.get()); + // The buffer field however is a definitional occurrence, aliased on top of the source. + // Transfer any memory scope from the source to the destination. + Optional opt_se_scope = GetBufferConstraint(new_source->buffer); + tir::Buffer new_buffer; + if (opt_se_scope.defined()) { + new_buffer = RewriteBuffer(match_buffer_region_node->buffer, opt_se_scope.value()); + } else { + new_buffer = match_buffer_region_node->buffer; + } + if (!new_buffer.same_as(match_buffer_region_node->buffer) || + !new_source.same_as(match_buffer_region_node->source)) { + return MatchBufferRegion(new_buffer, new_source); + } + return GetRef(match_buffer_region_node); + } + + template + Array VisitItems(Array items) { + items.MutateByApply([this](const T& item) { return VisitItem(item.get()); }); // copy-on-write + return items; + } + + Stmt VisitStmt_(const BlockNode* block_node) final { + Block new_block = Downcast(StmtExprMutator::VisitStmt_(block_node)); + Array new_reads = VisitItems(new_block->reads); + Array new_writes = VisitItems(new_block->writes); + // TODO(mbs): What memory scope should we assign to the new buffers? + Array new_match_buffers = VisitItems(new_block->match_buffers); + if (!new_reads.same_as(new_block->reads) || new_writes.same_as(new_block->writes) || + new_match_buffers.same_as(new_block->match_buffers)) { + return Block(new_block->iter_vars, std::move(new_reads), std::move(new_writes), + new_block->name_hint, new_block->body, new_block->init, new_block->alloc_buffers, + std::move(new_match_buffers), new_block->annotations, new_block->span); + } + return new_block; + } + + // BlockRealizeNode default ok + + /*! Applies \p var_subst_ substitution to \p var_node. */ + Var Subst(const VarNode* var_node) const { + auto itr = var_subst_.find(var_node); + return itr == var_subst_.end() ? GetRef(var_node) : itr->second; + } + + /*! Applies \p buffer_subst_ substitution to \p buffer. */ + Buffer Subst(const BufferNode* buffer_node) const { + auto itr = buffer_subst_.find(buffer_node); + return itr == buffer_subst_.end() ? GetRef(buffer_node) : itr->second; + } + + /*! + * \brief Rewrites \p buffer so as to follow the constraints in \p se_scope + * (currently just memory scope). + * + * Updates both the var_subst_ and buffer_subst_ to capture the rewrite, but + * also returns the new buffer. + */ + Buffer RewriteBuffer(const Buffer& buffer, const SEScope& se_scope) { + ICHECK(buffer->data->type_annotation.defined()); + const auto* pointer_type_node = buffer->data->type_annotation.as(); + ICHECK(pointer_type_node); + if (pointer_type_node->storage_scope == se_scope->memory_scope) { + // No change. + return buffer; + } + PointerType new_pointer_type(pointer_type_node->element_type, se_scope->memory_scope); + Var new_data(buffer->data->name_hint, new_pointer_type, buffer->data->span); + var_subst_.emplace(buffer->data.get(), new_data); + Buffer new_buffer(new_data, buffer->dtype, buffer->shape, buffer->strides, buffer->elem_offset, + buffer->name, buffer->data_alignment, buffer->offset_factor, + buffer->buffer_type, buffer->span); + buffer_subst_.emplace(buffer.get(), new_buffer); + return new_buffer; + } + + /*! + * \brief Returns the SEScope capturing any memory scope in \p buffer. Returns nullptr if + * buffer's data var does not have a type annotation of \p PointerType. Returns the fully + * unconstrained \p SEScope if no memory scope is given. + */ + static Optional GetBufferConstraint(const tir::Buffer& buffer) { + const auto* pointer_type_node = PointerInBuffer(buffer); + return pointer_type_node == nullptr ? Optional() + : SEScope::ForMemoryScope(pointer_type_node->storage_scope); + } + + /*! + * \brief Maps each \p Buffer::data \p Var to its constrained equivalent. + */ + std::unordered_map var_subst_; + + /*! + * \brief Maps each \p Buffer to its constrained equivalent. + */ + std::unordered_map buffer_subst_; +}; + +} // namespace + +Array GetPrimFuncArgAndResultConstraints(const tir::PrimFunc& prim_func, + const FuncType& relay_func_type) { + // Build the implied domain (in terms of the function's Relay type) implied by any memory scope + // constrains in the function's buffers, for both arguments and results. + Array se_scopes; + se_scopes.reserve(relay_func_type->arg_types.size() + 1); + + // For each Relay function parameter... + size_t current_primfunc_param_index = 0; + for (const auto& param_type : relay_func_type->arg_types) { + SEScope param_se_scope = + ConsistentParamConstraint(prim_func, param_type, ¤t_primfunc_param_index); + se_scopes.push_back(param_se_scope); + } + + // For the Relay function result... + const Type& ret_type = relay_func_type->ret_type; + SEScope ret_se_scope = + ConsistentParamConstraint(prim_func, ret_type, ¤t_primfunc_param_index); + se_scopes.push_back(ret_se_scope); + + // Make sure all parameters of the prim_func have been accounted for. + CheckNoRemainingPointerParams(prim_func, ¤t_primfunc_param_index); + + return se_scopes; +} + +TVM_REGISTER_GLOBAL("tir.analysis.GetPrimFuncArgAndResultMemoryConstraints") + .set_body_typed([](const PrimFunc& prim_func, const FuncType& relay_func_type) { + Array memory_scopes; + memory_scopes.reserve(relay_func_type->type_params.size() + 1); + for (const auto& se_scope : GetPrimFuncArgAndResultConstraints(prim_func, relay_func_type)) { + memory_scopes.push_back(se_scope->memory_scope); + } + return memory_scopes; + }); + +PrimFunc ApplyPrimFuncArgAndResultConstraints(const PrimFunc& prim_func, + const FuncType& relay_func_type, + const Array& arg_and_result_se_scopes) { + return ApplyDeviceConstraintsMutator().Rewrite(prim_func, relay_func_type, + arg_and_result_se_scopes); +} + +TVM_REGISTER_GLOBAL("tir.analysis.ApplyPrimFuncArgAndResultMemoryConstraints") + .set_body_typed([](const PrimFunc& prim_func, const FuncType& relay_func_type, + const Array& arg_and_result_memory_scopes) { + Array se_scopes; + se_scopes.reserve(arg_and_result_memory_scopes.size()); + for (const auto& memory_scope : arg_and_result_memory_scopes) { + se_scopes.push_back(SEScope::ForMemoryScope(memory_scope)); + } + return ApplyPrimFuncArgAndResultConstraints(prim_func, relay_func_type, se_scopes); + }); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/analysis/device_constraint_utils.h b/src/tir/analysis/device_constraint_utils.h new file mode 100644 index 000000000000..be0f199f5226 --- /dev/null +++ b/src/tir/analysis/device_constraint_utils.h @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/analysis/device_constraint_utils.cc + * \brief Utilities for extracting and applying device-related constraints to \p PrimFunc + * parameters. + * + * These utilities are used by the \p PlanDevices pass to extract memory (aka 'storage') scope + * information from \p PrimFuncs and convert them back into \p SEScope form w.r.t. the original + * Relay type of the \p PrimFunc (ie before flattening of tuple arguments/results and conversion + * to destination-passing style aka DPS). + * + * A utility is also supplied to go the other way: impose memory scopes on \p PrimFunc parameters. + * However that's still in EXPERIMENTAL form. + * + * We may extend these utilities to also gather/apply layout information should we add that to + * \p SEScope. + */ + +#ifndef TVM_TIR_ANALYSIS_DEVICE_CONSTRAINT_UTILS_H_ +#define TVM_TIR_ANALYSIS_DEVICE_CONSTRAINT_UTILS_H_ + +#include +#include + +namespace tvm { +namespace tir { + +/*! + * A Relay Function with type: + * \code + * fn((Tensor[...], Tensor[...]), Tensor[...]) -> (Tensor[...], Tensor[...]) + * ^ ^ ^ ^ ^ + * a b c d e + * \endcode + * will be represented by a TIR PrimFunc in flattened and DPS form with at least 5 argument a..e. + * \code + * primfn(a: handle, b: handle, c: handle, d: handle, e: handle) { + * buffers = { ... } + * buffer_map = { ... } + * ... + * } + * \endcode + * + * Each such PrimFunc argument will me mapped to a \p Buffer who's underlying \p data \p Var + * has a \p PointerType. + * + * The PrimFunc may have additional non-pointer arguments, eg for: + * - scalar inputs and tensor dimensions + * - device contexts + * Those should be ignored here since they have no counterpart in the Relay Function. + * + * We'll need helpers to map on-the-fly between the Relay and TIR view of functions. + */ + +/*! + * \brief Returns the \p SEScopes capturing the memory (aka storage) scope constraints for all the + * arguments and result of \p prim_func. However the result will be w.r.t. the \p prim_func's + * representation as a Relay \p Function of \p relay_func_type_ before lowering and conversion to + * DPS. + */ +Array GetPrimFuncArgAndResultConstraints(const tir::PrimFunc& prim_func, + const FuncType& relay_func_type); + +/* + * \brief Returns \p prim_func written to capture the memory (aka storage) scope constraints + * for each of the \p prim_func's parameters given by \p arg_and_result_se_scopes. However, + * \p arg_and_result_se_scopes should be w.r.t. the \p prim_func's representation as a Relay + * \p Function of \p relay_func_type before lowering and conversion to DPS. + * + * CAUTION: This is experimental. The resulting \p PrimFunc may not have fully accounted for all + * new memory scopes. + */ +PrimFunc ApplyPrimFuncArgAndResultConstraints(const PrimFunc& prim_func, + const FuncType& relay_func_type, + const Array& arg_and_result_se_scopes); + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_ANALYSIS_DEVICE_CONSTRAINT_UTILS_H_ diff --git a/tests/python/tir/analysis/test_device_constraint_utils.py b/tests/python/tir/analysis/test_device_constraint_utils.py new file mode 100644 index 000000000000..65cb4e398294 --- /dev/null +++ b/tests/python/tir/analysis/test_device_constraint_utils.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test retrieving and applying memory scope constraints to PrimFuncs""" +import tvm +from tvm import tir +from tvm import relay +from tvm.script import tir as T + + +@T.prim_func +def gem(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, [128, 128], scope="scopeA") + B = T.match_buffer(b, [128, 128], scope="scopeA") + C = T.match_buffer(c, [128, 128], scope="scopeB") + D = T.match_buffer(d, [128, 128], scope="scopeC") + + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + D[vi, vj] = C[vi, vj] + D[vi, vj] = D[vi, vj] + A[vi, vk] * B[vj, vk] + + +gem_ty = relay.FuncType( + [ + relay.TupleType( + [ + relay.TensorType((128, 128), "float32"), + relay.TensorType((128, 128), "float32"), + ] + ), + relay.TensorType((128, 128), "float32"), + ], + relay.TensorType((128, 128), "float32"), +) + + +def test_get_prim_func_arg_and_result_constraints(): + scopes = tir.analysis.get_prim_func_arg_and_result_memory_constraints(gem, gem_ty) + assert [x for x in scopes] == ["scopeA", "scopeB", "scopeC"] + + +def test_apply_prim_func_arg_and_result_memory_constraints(): + rewritten = tir.analysis.apply_prim_func_arg_and_result_memory_constraints( + gem, gem_ty, ["scopeX", "scopeY", "scopeZ"] + ) + scopes = tir.analysis.get_prim_func_arg_and_result_memory_constraints(rewritten, gem_ty) + assert [x for x in scopes] == ["scopeX", "scopeY", "scopeZ"] + + +if __name__ == "__main__": + import sys + import pytest + + sys.exit(pytest.main([__file__] + sys.argv[1:]))