From ce37cd58e8cbe6a6d8f732fd83a8e1a86314427b Mon Sep 17 00:00:00 2001 From: Mark Shields Date: Wed, 8 Dec 2021 17:16:49 -0800 Subject: [PATCH 1/2] [TIR] Allow memory (aka storage) scopes to be retrieved/applied to PrimFuncs This is in support of #9613 which allows memory scopes to flow out of already-lowered PrimFuncs into the rest of the Relay program. This means scope choices made during lowering can be accounted for in the rest of the program, with device_copies inserted as required. Somewhat more speculatively we also allow memory scopes to flow in to PrimFuncs. This is in preparation for when we can split lowering into two phases: i) lower "primitive" fused Relay functions to TensorIR in a schedulable form roughly isomorphic to TE, and ii) actual scheduling down to traditional TIR. Once that split is made it will be possible to flow memory scopes out of one PrimFunc and into another so as to avoid unnecessary device_copies being necessary due to independently chosen memory scopes. I also suspect we'll want to put our focus on layouts rather than memory scopes, but this at least sets up some of the machinery. --- include/tvm/target/se_scope.h | 7 +- include/tvm/tir/analysis.h | 3 + include/tvm/tir/buffer.h | 2 +- include/tvm/tir/stmt_functor.h | 1 + include/tvm/tir/var.h | 2 +- python/tvm/tir/analysis/analysis.py | 70 ++- src/tir/analysis/device_constraint_utils.cc | 523 ++++++++++++++++++ src/tir/analysis/device_constraint_utils.h | 85 +++ .../analysis/test_device_constraint_utils.py | 70 +++ 9 files changed, 759 insertions(+), 4 deletions(-) create mode 100644 src/tir/analysis/device_constraint_utils.cc create mode 100644 src/tir/analysis/device_constraint_utils.h create mode 100644 tests/python/tir/analysis/test_device_constraint_utils.py diff --git a/include/tvm/target/se_scope.h b/include/tvm/target/se_scope.h index 595f986686ed..ec5da3a80cae 100644 --- a/include/tvm/target/se_scope.h +++ b/include/tvm/target/se_scope.h @@ -170,7 +170,7 @@ class SEScopeNode : public AttrsNode { * * kInvalidDeviceType denotes unconstrained. */ - int device_type_int; + int /* actually DLDeviceType */ device_type_int; DLDeviceType device_type() const { return static_cast(device_type_int); } @@ -303,6 +303,11 @@ class SEScope : public ObjectRef { return SEScope(device_type, /*virtual_device_id=*/0, std::move(target)); } + /*! \brief Returns the \p SEScope for \p memory_scope alone. */ + static SEScope ForMemoryScope(MemoryScope memory_scope) { + return SEScope(kInvalidDeviceType, -1, {}, std::move(memory_scope)); + } + /*! \brief Returns the \p SEScope for \p device, \p target and \p memory_scope. */ TVM_DLL static SEScope ForDeviceTargetAndMemoryScope(const Device& device, Target target, MemoryScope memory_scope) { diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index 51bdb18d2217..fa63b2617f48 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -26,12 +26,14 @@ #include #include +#include #include #include #include #include #include +#include namespace tvm { namespace tir { @@ -242,4 +244,5 @@ TVM_DLL Pass VerifyGPUCode(Map constraints); } // namespace transform } // namespace tir } // namespace tvm + #endif // TVM_TIR_ANALYSIS_H_ diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index f04209d0b061..69453e23ac1a 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -144,7 +144,7 @@ class Buffer : public ObjectRef { public: // User can specify data_alignment and offset_factor to be 0 // A default value will be picked. - TVM_DLL Buffer(Var ptr, DataType dtype, Array shape, Array strides, + TVM_DLL Buffer(Var data, DataType dtype, Array shape, Array strides, PrimExpr elem_offset, String name, int data_alignment, int offset_factor, BufferType buffer_type, Span span = Span()); diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 24773a5a471f..0b4ace20078c 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -280,6 +280,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor { */ Stmt VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit, std::function fmutate = nullptr); + // internal helper. class Internal; }; diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 40a0d1ab2f74..1ac58e18db3e 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -52,7 +52,7 @@ class VarNode : public PrimExprNode { */ String name_hint; /*! - * \brief type annotaion of the variable. + * \brief type annotation of the variable. * * It is an optional field that provides a refined type of the variable than dtype. * diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index d1aaa61c3aae..c74837aa820a 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -16,8 +16,9 @@ # under the License. """Wrapping existing analysis utils.""" # pylint: disable=invalid-name -from typing import Dict, List +from typing import Dict, List, AnyStr +from tvm import Object from tvm.tir.stmt import Block, BufferRegion from tvm.tir.stmt import PrimExpr from tvm.tir.expr import Var @@ -196,3 +197,70 @@ def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]: Map from buffer to the LCA of all access to it. """ return _ffi_api.detect_buffer_access_lca(func) # type: ignore # pylint: disable=no-member + + +# NOTE: relay_func_type in the following two functions should be relay.FuncType however that would +# introduce a cycling dependency. We make do with Object. + + +def get_prim_func_arg_and_result_memory_constraints( + func: PrimFunc, relay_func_type: Object +) -> List[AnyStr]: + """Returns the memory (aka storage) scope constraints for all the arguments and result + of func. However the result will be w.r.t. the func's representation as a Relay Function + of relay_func_type before lowering and conversion to DPS. + + Visible for testing. + + Parameters + ---------- + func: tvm.tir.PrimFunc + The function to retrieve constraints from. + + relay_func_type: tvm.relay.FuncType + The type of the Relay Function from which the func was derived. + + Returns + ------- + result: List[AnyStr] + Memory scope constraints for funcs args and result in Relay form. The empty string + denotes 'no constraint'. + """ + return _ffi_api.GetPrimFuncArgAndResultMemoryConstraints( # type: ignore # pylint: disable=no-member + func, relay_func_type + ) + + +def apply_prim_func_arg_and_result_memory_constraints( + func: PrimFunc, relay_func_type: Object, arg_and_result_memory_scopes: List[AnyStr] +) -> PrimFunc: + """Returns func written to capture the memory (aka storage) scope constraints + for each of the func's parameters given by arg_and_result_memory_scopes. However, + arg_and_result_memory_scopes should be w.r.t. the func's representation as a Relay + Function of relay_func_type before lowering and conversion to DPS. + + Visible for testing. + + CAUTION: This is experimental. The resulting PrimFunc may not have fully accounted + for all new memory scopes. + + Parameters + ---------- + func: tvm.tir.PrimFunc + The function to retrieve constraints from. + + relay_func_type: tvm.relay.FuncType + The type of the Relay Function from which the func was derived. + + arg_and_result_memory_scopes: Array[AnyStr] + Memory constraints for funcs args and result in Relay form. The empty string denotes + 'no constraint'. + + Returns + ------- + result: tvm.tir.PrimFunc + The rewritten func. + """ + return _ffi_api.ApplyPrimFuncArgAndResultMemoryConstraints( # type: ignore # pylint: disable=no-member + func, relay_func_type, arg_and_result_memory_scopes + ) diff --git a/src/tir/analysis/device_constraint_utils.cc b/src/tir/analysis/device_constraint_utils.cc new file mode 100644 index 000000000000..a2b2e046ab10 --- /dev/null +++ b/src/tir/analysis/device_constraint_utils.cc @@ -0,0 +1,523 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/analysis/apply_device_constraints.cc + * \brief Applies device-related constraints to \p PrimFunc parameters. + * + * This is used by the \p PlanDevices pass to flow device-constraints *into* \p PrimFuncs. + * + * Currently only applies memory scope constraints into \p Buffer data pointer + * storage scopes. Aliased ('matched') buffers take on any scope introduced on + * the buffer they alias. However currently does not attempt to flow constraints into + * allocated buffers. + */ + +#include "./device_constraint_utils.h" + +#include +#include +#include +#include + +namespace tvm { +namespace tir { +namespace { + +/*! + * \brief Returns the \p PointerTypeNode for \p buffer, or nullptr if \p buffer does not describe a + * pointer. + */ +const PointerTypeNode* PointerInBuffer(const tir::Buffer& buffer) { + return buffer->data->type_annotation.defined() + ? buffer->data->type_annotation.as() + : nullptr; +} + +/*! + * \brief Returns the parameter variable and corresponding buffer at or after \p + * *current_primfunc_param_index in \p prim_func. Will skip over any non-pointer parameters. This + * can be used to find the parameter matching a tensor type in a flattened Relay function parameter + * or result. + */ +std::pair FindPointerParam(const tir::PrimFunc& prim_func, + size_t* current_primfunc_param_index) { + while (true) { + ICHECK_LT(*current_primfunc_param_index, prim_func->params.size()); + const tir::Var& param = prim_func->params[*current_primfunc_param_index]; + auto itr = prim_func->buffer_map.find(param); + if (itr == prim_func->buffer_map.end()) { + VLOG(2) << "no buffer map entry for '" << param->name_hint << "'"; + ++*current_primfunc_param_index; + continue; + } + const auto* pointer_type_node = PointerInBuffer((*itr).second); + if (pointer_type_node == nullptr) { + VLOG(2) << "not a pointer type for '" << param->name_hint << "'"; + ++*current_primfunc_param_index; + continue; + } + VLOG(2) << "using PrimFunc param '" << param->name_hint << "'"; + return *itr; + } +} + +/*! + * \brief Check fails if any parameter at or after \p *current_primfunc_param_index in \p prim_func + * is for a pointer type. This can be used to check all \p prim_func parameters have been accounted + * for when using \p FindPointerParam above. + */ +void CheckNoRemainingPointerParams(const tir::PrimFunc& prim_func, + size_t* current_primfunc_param_index) { + while (*current_primfunc_param_index < prim_func->params.size()) { + const tir::Var& param = prim_func->params[*current_primfunc_param_index]; + auto itr = prim_func->buffer_map.find(param); + if (itr == prim_func->buffer_map.end()) { + VLOG(1) << "no buffer map entry for '" << param->name_hint << "'"; + ++*current_primfunc_param_index; + continue; + } + const auto* pointer_type_node = PointerInBuffer((*itr).second); + ICHECK(pointer_type_node == nullptr); + ++*current_primfunc_param_index; + } +} + +/*! + * \brief Returns the (consistent) constraint to use for a Relay parameter of \p type, + * using \p prim_func parameters at or after \p *current_primfunc_param_index. Currently + * only memory scope is extracted. Fails if constraints are not consistent, ie \p type is a tuple + * type and the \p prim_func is attempting to map different fields of that tuple to different memory + * scopes. Returns the fully unconstrained \p SEScope if no memory scopes constraints arise from + * the \p prim_func, ie all storage scope strings in pointer types are empty. + */ +SEScope ConsistentParamConstraint(const tir::PrimFunc& prim_func, const Type& type, + size_t* current_primfunc_param_index) { + std::string memory_scope; // default empty => no constraint + for (size_t i = 0; i < relay::FlattenTupleType(type).size(); ++i) { + std::pair kv = FindPointerParam(prim_func, current_primfunc_param_index); + const tir::Buffer& buffer = kv.second; + const auto* pointer_type_node = buffer->data->type_annotation.as(); + const MemoryScope& buffer_memory_scope = pointer_type_node->storage_scope; + if (memory_scope.empty()) { + memory_scope = buffer_memory_scope; + } else if (buffer_memory_scope.empty()) { + // No constraint. + } else { + // Tuples must be homogenous on their SEScope and thus memory scope. + ICHECK_EQ(buffer_memory_scope, memory_scope); + } + ++*current_primfunc_param_index; + } + return SEScope::ForMemoryScope(memory_scope); +} + +/*! + * \brief Insert into param_constraints an entry for each parameter of \p prim_func starting from + * \p *current_primfunc_param_index for the flattened form of a Rleay parameters of \p type. Each + * entry maps to \p se_scope. + */ +void InsertParamConstraints(const tir::PrimFunc& prim_func, const Type& type, + const SEScope& se_scope, size_t* current_primfunc_param_index, + std::unordered_map* param_constraints) { + for (size_t i = 0; i < relay::FlattenTupleType(type).size(); ++i) { + std::pair kv = FindPointerParam(prim_func, current_primfunc_param_index); + param_constraints->emplace(kv.first.get(), se_scope); + ++*current_primfunc_param_index; + } +} + +/*! + * \brief Apply the memory scope constraints to the \p Buffers and data \p Vars of a \p PrimFunc. + * + * All definitional occurrences of buffer Vars are rewritten to capture memory scopes in their + * PointerTypes: + * - Buffer::data (if the buffer itself is a definitional occurrence) + * - AllocateNode::buffer_var + * - FUTURE: LetStmtNode::var if aliasing a buffer data var. + * + * All referential occurrences of buffer Vars are replaced with their new definitions: + * - LoadNode::buffer_var + * - StoreNode::buffer_var + * + * Similarly all definitional occurrences of Buffers are rewritten to account for any new memory + * scopes: + * - PrimFuncNode::buffer_map keys. + * - BlockNode::match_buffers.buffer + * - FUTURE: BlockNode::alloc_buffers? + * + * And all referential occurrences of Buffers are replaced with their new definitions: + * - BufferLoadNode::buffer + * - BufferStoreNode::buffer + * - BufferRealizeNode::buffer + * - PrefetchNode::buffer + * - BufferRegionNode:buffer + * - BlockNode.match_buffers.source.buffer + * - BlockNode::{reads, writes}.buffer + * + * CAUTION: We assume strict sharing of Buffer objects and do not attempt to rewrite the bodies + * of referential buffers. + * + * CAUTION: EXPERIMENTAL: We don't yet account for all buffers and pointer types. + */ +class ApplyDeviceConstraintsMutator : public StmtExprMutator { + public: + ApplyDeviceConstraintsMutator() = default; + + /*! + * \brief Returns \p prim_func written to capture the memory scope constraints in \p + * param_constraints for each pointer \p prim_func parameter. Returns \p prim_func unchanged if no + * memory scopes needed to change. + */ + PrimFunc Rewrite(const PrimFunc& prim_func, const FuncType& relay_func_type, + const Array& arg_and_result_se_scopes) { + size_t current_primfunc_param_index = 0; + std::unordered_map param_constraints; + + // For each Relay function parameter... + for (size_t i = 0; i < relay_func_type->arg_types.size(); ++i) { + const Type& param_type = relay_func_type->arg_types[i]; + const SEScope& param_se_scope = arg_and_result_se_scopes[i]; + InsertParamConstraints(prim_func, param_type, param_se_scope, ¤t_primfunc_param_index, + ¶m_constraints); + } + + // For the Relay function result... + const Type& ret_type = relay_func_type->ret_type; + const SEScope& ret_se_scope = arg_and_result_se_scopes.back(); + InsertParamConstraints(prim_func, ret_type, ret_se_scope, ¤t_primfunc_param_index, + ¶m_constraints); + + // Make sure we accounted for all prim_func parameters. + CheckNoRemainingPointerParams(prim_func, ¤t_primfunc_param_index); + + // Start with a copy of the current prim_func buffer map. + Map new_buffer_map(prim_func->buffer_map.begin(), prim_func->buffer_map.end()); + bool any_change = false; + + // For each constrained parameter... + for (const auto& kv : param_constraints) { + const tir::Var param = GetRef(kv.first); + const SEScope& se_scope = kv.second; + const tir::Buffer& buffer = prim_func->buffer_map[param]; + // Rewrite the buffer to account for constraint. + const Buffer new_buffer = RewriteBuffer(buffer, se_scope); + if (!new_buffer.same_as(buffer)) { + any_change = true; + } + new_buffer_map.Set(param, new_buffer); + } + // Make sure we have accounted for all prim_func parameters. + CheckNoRemainingPointerParams(prim_func, ¤t_primfunc_param_index); + + // Apply data variable and buffer substitutions to the prim_func body. These will have been + // accumulated from processing the parameters above. + Stmt new_body = VisitStmt(prim_func->body); + if (!new_body.same_as(prim_func->body)) { + any_change = true; + } + + // We are done with the substitutions. + var_subst_.clear(); + buffer_subst_.clear(); + + if (any_change) { + return PrimFunc(prim_func->params, std::move(new_body), prim_func->ret_type, + std::move(new_buffer_map), prim_func->attrs, prim_func->span); + } else { + return prim_func; + } + } + + private: + PrimExpr VisitExpr_(const VarNode* var_node) final { return Subst(var_node); } + + PrimExpr VisitExpr_(const LoadNode* load_node) final { + Load new_load = Downcast(StmtExprMutator::VisitExpr_(load_node)); + Var new_buffer_var = Subst(new_load->buffer_var.get()); + if (!new_buffer_var.same_as(new_load->buffer_var)) { + return Load(load_node->dtype, new_buffer_var, load_node->index, load_node->predicate); + } + return new_load; + } + + PrimExpr VisitExpr_(const BufferLoadNode* buffer_load_node) final { + BufferLoad new_buffer_load = + Downcast(StmtExprMutator::VisitExpr_(buffer_load_node)); + Buffer new_buffer = Subst(new_buffer_load->buffer.get()); + if (!new_buffer.same_as(new_buffer_load->buffer)) { + return BufferLoad(new_buffer, new_buffer_load->indices, new_buffer_load->span); + } + return new_buffer_load; + } + + Stmt VisitStmt_(const LetStmtNode* let_stmt_node) final { + // TODO(mbs): If the let-bound var is aliasing an existing buffer data var we need to + // rewrite it. + return StmtExprMutator::VisitStmt_(let_stmt_node); + } + + Stmt VisitStmt_(const AttrStmtNode* attr_stmt_node) final { + AttrStmt new_attr_stmt = Downcast(StmtExprMutator::VisitStmt_(attr_stmt_node)); + // remap node if a var + if (const auto* var_node = new_attr_stmt->node.as()) { + Var new_var = Subst(var_node); + if (!new_var.same_as(new_attr_stmt->node)) { + return AttrStmt(new_var, new_attr_stmt->attr_key, new_attr_stmt->value, + new_attr_stmt->body); + } + } + return new_attr_stmt; + } + + // ForNode default ok since loop_var never of PointerType + + // WhileNode default ok + + Stmt VisitStmt_(const AllocateNode* allocate_node) final { + // TODO(mbs): What memory scope should we assign to the new pointer? + return StmtExprMutator::VisitStmt_(allocate_node); + } + + Stmt VisitStmt_(const StoreNode* store_node) final { + Store new_store = Downcast(StmtExprMutator::VisitStmt_(store_node)); + Var new_buffer_var = Subst(new_store->buffer_var.get()); + if (!new_buffer_var.same_as(new_store->buffer_var)) { + Store(new_buffer_var, new_store->value, new_store->index, new_store->predicate); + } + return new_store; + } + + Stmt VisitStmt_(const BufferStoreNode* buffer_store_node) final { + BufferStore new_buffer_store = + Downcast(StmtExprMutator::VisitStmt_(buffer_store_node)); + Buffer new_buffer = Subst(new_buffer_store->buffer.get()); + if (!new_buffer.same_as(new_buffer_store->buffer)) { + return BufferStore(new_buffer, new_buffer_store->value, new_buffer_store->indices, + new_buffer_store->span); + } + return new_buffer_store; + } + + Stmt VisitStmt_(const BufferRealizeNode* buffer_realize_node) final { + BufferRealize new_buffer_realize = + Downcast(StmtExprMutator::VisitStmt_(buffer_realize_node)); + Buffer new_buffer = Subst(new_buffer_realize->buffer.get()); + if (!new_buffer.same_as(new_buffer_realize->buffer)) { + return BufferRealize(new_buffer, new_buffer_realize->bounds, new_buffer_realize->condition, + new_buffer_realize->body, new_buffer_realize->span); + } + return new_buffer_realize; + } + + // IfThenElseNode default ok + // AssertStmtNode default ok + // ProducerStoreNode default ok (though does not visit producer) + // ProducerRealizeNode default ok (though does not visit producer) + + Stmt VisitStmt_(const PrefetchNode* prefetch_node) final { + Prefetch new_prefetch = Downcast(StmtExprMutator::VisitStmt_(prefetch_node)); + Buffer new_buffer = Subst(new_prefetch->buffer.get()); + if (!new_buffer.same_as(new_prefetch->buffer)) { + return Prefetch(new_buffer, prefetch_node->bounds, prefetch_node->span); + } + return new_prefetch; + } + + // SeqStmtNode default ok + // EvaluateNode default ok + + BufferRegion VisitItem(const BufferRegionNode* buffer_region_node) { + Buffer new_buffer = Subst(buffer_region_node->buffer.get()); + if (!new_buffer.same_as(buffer_region_node->buffer)) { + return BufferRegion(new_buffer, buffer_region_node->region); + } + return GetRef(buffer_region_node); + } + + MatchBufferRegion VisitItem(const MatchBufferRegionNode* match_buffer_region_node) { + // The source field has a referential occurrence of the buffer. Apply the buffer substitution + // to that. + BufferRegion new_source = VisitItem(match_buffer_region_node->source.get()); + // The buffer field however is a definitional occurrence, aliased on top of the source. + // Transfer any memory scope from the source to the destination. + Optional opt_se_scope = GetBufferConstraint(new_source->buffer); + tir::Buffer new_buffer; + if (opt_se_scope.defined()) { + new_buffer = RewriteBuffer(match_buffer_region_node->buffer, opt_se_scope.value()); + } else { + new_buffer = match_buffer_region_node->buffer; + } + if (!new_buffer.same_as(match_buffer_region_node->buffer) || + !new_source.same_as(match_buffer_region_node->source)) { + return MatchBufferRegion(new_buffer, new_source); + } + return GetRef(match_buffer_region_node); + } + + template + Array VisitItems(Array items) { + bool any_change = false; + Array new_items; + new_items.reserve(items.size()); + for (const auto& item : items) { + T new_item = VisitItem(item.get()); + if (!new_item.same_as(item)) { + any_change = true; + } + new_items.push_back(new_item); + } + return any_change ? new_items : items; + } + + Stmt VisitStmt_(const BlockNode* block_node) final { + Block new_block = Downcast(StmtExprMutator::VisitStmt_(block_node)); + Array new_reads = VisitItems(new_block->reads); + Array new_writes = VisitItems(new_block->writes); + // TODO(mbs): What memory scope should we assign to the new buffers? + Array new_match_buffers = VisitItems(new_block->match_buffers); + if (!new_reads.same_as(new_block->reads) || new_writes.same_as(new_block->writes) || + new_match_buffers.same_as(new_block->match_buffers)) { + return Block(new_block->iter_vars, std::move(new_reads), std::move(new_writes), + new_block->name_hint, new_block->body, new_block->init, new_block->alloc_buffers, + std::move(new_match_buffers), new_block->annotations, new_block->span); + } + return new_block; + } + + // BlockRealizeNode default ok + + /*! Applies \p var_subst_ substitution to \p var_node. */ + Var Subst(const VarNode* var_node) const { + auto itr = var_subst_.find(var_node); + return itr == var_subst_.end() ? GetRef(var_node) : itr->second; + } + + /*! Applies \p buffer_subst_ substitution to \p buffer. */ + Buffer Subst(const BufferNode* buffer_node) const { + auto itr = buffer_subst_.find(buffer_node); + return itr == buffer_subst_.end() ? GetRef(buffer_node) : itr->second; + } + + /*! + * \brief Rewrites \p buffer so as to follow the constraints in \p se_scope + * (currently just memory scope). + * + * Updates both the var_subst_ and buffer_subst_ to capture the rewrite, but + * also returns the new buffer. + */ + Buffer RewriteBuffer(const Buffer& buffer, const SEScope& se_scope) { + ICHECK(buffer->data->type_annotation.defined()); + const auto* pointer_type_node = buffer->data->type_annotation.as(); + ICHECK(pointer_type_node); + if (pointer_type_node->storage_scope == se_scope->memory_scope) { + // No change. + return buffer; + } + PointerType new_pointer_type(pointer_type_node->element_type, se_scope->memory_scope); + Var new_data(buffer->data->name_hint, new_pointer_type, buffer->data->span); + var_subst_.emplace(buffer->data.get(), new_data); + Buffer new_buffer(new_data, buffer->dtype, buffer->shape, buffer->strides, buffer->elem_offset, + buffer->name, buffer->data_alignment, buffer->offset_factor, + buffer->buffer_type, buffer->span); + buffer_subst_.emplace(buffer.get(), new_buffer); + return new_buffer; + } + + /*! + * \brief Returns the SEScope capturing any memory scope in \p buffer. Returns nullptr if + * buffer's data var does not have a type annotation of \p PointerType. Returns the fully + * unconstrained \p SEScope if no memory scope is given. + */ + static Optional GetBufferConstraint(const tir::Buffer& buffer) { + const auto* pointer_type_node = PointerInBuffer(buffer); + return pointer_type_node == nullptr ? Optional() + : SEScope::ForMemoryScope(pointer_type_node->storage_scope); + } + + /*! + * \brief Maps each \p Buffer::data \p Var to its constrained equivalent. + */ + std::unordered_map var_subst_; + + /*! + * \brief Maps each \p Buffer to its constrained equivalent. + */ + std::unordered_map buffer_subst_; +}; + +} // namespace + +Array GetPrimFuncArgAndResultConstraints(const tir::PrimFunc& prim_func, + const FuncType& relay_func_type) { + // Build the implied domain (in terms of the function's Relay type) implied by any memory scope + // constrains in the function's buffers, for both arguments and results. + Array se_scopes; + se_scopes.reserve(relay_func_type->arg_types.size() + 1); + + // For each Relay function parameter... + size_t current_primfunc_param_index = 0; + for (const auto& param_type : relay_func_type->arg_types) { + SEScope param_se_scope = + ConsistentParamConstraint(prim_func, param_type, ¤t_primfunc_param_index); + se_scopes.push_back(param_se_scope); + } + + // For the Relay function result... + const Type& ret_type = relay_func_type->ret_type; + SEScope ret_se_scope = + ConsistentParamConstraint(prim_func, ret_type, ¤t_primfunc_param_index); + se_scopes.push_back(ret_se_scope); + + // Make sure all parameters of the prim_func have been accounted for. + CheckNoRemainingPointerParams(prim_func, ¤t_primfunc_param_index); + + return se_scopes; +} + +TVM_REGISTER_GLOBAL("tir.analysis.GetPrimFuncArgAndResultMemoryConstraints") + .set_body_typed([](const PrimFunc& prim_func, const FuncType& relay_func_type) { + Array memory_scopes; + memory_scopes.reserve(relay_func_type->type_params.size() + 1); + for (const auto& se_scope : GetPrimFuncArgAndResultConstraints(prim_func, relay_func_type)) { + memory_scopes.push_back(se_scope->memory_scope); + } + return memory_scopes; + }); + +PrimFunc ApplyPrimFuncArgAndResultConstraints(const PrimFunc& prim_func, + const FuncType& relay_func_type, + const Array& arg_and_result_se_scopes) { + return ApplyDeviceConstraintsMutator().Rewrite(prim_func, relay_func_type, + arg_and_result_se_scopes); +} + +TVM_REGISTER_GLOBAL("tir.analysis.ApplyPrimFuncArgAndResultMemoryConstraints") + .set_body_typed([](const PrimFunc& prim_func, const FuncType& relay_func_type, + const Array& arg_and_result_memory_scopes) { + Array se_scopes; + se_scopes.reserve(arg_and_result_memory_scopes.size()); + for (const auto& memory_scope : arg_and_result_memory_scopes) { + se_scopes.push_back(SEScope::ForMemoryScope(memory_scope)); + } + return ApplyPrimFuncArgAndResultConstraints(prim_func, relay_func_type, se_scopes); + }); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/analysis/device_constraint_utils.h b/src/tir/analysis/device_constraint_utils.h new file mode 100644 index 000000000000..94cb40e5a83e --- /dev/null +++ b/src/tir/analysis/device_constraint_utils.h @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/analysis/device_constraint_utils.cc + * \brief Utilities for extracting and applying device-related constraints to \p PrimFunc + * parameters. + * + * These utilities are used by the \p PlanDevices pass to extract memory (aka 'storage') scope + * information from \p PrimFuncs and convert them back into \p SEScope form w.r.t. the original + * Relay type of the \p PrimFunc (ie before flattening of tuple arguments/results and conversion + * to destination-passing style aka DPS). + * + * A utility is also supplied to go the other way: impose memory scopes on \p PrimFunc parameters. + * However that's still in EXPERIMENTAL form. + * + * We may extend these utilities to also gather/apply layout information should we add that to + * \p SEScope. + */ + +#ifndef TVM_TIR_ANALYSIS_DEVICE_CONSTRAINT_UTILS_H_ +#define TVM_TIR_ANALYSIS_DEVICE_CONSTRAINT_UTILS_H_ + +#include +#include + +namespace tvm { +namespace tir { + +/* + * A Relay Function with type: + * \code + * fn((Tensor[...], Tensor[...]), Tensor[...]) -> (Tensor[...], Tensor[...]) + * ^ ^ ^ ^ ^ + * a b c d e + * \endcode + * will be represented by a TIR PrimFunc in flattened and DPS form with at least 5 argument a..e. + * Each such PrimFunc argument will have a type annotation for a PointerType to the underlying + * tensor's buffer. The PrimFunc may have additional non-pointer arguments, for example to represent + * device contexts or other non-tensor arguments, and those should be ignored here since they have + * no counterpart in the Relay Function. + */ + +/*! + * \brief Returns the \p SEScopes capturing the memory (aka storage) scope constraints for all the + * arguments and result of \p prim_func. However the result will be w.r.t. the \p prim_func's + * representation as a Relay \p Function of \p relay_func_type_ before lowering and conversion to + * DPS. + */ +Array GetPrimFuncArgAndResultConstraints(const tir::PrimFunc& prim_func, + const FuncType& relay_func_type); + +/* + * \brief Returns \p prim_func written to capture the memory (aka storage) scope constraints + * for each of the \p prim_func's parameters given by \p arg_and_result_se_scopes. However, + * \p arg_and_result_se_scopes should be w.r.t. the \p prim_func's representation as a Relay + * \p Function of \p relay_func_type before lowering and conversion to DPS. + * + * CAUTION: This is experimental. The resulting \p PrimFunc may not have fully accounted for all + * new memory scopes. + */ +PrimFunc ApplyPrimFuncArgAndResultConstraints(const PrimFunc& prim_func, + const FuncType& relay_func_type, + const Array& arg_and_result_se_scopes); + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_ANALYSIS_DEVICE_CONSTRAINT_UTILS_H_ diff --git a/tests/python/tir/analysis/test_device_constraint_utils.py b/tests/python/tir/analysis/test_device_constraint_utils.py new file mode 100644 index 000000000000..65cb4e398294 --- /dev/null +++ b/tests/python/tir/analysis/test_device_constraint_utils.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test retrieving and applying memory scope constraints to PrimFuncs""" +import tvm +from tvm import tir +from tvm import relay +from tvm.script import tir as T + + +@T.prim_func +def gem(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, [128, 128], scope="scopeA") + B = T.match_buffer(b, [128, 128], scope="scopeA") + C = T.match_buffer(c, [128, 128], scope="scopeB") + D = T.match_buffer(d, [128, 128], scope="scopeC") + + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + D[vi, vj] = C[vi, vj] + D[vi, vj] = D[vi, vj] + A[vi, vk] * B[vj, vk] + + +gem_ty = relay.FuncType( + [ + relay.TupleType( + [ + relay.TensorType((128, 128), "float32"), + relay.TensorType((128, 128), "float32"), + ] + ), + relay.TensorType((128, 128), "float32"), + ], + relay.TensorType((128, 128), "float32"), +) + + +def test_get_prim_func_arg_and_result_constraints(): + scopes = tir.analysis.get_prim_func_arg_and_result_memory_constraints(gem, gem_ty) + assert [x for x in scopes] == ["scopeA", "scopeB", "scopeC"] + + +def test_apply_prim_func_arg_and_result_memory_constraints(): + rewritten = tir.analysis.apply_prim_func_arg_and_result_memory_constraints( + gem, gem_ty, ["scopeX", "scopeY", "scopeZ"] + ) + scopes = tir.analysis.get_prim_func_arg_and_result_memory_constraints(rewritten, gem_ty) + assert [x for x in scopes] == ["scopeX", "scopeY", "scopeZ"] + + +if __name__ == "__main__": + import sys + import pytest + + sys.exit(pytest.main([__file__] + sys.argv[1:])) From 4468dd8278da24025e2cc3433393f50104d385f8 Mon Sep 17 00:00:00 2001 From: Mark Shields Date: Thu, 9 Dec 2021 12:11:25 -0800 Subject: [PATCH 2/2] [checkpoint] Junru's comments. --- include/tvm/tir/analysis.h | 3 --- python/tvm/tir/analysis/analysis.py | 6 +++--- src/tir/analysis/device_constraint_utils.cc | 13 ++---------- src/tir/analysis/device_constraint_utils.h | 23 ++++++++++++++++----- 4 files changed, 23 insertions(+), 22 deletions(-) diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index fa63b2617f48..51bdb18d2217 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -26,14 +26,12 @@ #include #include -#include #include #include #include #include #include -#include namespace tvm { namespace tir { @@ -244,5 +242,4 @@ TVM_DLL Pass VerifyGPUCode(Map constraints); } // namespace transform } // namespace tir } // namespace tvm - #endif // TVM_TIR_ANALYSIS_H_ diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index c74837aa820a..c2338dd9b611 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -16,7 +16,7 @@ # under the License. """Wrapping existing analysis utils.""" # pylint: disable=invalid-name -from typing import Dict, List, AnyStr +from typing import Dict, List from tvm import Object from tvm.tir.stmt import Block, BufferRegion @@ -205,7 +205,7 @@ def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]: def get_prim_func_arg_and_result_memory_constraints( func: PrimFunc, relay_func_type: Object -) -> List[AnyStr]: +) -> List[str]: """Returns the memory (aka storage) scope constraints for all the arguments and result of func. However the result will be w.r.t. the func's representation as a Relay Function of relay_func_type before lowering and conversion to DPS. @@ -232,7 +232,7 @@ def get_prim_func_arg_and_result_memory_constraints( def apply_prim_func_arg_and_result_memory_constraints( - func: PrimFunc, relay_func_type: Object, arg_and_result_memory_scopes: List[AnyStr] + func: PrimFunc, relay_func_type: Object, arg_and_result_memory_scopes: List[str] ) -> PrimFunc: """Returns func written to capture the memory (aka storage) scope constraints for each of the func's parameters given by arg_and_result_memory_scopes. However, diff --git a/src/tir/analysis/device_constraint_utils.cc b/src/tir/analysis/device_constraint_utils.cc index a2b2e046ab10..8412cb8b8923 100644 --- a/src/tir/analysis/device_constraint_utils.cc +++ b/src/tir/analysis/device_constraint_utils.cc @@ -373,17 +373,8 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { template Array VisitItems(Array items) { - bool any_change = false; - Array new_items; - new_items.reserve(items.size()); - for (const auto& item : items) { - T new_item = VisitItem(item.get()); - if (!new_item.same_as(item)) { - any_change = true; - } - new_items.push_back(new_item); - } - return any_change ? new_items : items; + items.MutateByApply([this](const T& item) { return VisitItem(item.get()); }); // copy-on-write + return items; } Stmt VisitStmt_(const BlockNode* block_node) final { diff --git a/src/tir/analysis/device_constraint_utils.h b/src/tir/analysis/device_constraint_utils.h index 94cb40e5a83e..be0f199f5226 100644 --- a/src/tir/analysis/device_constraint_utils.h +++ b/src/tir/analysis/device_constraint_utils.h @@ -43,7 +43,7 @@ namespace tvm { namespace tir { -/* +/*! * A Relay Function with type: * \code * fn((Tensor[...], Tensor[...]), Tensor[...]) -> (Tensor[...], Tensor[...]) @@ -51,10 +51,23 @@ namespace tir { * a b c d e * \endcode * will be represented by a TIR PrimFunc in flattened and DPS form with at least 5 argument a..e. - * Each such PrimFunc argument will have a type annotation for a PointerType to the underlying - * tensor's buffer. The PrimFunc may have additional non-pointer arguments, for example to represent - * device contexts or other non-tensor arguments, and those should be ignored here since they have - * no counterpart in the Relay Function. + * \code + * primfn(a: handle, b: handle, c: handle, d: handle, e: handle) { + * buffers = { ... } + * buffer_map = { ... } + * ... + * } + * \endcode + * + * Each such PrimFunc argument will me mapped to a \p Buffer who's underlying \p data \p Var + * has a \p PointerType. + * + * The PrimFunc may have additional non-pointer arguments, eg for: + * - scalar inputs and tensor dimensions + * - device contexts + * Those should be ignored here since they have no counterpart in the Relay Function. + * + * We'll need helpers to map on-the-fly between the Relay and TIR view of functions. */ /*!