diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 5376d99ee15b..efe30e5cbb50 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -572,6 +572,16 @@ TVM_DLL Pass ConvertToDataflow(int min_size = 2); */ TVM_DLL Pass DeadCodeElimination(Array entry_functions); +/*! + * \brief Pass that changes calls to operators that can be done in-place + * (generally, these are elementwise operations) in dataflow blocks into in-place implementations. + * Supported operators will be replaced by calls to `call_tir_inplace` that invoke in-place + * PrimFunc implementations of those operators (which are based on the legalizations of those + * operators). + * \return The pass. + */ +TVM_DLL Pass DataflowUseInplaceCalls(); + /*! * \brief Automatic mixed precision pass. Currently the pass assumes the input module to be fp32 * only, and will automatically cast fp32 to fp16 for certain ops. diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index 5bc0d6c56eb4..23cfaf293560 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -63,7 +63,13 @@ from .exec_builder import ExecBuilder # Operator -from .op.base import call_tir, call_pure_packed, call_dps_packed, call_tir_with_grad +from .op.base import ( + call_tir, + call_tir_inplace, + call_pure_packed, + call_dps_packed, + call_tir_with_grad, +) # BlockBuilder from .block_builder import BlockBuilder diff --git a/python/tvm/relax/testing/transform.py b/python/tvm/relax/testing/transform.py index ccae38a138a3..42dbd37d2931 100644 --- a/python/tvm/relax/testing/transform.py +++ b/python/tvm/relax/testing/transform.py @@ -17,14 +17,18 @@ # pylint: disable=unused-argument, invalid-name, no-else-return, abstract-method, arguments-differ """Relax transformation passes for testing""" +import logging +import os +from typing import Dict, List, Set, Tuple import tvm from tvm import ir, relax from tvm.ir import transform from tvm.ir.module import IRModule from tvm.ir.transform import PassContext from tvm.relax import PyExprMutator -from tvm.relax.expr import Call +from tvm.relax.expr import Call, DataflowBlock, Var from tvm.relay.backend.te_compiler import select_implementation +from tvm.runtime.object import Object from tvm.target import Target @@ -128,3 +132,95 @@ def transform(self): def ApplyEmptyCppMutator() -> tvm.ir.transform.Pass: packed_func = tvm.get_global_func("relax.testing.transform.ApplyEmptyCppMutator") return packed_func() + + +def dataflow_liveness_analysis(block: DataflowBlock) -> Dict[Var, Tuple[int, int]]: + """ + Inner function for the dataflow inplace transformation exposed for testing. + """ + if "PYTEST_CURRENT_TEST" not in os.environ: + logging.warning("The function dataflow_liveness_analysis is exposed for testing only.") + + live_ranges = tvm.get_global_func("relax.testing.transform.DataflowLivenessAnalysis")( + block + ) # type: ignore + ret = {} + for var, live_range in live_ranges.items(): + ret[var] = tuple(live_range) + return ret # type: ignore + + +def dataflow_alias_analysis( + block: DataflowBlock, inputs: List[Var] +) -> Tuple[Dict[Var, Set[int]], Dict[int, List[Set[int]]]]: + """ + Inner function for the dataflow inplace transformation exposed for testing. + """ + if "PYTEST_CURRENT_TEST" not in os.environ: + logging.warning("The function dataflow_alias_analysis is exposed for testing only.") + + alias_sets, tuple_map = tvm.get_global_func("relax.testing.transform.DataflowAliasAnalysis")( + block, + inputs, + ) # type: ignore + res_alias_sets = {} + res_tuple_map = {} + for var, alias_set in alias_sets.items(): + res_alias_sets[var] = set(alias_set) + for idx, elem_alias_sets in tuple_map.items(): + res_tuple_map[idx] = [set(alias_set) for alias_set in elem_alias_sets] + return res_alias_sets, res_tuple_map # type: ignore + + +@tvm._ffi.register_object("relax.transform.InplaceOpportunity") +class InplaceOpportunity(Object): + """ + Represents an opportunity to make a binding in-place. Exposed only for testing; + the constructor is not exposed. + + Parameters: + ----------- + binding_idx: int + Index of the binding within its block + + arg_idxs: List[int] + Indices of arguments that are eligible to be used as in-place targets. + """ + + def __init__(self, _binding_idx, _arg_idxs): + raise NotImplementedError("Constructor for InplaceOpportunity not exposed!") + + +def dataflow_inplace_analysis( + block: DataflowBlock, inputs: List[Var], mod: IRModule +) -> Tuple[List[Tuple[int, Set[int]]], List[Tuple[int, Set[int]]]]: + """ + Inner function for the dataflow inplace transformation exposed for testing. + """ + if "PYTEST_CURRENT_TEST" not in os.environ: + logging.warning("The function dataflow_inplace_analysis is exposed for testing only.") + index_lists = tvm.get_global_func("relax.testing.transform.DataflowInplaceAnalysis")( + block, inputs, mod + ) # type: ignore + + def convert(opp_list): + return list(map(lambda opp: (int(opp.binding_idx), set(map(int, opp.arg_idxs))), opp_list)) + + return (convert(index_lists[0]), convert(index_lists[1])) # type: ignore + + +def dataflow_single_inplace_call( + mod: IRModule, call: Call, inplace_indices: List[int] +) -> Tuple[Call, IRModule]: + """ + Inner function for the dataflow inplace transformation exposed for testing. + """ + if "PYTEST_CURRENT_TEST" not in os.environ: + logging.warning("The function dataflow_single_inplace_call is exposed for testing only.") + + ret = tvm.get_global_func("relax.testing.transform.SingleInplaceCall")( + mod, + call, + inplace_indices, + ) # type: ignore + return (ret[0], ret[1]) # type: ignore diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 19316c76b83d..353ee88b6898 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -30,6 +30,7 @@ ConvertLayout, ConvertToDataflow, DataflowBlockPass, + DataflowUseInplaceCalls, DeadCodeElimination, DecomposeOpsForInference, DecomposeOpsForTraining, diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 9589f661d79e..99fdc67c29ce 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -252,6 +252,22 @@ def RemovePurityChecking() -> tvm.ir.transform.Pass: return _ffi_api.RemovePurityChecking() # type: ignore +def DataflowUseInplaceCalls() -> tvm.ir.transform.Pass: + """ + Pass that changes calls to operators that can be done in-place + (generally, these are elementwise operations) into in-place implementations. + Supported operators will be replaced by calls to `call_tir_inplace` that invoke + in-place PrimFunc implementations of those operators (which are based on the legalizations of + those operators). + + Returns + ------- + ret: tvm.ir.transform.Pass + The pass + """ + return _ffi_api.DataflowUseInplaceCalls() + + def LambdaLift() -> tvm.ir.transform.Pass: """A pass that lifts local functions into global. diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc new file mode 100644 index 000000000000..755c5dbab433 --- /dev/null +++ b/src/relax/transform/dataflow_inplace.cc @@ -0,0 +1,1040 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/dataflow_inplace.cc + * \brief Pass that converts eligible operator calls in dataflow blocks + * into in-place versions. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "utils.h" + +namespace tvm { +namespace relax { + +// Perform liveness analysis on a dataflow block, returning a map of vars to +// pairs of indices (the liveness interval, from the starting index to the end index). +// A starting index of -1 means the var is defined before the block starts and an end index +// of block->bindings.size() (one past the last index) means it is live after the block ends. +std::unordered_map, ObjectPtrHash, ObjectPtrEqual> AnalyzeLiveness( + const DataflowBlock& block) { + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> ret; + for (int i = block->bindings.size() - 1; i >= 0; i--) { + Binding b = block->bindings[i]; + Var defined_var = b->var; + Expr value = GetBoundValue(b); + Array used_vars; + // for a function literal, we consider only the free vars + // (those captured from the outer scope) + if (value.as()) { + used_vars = FreeVars(value); + } else if (value.as()) { + // Special case: we do not consider a tuple index to be a "use." + // This is a bit of a hack but allows us to do operations that + // create tuples to be done in-place (otherwise, any index of the tuple + // would be considered a use and so the tuple would be live later). + // Hence we keep the array empty. + } else { + used_vars = AllVars(value); + } + + for (auto var : used_vars) { + int range_end = i; + // if the var is not a dataflow var, then it is live + // after the block (we are not checking later blocks) + if (!var.as()) { + range_end = block->bindings.size(); + } + if (!ret.count(var)) { + ret[var] = {-1, range_end}; + } + } + + if (!ret.count(defined_var)) { + // if it's an output, then it lives past the end of the block + if (!defined_var.as()) { + ret[defined_var] = {i, block->bindings.size()}; + } else { + // otherwise, it's live only here + ret[defined_var] = {i, i}; + } + } else { + // this means the var is used later but we encountered its definition now + auto last_range = ret[defined_var]; + CHECK_EQ(last_range.first, -1); + std::pair new_range = {i, last_range.second}; + ret[defined_var] = new_range; + } + } + return ret; +} + +class AliasAnalyzer { + public: + AliasAnalyzer() : alias_map_(), tuple_map_(), mem_idx_(0) {} + + // The analysis returns a map of vars to memory locations that it *could* map to + // (any unique allocation = one memory location), plus a map of memory locations + // that correspond to tuples (this maps to sets of memory locations for each tuple element). + // Note: inputs are values that should be assumed not to be aliased and are therefore + // (in the case of in-place ops) safe to overwrite. This may not be true of function args. + std::pair, ObjectPtrHash, ObjectPtrEqual>, + std::unordered_map>>> + Analyze(const DataflowBlock& block, const Array& inputs) { + for (auto input : inputs) { + int curr_idx = get_fresh_idx(); + alias_map_[input] = {curr_idx}; + if (auto* tup_info = GetStructInfoAs(input)) { + InsertFreshTuple(curr_idx, tup_info); + } + } + + for (const Binding& binding : block->bindings) { + Var current_var = binding->var; + Expr value = GetBoundValue(binding); + alias_map_[current_var] = GetAliasSet(value, current_var); + } + + return {alias_map_, tuple_map_}; + } + + private: + int get_fresh_idx() { + int ret = mem_idx_; + mem_idx_++; + return ret; + } + + // Fresh tuple = each element is assumed to be a unique allocation + void InsertFreshTuple(int tup_idx, const TupleStructInfoNode* tup_info) { + std::vector> tuple_set; + for (int i = 0; i < static_cast(tup_info->fields.size()); i++) { + int curr_field = get_fresh_idx(); + tuple_set.push_back({curr_field}); + if (auto* nested_tup_info = tup_info->fields[i].as()) { + InsertFreshTuple(curr_field, nested_tup_info); + } + } + tuple_map_[tup_idx] = tuple_set; + } + + // given a tuple index, add the given memory location indices to each component's + // alias set + void UpdateTupleComponents(int tup_idx, const std::unordered_set& insert_idxs) { + if (tuple_map_.count(tup_idx)) { + auto tuple_comps = tuple_map_[tup_idx]; + for (size_t i = 0; i < tuple_comps.size(); i++) { + auto comp_set = tuple_comps[i]; + + // if a member is a tuple, update its components as well + for (int member : comp_set) { + if (tuple_map_.count(member)) { + UpdateTupleComponents(member, insert_idxs); + } + } + + // update after iterating to avoid iterating over the inserted elements + tuple_map_[tup_idx][i].insert(insert_idxs.begin(), insert_idxs.end()); + } + } + } + + // capture the given index and also its tuple components (including recursively) + // if they exist + void AddCapturedIndices(std::unordered_set* captured_set, int idx) { + captured_set->insert(idx); + if (tuple_map_.count(idx)) { + for (auto comp_set : tuple_map_[idx]) { + for (auto tup_comp_idx : comp_set) { + AddCapturedIndices(captured_set, tup_comp_idx); + } + } + } + } + + // Conservative extremely pessimistic assumption: + // assume that the result of a non-op call can be aliased to any argument + // or that it could be a newly allocated value. + // For tuples, assume all members are aliased. Yeah, it's bad. + // (Skip first arg is for handling call_pure_packed, where the first arg is an ExternFunc that we + // should ignore) + std::unordered_set HandleMysteryCall(const CallNode* call_node, const Var& bound_var, + bool skip_first_arg = false) { + // the result may or may not be newly allocated + std::unordered_set ret; + int res_idx = get_fresh_idx(); + // the result may be a tuple + if (auto* tup_info_node = GetStructInfoAs(bound_var)) { + InsertFreshTuple(res_idx, tup_info_node); + } + AddCapturedIndices(&ret, res_idx); + + for (size_t i = (skip_first_arg) ? 1 : 0; i < call_node->args.size(); i++) { + auto arg = call_node->args[i]; + auto arg_alias_set = GetAliasSet(arg, bound_var); + for (int alias_idx : arg_alias_set) { + AddCapturedIndices(&ret, alias_idx); + } + } + // if the result is a tuple, the components can also potentially be aliased to any arg + // or, in fact, to each other + UpdateTupleComponents(res_idx, ret); + return ret; + } + + // given the expression value, return the set of memory locations corresponding to it + // (the var the expression is being bound to is needed for struct info) + std::unordered_set GetAliasSet(const Expr& value, const Var& bound_var) { + std::unordered_set ret; + + // cases for value: + // constant: it's a fresh index + // var: look up in alias map (-1 if not present) + // op call: assume it's fresh (may need to make list of exceptions) + // tuple: fresh entry in tuple index, recurse to determine indices for values + // function/packed call: chaos reigns, alias with any other argument + // (if tuple is passed, assume also aliased with all members of the tuple) + // tuple index: -1 if tuple is not in tuple map, otherwise look up corresponding entry + // function constant: give them a fresh index (TODO: we can handle in more detail if this is a + // case we need to support) prim value: fresh index if node: should not happen inside dataflow + // block + if (value.as() || value.as() || value.as()) { + // TODO(@slyubomirsky): We will probably want special handling for closures + ret.insert(get_fresh_idx()); + } else if (auto* target_var_node = value.as()) { + auto target_var = GetRef(target_var_node); + if (alias_map_.count(target_var)) { + ret.insert(alias_map_[target_var].begin(), alias_map_[target_var].end()); + } else { + ret.insert(-1); + } + } else if (auto* target_tuple = value.as()) { + // fresh idx but we update the tuple map + int tup_idx = get_fresh_idx(); + ret.insert(tup_idx); + std::vector> new_tuple_map; + for (auto field : target_tuple->fields) { + new_tuple_map.push_back(GetAliasSet(field, bound_var)); + } + tuple_map_[tup_idx] = new_tuple_map; + } else if (auto* target_tgi = value.as()) { + std::unordered_set tuple_set = GetAliasSet(target_tgi->tuple, bound_var); + // if -1 is a member of the tuple set, then we have to assume the result is -1 + if (tuple_set.count(-1)) { + ret.insert(-1); + } else { + // otherwise, consider all members that are tuples of appropriate size and index into them + // (this is safe because the type system will ensure we're not indexing into a tuple + // of the wrong size) + for (int member : tuple_set) { + if (tuple_map_.count(member) && + static_cast(tuple_map_[member].size()) > target_tgi->index) { + auto member_set = tuple_map_[member][target_tgi->index]; + ret.insert(member_set.begin(), member_set.end()); + } + } + } + } else if (auto* call_node = value.as()) { + if (auto* op_node = call_node->op.as()) { + // call_pure_packed: treat as non-op call + if (op_node->name == "relax.call_pure_packed") { + return HandleMysteryCall(call_node, bound_var, true); + } else if (op_node->name == "relax.call_tir") { + // call_tir: can potentially return a tuple + if (auto* tuple_struct_info = call_node->sinfo_args[0].as()) { + int tup_idx = get_fresh_idx(); + ret.insert(tup_idx); + InsertFreshTuple(tup_idx, tuple_struct_info); + } else { + ret.insert(get_fresh_idx()); + } + } else { + // We are assuming most op calls return fresh values. + // We may have to track more exceptions + + // If the returned value is a tuple, we'll assume it's a fresh tuple + // (there may be exceptions to this too) + if (auto* tup_info = GetStructInfoAs(bound_var)) { + int tup_idx = get_fresh_idx(); + ret.insert(tup_idx); + InsertFreshTuple(tup_idx, tup_info); + return ret; + } + ret.insert(get_fresh_idx()); + } + } else { + // assume any non-op call can be extremely dangerous and do anything + return HandleMysteryCall(call_node, bound_var); + } + } + + return ret; + } + + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> alias_map_; + std::unordered_map>> tuple_map_; + int mem_idx_; +}; + +// given a shape, return the number of elements corresponding to it (product of elements) +PrimExpr NumElements(const ShapeExpr& shape) { + PrimExpr ret = IntImm(DataType::Int(64), 1); + for (auto dim : shape->values) { + ret *= dim; + } + return ret; +} + +// Given the struct info of the result, return any struct info nested in it +// that is eleigible to be used for in-place computations (tensors are eligible +// only if all their dimensions are integer constants, tuples are eligible if +// all members are eligible though we can consider only individual members separately) +std::unordered_set GatherCandidateSinfo( + const StructInfo& result_sinfo) { + if (auto* tensor_info = result_sinfo.as()) { + // don't consider void dtype (don't know the size at compile time) + if (tensor_info->dtype.is_void()) { + return {}; + } + // don't consider cases where we don't know the shape at compile time + // (we will use the analyzer to do best-effort analysis where there are vars) + if (tensor_info->shape.as()) { + return {GetRef(tensor_info)}; + } else { + return {}; + } + } else if (auto* tuple_info = result_sinfo.as()) { + // we can see if the whole tuple matches or go for any of the components + std::unordered_set ret; + for (auto field : tuple_info->fields) { + auto field_candidates = GatherCandidateSinfo(field); + ret.insert(field_candidates.begin(), field_candidates.end()); + } + // at least one field should be eligible to be done in-place + if (!ret.empty()) { + ret.insert(GetRef(tuple_info)); + } + return ret; + } else { + // don't consider any other types + return {}; + } +} + +// Given the two struct info, return a pair of bools where the first element is true if +// the two struct info have the same number of elements and dtype and the second element is true +// if the shapes match _exactly_. Performs this check recursively and ensures the +// stated condition is true for all tensor members of the struct info (return false +// if a single pair of corresponding tensors does not meet the condition). +std::pair SizeMatches(const StructInfo& target_info, const StructInfo& arg_info, + const BlockBuilder& ctx) { + if (target_info.as() && arg_info.as()) { + auto target_tensor = Downcast(target_info); + auto arg_tensor = Downcast(arg_info); + if (target_tensor->shape.defined() && target_tensor->shape.as() && + arg_tensor->shape.defined() && arg_tensor->shape.as()) { + if (target_tensor->dtype != arg_tensor->dtype) { + return {false, false}; + } + auto target_shape = Downcast(target_tensor->shape); + auto arg_shape = Downcast(arg_tensor->shape); + PrimExpr target_size = NumElements(target_shape); + PrimExpr arg_size = NumElements(arg_shape); + if (!ctx->GetAnalyzer()->CanProve(arg_size >= target_size)) { + return {false, false}; + } + // exact match: number of dims and each dim matches + if (target_shape->values.size() == arg_shape->values.size()) { + for (size_t i = 0; i < target_shape->values.size(); i++) { + if (!ctx->GetAnalyzer()->CanProveEqual(target_shape->values[i], arg_shape->values[i])) { + return {true, false}; + } + } + return {true, true}; + } + return {true, false}; + } else { + return {false, false}; + } + } else if (target_info.as() && arg_info.as()) { + auto target_tup = Downcast(target_info); + auto arg_tup = Downcast(arg_info); + if (target_tup->fields.size() != arg_tup->fields.size()) { + return {false, false}; + } + bool all_exact = true; + for (size_t i = 0; i < target_tup->fields.size(); i++) { + // if members aren't either tuples or tensors, simply skip them, + // since they don't matter for in-place computations + if (!(target_tup->fields[i].as() || + target_tup->fields[i].as()) && + !(arg_tup->fields[i].as() || + arg_tup->fields[i].as())) { + continue; + } + auto [field_size_match, field_exact_match] = + SizeMatches(target_tup->fields[i], arg_tup->fields[i], ctx); + if (!field_size_match) { + return {false, false}; + } + all_exact = all_exact && field_exact_match; + } + return {true, all_exact}; + } else { + return {false, false}; + } +} + +// Given an alias index, check if it's a tuple and gather the sets of aliases for the tuple +// members if so (apply recursively if any of those members are tuples). +// Return false if the alias set contains -1, meaning a reference to an unknown or +// possibly dangerous value (no checking we can do for that). +bool GatherSetsToCheckForLiveness( + const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& + alias_sets, + const std::unordered_map>>& tuple_map, + std::vector>* sets_to_check, int alias_idx) { + if (tuple_map.count(alias_idx)) { + for (auto member_set : tuple_map.at(alias_idx)) { + // contains -1 -> unknown and dangerous, we can short-circuit + if (member_set.count(-1)) { + return false; + } + sets_to_check->push_back(member_set); + + // if a member can be a tuple, check it recursively + for (int member : member_set) { + if (tuple_map.count(member)) { + if (!GatherSetsToCheckForLiveness(alias_sets, tuple_map, sets_to_check, member)) { + return false; + } + } + } + } + } + return true; +} + +// Check that the target is not live past the index and that no alias of it is live past the +// binding index (if the target is a tuple, check the conditions recursively for the members) +bool InplaceConditionsMet( + const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& live_ranges, + const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& + alias_sets, + const std::unordered_map>>& tuple_map, + const std::unordered_set& currently_live, + const Expr& target, int binding_idx) { + if (auto* var_node = target.as()) { + auto current_var = GetRef(var_node); + // if the var is live past this point, we can't use it for in-place computations anyway + if (live_ranges.count(current_var)) { + auto live_range = live_ranges.at(current_var); + if (live_range.second > binding_idx) { + return false; + } + } + + // no entry for the current var -> it must be something external and we have to assume the worst + if (!alias_sets.count(current_var)) { + return false; + } + auto alias_set = alias_sets.at(current_var); + // -1 -> an external value and we must assume the worst + if (alias_set.count(-1)) { + return false; + } + std::vector> sets_to_check = {alias_set}; + std::unordered_set indices_checked; + // If a possible alias is a tuple, we will also check for aliases of the members + // (possibly recursively) + for (int alias_idx : alias_set) { + if (!GatherSetsToCheckForLiveness(alias_sets, tuple_map, &sets_to_check, alias_idx)) { + return false; + } + } + + for (Var other_var : currently_live) { + if (other_var.same_as(target)) { + continue; + } + // not represented = spooky unknown value that should be modeled by -1 + if (!alias_sets.count(other_var) || !live_ranges.count(other_var)) { + continue; + } + // var is not live past this point => don't need to worry + if (live_ranges.at(other_var).second <= binding_idx) { + continue; + } + auto other_alias_set = alias_sets.at(other_var); + for (int alias_idx : other_alias_set) { + for (auto check_set : sets_to_check) { + if (check_set.count(alias_idx)) { + return false; + } + } + } + } + return true; + } else if (auto* tup_node = target.as()) { + for (auto field : tup_node->fields) { + if (!InplaceConditionsMet(live_ranges, alias_sets, tuple_map, currently_live, field, + binding_idx)) { + return false; + } + } + return true; + } else { + return true; + } +} + +// this is obviously not a complete list +static std::unordered_set SUPPORTED_OPS = {"relax.add", "relax.subtract", + "relax.multiply", "relax.divide", + "relax.nn.silu", "relax.nn.relu"}; +bool OpSupportsInplace(const Op& op) { return SUPPORTED_OPS.count(op->name); } + +/*! \brief Corresponds to a binding where at least one argument meets the conditions to be + * made in-place. Contains the binding index and indices of the applicable arguments + */ +class InplaceOpportunityNode : public Object { + public: + // need to use Array for the benefit of the FFI + Integer binding_idx; + Array arg_idxs; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("binding_idx", &binding_idx); + v->Visit("arg_idxs", &arg_idxs); + } + + static constexpr const char* _type_key = "relax.transform.InplaceOpportunity"; + TVM_DECLARE_BASE_OBJECT_INFO(InplaceOpportunityNode, Object); +}; + +TVM_REGISTER_NODE_TYPE(InplaceOpportunityNode); + +class InplaceOpportunity : public ObjectRef { + public: + TVM_DLL InplaceOpportunity(const Integer& binding_idx, const Array& arg_idxs) { + auto node = make_object(); + node->binding_idx = binding_idx; + node->arg_idxs = arg_idxs; + data_ = std::move(node); + } + + TVM_DEFINE_OBJECT_REF_METHODS(InplaceOpportunity, ObjectRef, InplaceOpportunityNode); +}; + +// Check for in-place eligibility: +// 1. see if there's an arg big enough to hold the result +// 2. see if the arg is live past the call +// 3. see if the arg has an alias that's live past the call +// If the conditions are met, record the index of that binding. +// Returns two lists of lists: +// 1. A list of bindings where at least one argument meets the in-place conditions and the *size* +// matches the size of the result. +// 2. A list of bindings where at least one argument meets the in-place conditions +// and *exactly* matches the shape of the result. +// For both lists, each element is a list of ints of the following format: +// The first element is the index of the *binding* in the block. +// All remaining elements are the indices of *eligible arguments* in that call. +std::pair, std::vector> +FindInplaceOpportunities(const DataflowBlock& block, const Array& inputs, + const BlockBuilder& ctx) { + auto live_ranges = AnalyzeLiveness(block); + AliasAnalyzer analyzer; + auto alias_info = analyzer.Analyze(block, inputs); + auto alias_sets = alias_info.first; + auto tuple_map = alias_info.second; + + std::vector size_match_list; + std::vector exact_match_list; + + // sort the live ranges by starting index + std::vector live_order; + for (auto kv : live_ranges) { + live_order.push_back(kv.first); + } + std::sort(live_order.begin(), live_order.end(), + [&live_ranges](const Var& var1, const Var& var2) -> bool { + return live_ranges[var1].first < live_ranges[var2].first; + }); + + std::unordered_set currently_live; + int last_live = 0; + + for (size_t i = 0; i < block->bindings.size(); i++) { + // include all vars that are currently live + for (int j = last_live; j < static_cast(live_order.size()); j++) { + auto live_var = live_order[j]; + auto live_range = live_ranges[live_var]; + if (live_range.first > static_cast(i)) { + break; + } + currently_live.insert(live_var); + last_live++; + } + // remove vars whose range has come to an end + // (keep a separate set to avoid changing the set while iterating on it) + std::unordered_set remove; + for (auto var : currently_live) { + auto live_range = live_ranges[var]; + if (live_range.second < static_cast(i)) { + remove.insert(var); + } + } + for (auto var : remove) { + currently_live.erase(var); + } + + // if we reach a binding check the conditions + Binding b = block->bindings[i]; + Var defined_var = b->var; + Expr value = GetBoundValue(b); + + if (auto* call_node = value.as()) { + if (auto* op_node = call_node->op.as()) { + if (!OpSupportsInplace(GetRef(op_node))) { + continue; + } + + std::unordered_set candidates; + std::unordered_set exact_match_candidates; + + auto target_sinfo = GatherCandidateSinfo(GetStructInfo(defined_var)); + // can't be done in-place, ignore + if (target_sinfo.empty()) { + continue; + } + + // Check that at least one argument matches size with the result + for (size_t j = 0; j < call_node->args.size(); j++) { + auto arg = call_node->args[j]; + for (auto target : target_sinfo) { + auto [matches_size, matches_exactly] = SizeMatches(target, GetStructInfo(arg), ctx); + if (matches_size) { + candidates.insert(static_cast(j)); + if (matches_exactly) { + exact_match_candidates.insert(static_cast(j)); + } + } + } + } + if (candidates.empty()) { + continue; + } + + // Make sure at least one candidate is not live past this point and does not have an alias + // live past this point + std::unordered_set remove_candidates; + for (auto candidate : candidates) { + if (!InplaceConditionsMet(live_ranges, alias_sets, tuple_map, currently_live, + call_node->args[candidate], i)) { + remove_candidates.insert(candidate); + } + } + // (remove now to avoid modifying the list as we iterate on it) + for (auto candidate : remove_candidates) { + candidates.erase(candidate); + } + + // if we have a candidate, then this can be made in-place. Report the appropriate candidates + if (candidates.empty()) { + continue; + } + + // produce a list of candidates for this index + Array size_candidate_list; + for (auto candidate : candidates) { + size_candidate_list.push_back(Integer(candidate)); + } + size_match_list.push_back(InplaceOpportunity(Integer(i), size_candidate_list)); + + // also gather up the exact match candidates if there are any + Array exact_candidate_list; + for (auto candidate : candidates) { + if (!exact_match_candidates.count(candidate)) { + continue; + } + exact_candidate_list.push_back(Integer(candidate)); + } + if (exact_candidate_list.empty()) { + continue; + } + exact_match_list.push_back(InplaceOpportunity(Integer(i), exact_candidate_list)); + } + } + } + + return {size_match_list, exact_match_list}; +} + +// Replace buffers in a PrimFunc according to the mapping. +tir::Stmt RemapBuffers(const tir::Stmt& stmt, const Map& buffer_map) { + class BufferMapper : public tir::StmtExprMutator { + public: + explicit BufferMapper(const Map& buffer_map) + : buffer_map_(buffer_map) {} + + tir::Stmt Remap(const tir::Stmt& stmt) { return VisitStmt(stmt); } + + PrimExpr VisitExpr_(const tir::BufferLoadNode* op) final { + auto node = Downcast(tir::StmtExprMutator::VisitExpr_(op)); + auto* node_cow = node.CopyOnWrite(); + node_cow->buffer = AttemptRemap(node->buffer); + return node; + } + + tir::Stmt VisitStmt_(const tir::BufferStoreNode* op) final { + auto node = Downcast(tir::StmtExprMutator::VisitStmt_(op)); + auto* node_cow = node.CopyOnWrite(); + node_cow->buffer = AttemptRemap(node->buffer); + return node; + } + + tir::Stmt VisitStmt_(const tir::BufferRealizeNode* op) final { + auto node = Downcast(tir::StmtExprMutator::VisitStmt_(op)); + auto* node_cow = node.CopyOnWrite(); + node_cow->buffer = AttemptRemap(node->buffer); + return node; + } + + tir::Stmt VisitStmt_(const tir::DeclBufferNode* op) final { + auto node = Downcast(tir::StmtExprMutator::VisitStmt_(op)); + auto* node_cow = node.CopyOnWrite(); + node_cow->buffer = AttemptRemap(node->buffer); + return node; + } + + tir::Stmt VisitStmt_(const tir::BlockNode* op) final { + auto node = Downcast(tir::StmtExprMutator::VisitStmt_(op)); + auto* node_cow = node.CopyOnWrite(); + // need the lambdas because class methods are not first-class (how ironic) + node_cow->alloc_buffers = + node->alloc_buffers.Map([this](const tir::Buffer& b) { return AttemptRemap(b); }); + node_cow->reads = + node->reads.Map([this](const tir::BufferRegion& br) { return VisitBufferRegion(br); }); + node_cow->writes = + node->writes.Map([this](const tir::BufferRegion& br) { return VisitBufferRegion(br); }); + node_cow->match_buffers = node->match_buffers.Map( + [this](const tir::MatchBufferRegion& mbr) { return VisitMatchBufferRegion(mbr); }); + return node; + } + + private: + tir::Buffer AttemptRemap(const tir::Buffer& buffer) { + if (buffer_map_.count(buffer)) { + return buffer_map_.at(buffer); + } + return buffer; + } + + tir::BufferRegion VisitBufferRegion(tir::BufferRegion region) { + auto* region_cow = region.CopyOnWrite(); + region_cow->buffer = AttemptRemap(region_cow->buffer); + return region; + } + + tir::MatchBufferRegion VisitMatchBufferRegion(tir::MatchBufferRegion region) { + auto* region_cow = region.CopyOnWrite(); + region_cow->buffer = AttemptRemap(region_cow->buffer); + return region; + } + + const Map& buffer_map_; + }; + + BufferMapper mapper(buffer_map); + auto ret = mapper.Remap(stmt); + return ret; +} + +class ModuleInplaceTransformer : public ExprMutator { + public: + explicit ModuleInplaceTransformer(const IRModule& mod) : mod_(mod) { + builder_ = BlockBuilder::Create(mod); + } + + IRModule Transform() { + // visit every Relax function in the module + for (auto kv : mod_->functions) { + if (auto* func_node = kv.second.as()) { + auto gv = kv.first; + auto func_params = func_node->params; + auto function = Downcast(VisitExpr(GetRef(func_node))); + builder_->UpdateFunction(gv, function); + } + } + + auto ret = builder_->GetContextIRModule(); + // clean up to avoid polluting the IRModule + for (auto gv : legalizers_added) { + ret->Remove(gv); + } + return ret; + } + + Expr VisitExpr_(const FunctionNode* op) override { + auto old_func_params = func_params; + func_params = op->params; + auto ret = ExprMutator::VisitExpr_(op); + func_params = old_func_params; + return ret; + } + + // the only case we will override: we will visit all binding blocks + // and replace any valid calls in them + BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override { + auto block = GetRef(op); + auto old_idxs = inplace_idxs; + + // For now, only handle exact match cases. + // Note: Not passing any input values for now, as we can't make any assumptions + // about them. + auto matches_found = FindInplaceOpportunities(block, {}, builder_); + Map> new_idxs; + for (auto match : matches_found.second) { + new_idxs.Set(block->bindings[match->binding_idx.IntValue()], match->arg_idxs); + } + + inplace_idxs = new_idxs; + auto ret = ExprMutator::VisitBindingBlock_(op); + inplace_idxs = old_idxs; + return ret; + } + + Expr ReplaceBoundCall(const Binding& binding) { + // can just pick the first index arbitrarily (only using one output for now too) + // now replace the binding appropriately + auto arg_idxs = inplace_idxs.at(binding); + auto target = Downcast(GetBoundValue(binding)); + auto new_call = CreateInplaceCall(target, {arg_idxs[0]}); + return builder_->Normalize(new_call); + } + + void VisitBinding_(const VarBindingNode* binding) override { + auto binding_ref = GetRef(binding); + if (!inplace_idxs.count(binding_ref)) { + ExprMutator::VisitBinding_(binding); + return; + } + Expr new_value = ReplaceBoundCall(binding_ref); + builder_->EmitNormalized(VarBinding(binding->var, new_value, binding->span)); + } + + void VisitBinding_(const MatchCastNode* binding) override { + auto binding_ref = GetRef(binding); + if (!inplace_idxs.count(binding_ref)) { + ExprMutator::VisitBinding_(binding); + return; + } + Expr new_value = ReplaceBoundCall(binding_ref); + builder_->EmitNormalized( + MatchCast(binding->var, new_value, binding->struct_info, binding->span)); + } + + // Given the call and indices of arguments that could be done in-place, + // replace the call with a call to an in-place PrimFunc. + // (Made public for testing.) + Call CreateInplaceCall(const Call& call, const Array& inplace_indices) { + static const auto& legalize_map = Op::GetAttrMap("FLegalize"); + static const auto& call_tir_inplace_op = Op::Get("relax.call_tir_inplace"); + + auto op = Downcast(call->op); + auto legalized_call = Downcast(legalize_map[op](builder_, call)); + auto* legalized_call_cow = legalized_call.CopyOnWrite(); + + // The legalized call should be call_tir. We will replace it with call_tir_inplace + // and replace the called PrimFunc with an inplace version + auto legal_op = Downcast(legalized_call->args[0]); + legalizers_added.push_back(legal_op); + auto inline_legal_op_name = legal_op->name_hint + "_inplace"; + + auto mod = builder_->GetContextIRModule(); + auto legal_primfunc = Downcast(mod->Lookup(legal_op)); + auto* legal_primfunc_cow = legal_primfunc.CopyOnWrite(); + size_t num_outs = inplace_indices.size(); + size_t num_params = legal_primfunc->params.size(); + + // the replacement we must make: + // 1. For each output var, replace its corresponding buffers with the corresponding inplace + // index + // var's buffers + // 2. For each output var, replace its instances with the corresponding inplace index var + // 3. Do the same for the *buffer vars* corresponding to the output vars + // 4. Remove the output vars from the param list and buffer map + Map buffer_subst_map; + Map var_subst_map; + for (size_t i = 0; i < num_outs; i++) { + // we will substitute output i with the corresponding param indicated by inplace indices + auto output_var = legal_primfunc->params[num_params - num_outs + i]; + auto inplace_var = legal_primfunc->params[inplace_indices[i].IntValue()]; + var_subst_map.Set(output_var, inplace_var); + + // also do the same with the buffer vars + auto output_buffer = legal_primfunc->buffer_map.at(output_var); + auto inplace_buffer = legal_primfunc->buffer_map.at(inplace_var); + var_subst_map.Set(output_buffer->data, inplace_buffer->data); + buffer_subst_map.Set(output_buffer, inplace_buffer); + } + + // apply substitutions + legal_primfunc_cow->body = RemapBuffers(legal_primfunc->body, buffer_subst_map); + legal_primfunc_cow->body = tir::Substitute( + legal_primfunc->body, [&var_subst_map](const tir::Var& v) -> Optional { + if (var_subst_map.count(v)) { + return var_subst_map.at(v); + } + return Optional(); + }); + + // remove the now-unused outputs from the buffer map + auto buffer_map = legal_primfunc->buffer_map; + for (size_t i = 0; i < num_outs; i++) { + buffer_map.erase(legal_primfunc->params[num_params - num_outs + i]); + } + legal_primfunc_cow->buffer_map = buffer_map; + + // now get rid of the last num_outputs arguments + // (couldn't do earlier or else it would have thrown off the indexing) + legal_primfunc_cow->params = Array( + legal_primfunc->params.begin(), legal_primfunc->params.begin() + (num_params - num_outs)); + + // note: this might be a good time to get rid of the old legalized function, but we don't do it + // now because later ops might need the same one. Instead, we will clean up at the end + auto new_gv = builder_->AddFunction(legal_primfunc, inline_legal_op_name); + + // update the call (change the op, update the argument, change the attrs) + legalized_call_cow->op = call_tir_inplace_op; + + Array new_args(legalized_call->args.begin(), legalized_call->args.end()); + new_args.Set(0, new_gv); + legalized_call_cow->args = new_args; + + ObjectPtr attrs = make_object(); + attrs->inplace_indices = inplace_indices; + legalized_call_cow->attrs = Attrs(attrs); + + return legalized_call; + } + + // Made public for testing. + IRModule CurrentMod() { return builder_->GetContextIRModule(); } + + private: + const IRModule& mod_; + // Keep track of legalizers we add so we can clean up at the end. + Array legalizers_added; + // The current function's params will be treated as non-aliased + // (we are assuming good behavior on the user's part). + Array func_params; + // map of eligible bindings to indices of arguments that can be used as the in-place target + Map> inplace_idxs; +}; + +namespace transform { + +Map> DataflowLivenessAnalysis(const DataflowBlock& block) { + auto liveness_ranges = AnalyzeLiveness(block); + Map> ret; + for (auto kv : liveness_ranges) { + ret.Set(kv.first, {kv.second.first, kv.second.second}); + } + return ret; +} + +Array DataflowAliasAnalysis(const DataflowBlock& block, Array inputs) { + AliasAnalyzer analyzer; + auto res = analyzer.Analyze(block, inputs); + auto alias_sets = res.first; + auto tuple_map = res.second; + Map> new_alias_sets; + Map>> new_tuple_map; + for (auto kv : alias_sets) { + Array aliases; + for (auto alias : kv.second) { + aliases.push_back(alias); + } + new_alias_sets.Set(kv.first, aliases); + } + for (auto kv : tuple_map) { + Array> elem_aliases; + for (auto alias_set : kv.second) { + Array dim_aliases; + for (auto alias : alias_set) { + dim_aliases.push_back(alias); + } + elem_aliases.push_back(dim_aliases); + } + new_tuple_map.Set(kv.first, elem_aliases); + } + return {new_alias_sets, new_tuple_map}; +} + +// this would be preferable to do as a dataflow block pass, +// but the transformation adds new PrimFuncs, so it affects the module +tvm::transform::Pass DataflowUseInplaceCalls() { + return tvm::transform::CreateModulePass( + [](const IRModule& mod, const PassContext& ctx) -> IRModule { + ModuleInplaceTransformer transformer(mod); + return transformer.Transform(); + }, + 0, "DataflowInsertInPlaceCalls", {}, false); +} + +Array> DataflowInplaceAnalysis(const DataflowBlock& block, + const Array& inputs, + const IRModule& mod) { + auto index_lists = relax::FindInplaceOpportunities(block, inputs, BlockBuilder::Create(mod)); + return {Array(index_lists.first.begin(), index_lists.first.end()), + Array(index_lists.second.begin(), index_lists.second.end())}; +} + +// these are exposed only for testing +TVM_REGISTER_GLOBAL("relax.testing.transform.DataflowLivenessAnalysis") + .set_body_typed(DataflowLivenessAnalysis); +TVM_REGISTER_GLOBAL("relax.testing.transform.DataflowAliasAnalysis") + .set_body_typed(DataflowAliasAnalysis); +TVM_REGISTER_GLOBAL("relax.testing.transform.DataflowInplaceAnalysis") + .set_body_typed(DataflowInplaceAnalysis); +TVM_REGISTER_GLOBAL("relax.testing.transform.SingleInplaceCall") + .set_body_typed([](const IRModule& mod, const Call& call, + const Array& inplace_indices) -> Array { + ModuleInplaceTransformer transformer(mod); + auto ret_call = transformer.CreateInplaceCall(call, inplace_indices); + return Array{ret_call, transformer.CurrentMod()}; + }); + +// actually exposed +TVM_REGISTER_GLOBAL("relax.transform.DataflowUseInplaceCalls") + .set_body_typed(DataflowUseInplaceCalls); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc index 785dc6d96320..ef9438350ce0 100644 --- a/src/script/printer/relax/call.cc +++ b/src/script/printer/relax/call.cc @@ -97,11 +97,13 @@ ExprDoc PrintCallee(const relax::Expr& n, const ObjectPath& n_p, const IRDocsifi Optional PrintCallTIRDPSPacked(const relax::Call& n, const ObjectPath& n_p, const IRDocsifier& d) { static const Op& call_tir_op = Op::Get("relax.call_tir"); + static const Op& call_tir_inplace_op = Op::Get("relax.call_tir_inplace"); static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed"); static const Op& call_tir_with_grad_op = Op::Get("relax.call_tir_with_grad"); static const Op& call_tir_local_view = Op::Get("relax.dist.call_tir_local_view"); if (!n->op.same_as(call_tir_op) && !n->op.same_as(call_dps_packed_op) && - !n->op.same_as(call_tir_with_grad_op) && !n->op.same_as(call_tir_local_view)) { + !n->op.same_as(call_tir_with_grad_op) && !n->op.same_as(call_tir_local_view) && + !n->op.same_as(call_tir_inplace_op)) { return NullOpt; } ICHECK(n->args.size() == 2 || n->args.size() == 3); @@ -135,6 +137,19 @@ Optional PrintCallTIRDPSPacked(const relax::Call& n, const ObjectPath& kwargs_values.push_back(d->AsDoc(o_sinfo, o_sinfo_p)); } + // for call_tir_inplace, we also need to include the inplace args + if (n->op.same_as(call_tir_inplace_op)) { + kwargs_keys.push_back("inplace_indices"); + Array index_fields; + if (auto* call_tir_inplace_attrs = n->attrs.as()) { + for (auto inplace_index : call_tir_inplace_attrs->inplace_indices) { + index_fields.push_back( + LiteralDoc::Int(inplace_index.IntValue(), n_p->Attr("attrs")->Attr("inplace_indices"))); + } + } + kwargs_values.push_back(ListDoc(index_fields)); + } + // start of specially handling call_tir_with_grad if (const auto* call_tir_with_grad_attrs = n->attrs.as()) { kwargs_keys.push_back("te_grad_name"); @@ -163,6 +178,8 @@ Optional PrintCallTIRDPSPacked(const relax::Call& n, const ObjectPath& return Relax(d, "dist.call_tir_local_view")->Call(args, kwargs_keys, kwargs_values); } else if (is_dtensor) { return Relax(d, "dist.call_tir")->Call(args, kwargs_keys, kwargs_values); + } else if (n->op.same_as(call_tir_inplace_op)) { + return Relax(d, "call_tir_inplace")->Call(args, kwargs_keys, kwargs_values); } else { return Relax(d, "call_tir")->Call(args, kwargs_keys, kwargs_values); } diff --git a/tests/python/relax/test_dataflow_inplace.py b/tests/python/relax/test_dataflow_inplace.py new file mode 100644 index 000000000000..8d5eb07c7858 --- /dev/null +++ b/tests/python/relax/test_dataflow_inplace.py @@ -0,0 +1,644 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import List, Set, Tuple +import tvm +from tvm import relax, testing +from tvm.relax.transform import DataflowUseInplaceCalls +from tvm.relax.testing.transform import ( + dataflow_liveness_analysis, + dataflow_alias_analysis, + dataflow_inplace_analysis, + dataflow_single_inplace_call, +) +from tvm.script.parser import ir as I, relax as R, tir as T + +import numpy as np + + +def test_liveness_analysis(): + @I.ir_module + class BasicLiveness: + @R.function + def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + with R.dataflow(): + y = R.const(1, dtype="int32") + z = R.add(x, y) + q = R.multiply(z, y) + p = R.add(z, q) + n = R.multiply(p, p) + R.output(n, p) + return n + + block = BasicLiveness["main"].body.blocks[0] + live_ranges = dataflow_liveness_analysis(block) + expected_ranges = { + # x is live past the binding block + "x": (-1, 5), + "y": (0, 2), + "z": (1, 3), + "q": (2, 3), + # exposed though ultimately not used + "p": (3, 5), + "n": (4, 5), + } + actual_ranges = {var.name_hint: live_range for var, live_range in live_ranges.items()} + assert actual_ranges == expected_ranges + + +def test_alias_analysis_basic(): + @I.ir_module + class BasicAliasAnalysis: + @R.function + def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + with R.dataflow(): + y = x # y is an alias of x + z = R.add(y, y) # fresh value + n = z # alias of z + R.output(n) + return n + + block = BasicAliasAnalysis["main"].body.blocks[0] + alias_sets, tuple_map = dataflow_alias_analysis(block, BasicAliasAnalysis["main"].params) + expected = { + "x": {0}, + "y": {0}, + "z": {1}, + "n": {1}, + } + + for var, alias_set in alias_sets.items(): + assert alias_set == expected[var.name_hint] + assert tuple_map == {} + + +def test_alias_analysis_tuple(): + @I.ir_module + class AliasesWithTuples: + @R.function + def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + with R.dataflow(): + y = R.const(1, dtype="int32") + t = (x, y) + a = t[0] + b = t[1] + c = t[0] + d = t[1] + u = t + e = t[0] + f = t[1] + z = R.add(c, d) + n = z + R.output(n) + return n + + block = AliasesWithTuples["main"].body.blocks[0] + alias_sets, tuple_map = dataflow_alias_analysis(block, AliasesWithTuples["main"].params) + expected = { + "x": {0}, + "y": {1}, + "t": {2}, + "a": {0}, + "b": {1}, + "c": {0}, + "d": {1}, + "u": {2}, + "e": {0}, + "f": {1}, + "z": {3}, + "n": {3}, + } + + actual_alias_sets = {var.name_hint: alias_set for var, alias_set in alias_sets.items()} + assert expected == actual_alias_sets + assert 2 in tuple_map + assert tuple_map[2] == [{0}, {1}] + + +def test_alias_split(): + @I.ir_module + class AliasSplit: + @R.function + def main(x: R.Tensor((60,), "int32")) -> R.Tensor((15,), "int32"): + with R.dataflow(): + t = R.split(x, 4) + y = t[0] + z = t[1] + q = t[2] + p = t[3] + n = z + R.output(n) + return n + + block = AliasSplit["main"].body.blocks[0] + alias_sets, tuple_map = dataflow_alias_analysis(block, AliasSplit["main"].params) + expected = { + "x": {0}, + "t": {1}, + "y": {2}, + "z": {3}, + "q": {4}, + "p": {5}, + "n": {3}, + } + + actual_alias_sets = {var.name_hint: alias_set for var, alias_set in alias_sets.items()} + assert expected == actual_alias_sets + assert len(tuple_map) == 1 + assert 1 in tuple_map + assert tuple_map[1] == [{2}, {3}, {4}, {5}] + + +def test_alias_call_tir(): + # call TIR can yield either a single tensor or a tuple + @I.ir_module + class AliasCallTir: + @T.prim_func + def tir_id(x: T.handle, y: T.handle) -> None: + T.func_attr({"global_symbol": "tir_id"}) + m = T.int32() + n = T.int32() + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (m, n)) + + for i, j in T.grid(m, n): + with T.block("id"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + + @T.prim_func + def tir_id2(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_id"}) + m = T.int32() + n = T.int32() + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (m, n)) + C = T.match_buffer(z, (m, n)) + + for i, j in T.grid(m, n): + with T.block("id"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + C[vi, vj] = A[vi, vj] + + @R.function + def main(x: R.Tensor((10, 10), "int32")) -> R.Tensor((10, 10), "int32"): + with R.dataflow(): + cls = AliasCallTir + y = R.call_tir(cls.tir_id, (x,), out_sinfo=R.Tensor((10, 10), "int32")) + t = R.call_tir( + cls.tir_id2, + (y,), + out_sinfo=[R.Tensor((10, 10), "int32"), R.Tensor((10, 10), "int32")], + ) + z = y + p = t[0] + q = t[1] + u = t + m = u[0] + n = u[1] + v = n + R.output(v) + return v + + block = AliasCallTir["main"].body.blocks[0] + alias_sets, tuple_map = dataflow_alias_analysis(block, AliasCallTir["main"].params) + expected = { + "x": {0}, + "y": {1}, + "t": {2}, + "z": {1}, + "p": {3}, + "q": {4}, + "u": {2}, + "m": {3}, + "n": {4}, + "v": {4}, + } + + actual_alias_sets = {var.name_hint: alias_set for var, alias_set in alias_sets.items()} + assert expected == actual_alias_sets + assert len(tuple_map) == 1 + assert 2 in tuple_map + assert tuple_map[2] == [{3}, {4}] + + +def test_mystery_calls(): + @I.ir_module + class AliasChaosCalls: + @R.function + def identity(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + return x + + @R.function + def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + with R.dataflow(): + cls = AliasChaosCalls + y = cls.identity(x) + z = cls.identity(y) + m = R.const(1, dtype="int32") + n = R.const(2, dtype="int32") + t = (m, n) + a = R.call_pure_packed( + "chaos", t, sinfo_args=R.Tuple(R.Tensor((), "int32"), R.Tensor((), "int32")) + ) + b = a[0] + c = a[1] + R.output(c) + return c + + block = AliasChaosCalls["main"].body.blocks[0] + alias_sets, tuple_map = dataflow_alias_analysis(block, AliasChaosCalls["main"].params) + expected = { + "x": {0}, + "y": {0, 1}, + "z": {0, 1, 2}, + "m": {3}, + "n": {4}, + "t": {5}, + "a": {3, 4, 5, 6, 7, 8}, # either t or a fresh tuple + "b": {3, 4, 5, 6, 7, 8}, # the tuple components can be aliased to any member... + "c": {3, 4, 5, 6, 7, 8}, # the tuple components can be aliased to any member... + # (in principle, we can use type information to narrow down the aliasing) + } + + actual_alias_sets = {var.name_hint: alias_set for var, alias_set in alias_sets.items()} + assert expected == actual_alias_sets + assert len(tuple_map) == 2 + assert 5 in tuple_map + assert tuple_map[5] == [{3}, {4}] + assert 6 in tuple_map + assert tuple_map[6] == [{3, 4, 5, 6, 7, 8}, {3, 4, 5, 6, 7, 8}] + + +def test_alias_external_value(): + @I.ir_module + class AliasExternalValue: + @R.function + def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + y = R.const(1, dtype="int32") # not in DF block, treated as external + t1 = (y, y) # not in DF block, treated as external + with R.dataflow(): + z = y # mystery value + a = R.const(2, dtype="int32") + t2 = (z, a) + b = t2[0] + c = t1[1] # tuple index into external value + R.output(b) + return b + + block = AliasExternalValue["main"].body.blocks[1] + alias_sets, tuple_map = dataflow_alias_analysis(block, AliasExternalValue["main"].params) + expected = { + "x": {0}, + "z": {-1}, + "a": {1}, + "t2": {2}, + "b": {-1}, + "c": {-1}, + } + + actual_alias_sets = {var.name_hint: alias_set for var, alias_set in alias_sets.items()} + assert expected == actual_alias_sets + assert len(tuple_map) == 1 + assert 2 in tuple_map + assert tuple_map[2] == [{-1}, {1}] + + +def test_inplace_simple_case(): + @I.ir_module + class InplaceBasic: + @R.function + def main( + x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32") + ) -> R.Tensor((2, 3), "int32"): + with R.dataflow(): + z = R.add(x, y) # cannot be done inplace: x and y are live later + p = R.add(z, z) # can be done inplace: z is not used later + r = p # alias of p + m = R.multiply(p, p) # p is not used later but r is, so can't do inplace + n = R.add(m, r) # can be done inplace: r is not used again + ret = R.subtract(n, m) # can be done inplace: neither is used again + R.output(ret) + return ret + + block = InplaceBasic["main"].body.blocks[0] + size_match, exact_match = dataflow_inplace_analysis( + block, InplaceBasic["main"].params, InplaceBasic + ) + + # order does not matter for the listing of candidates, so we have to implement as sets + def assert_candidate_list( + actual: List[Tuple[int, Set[int]]], expected: List[Tuple[int, Set[int]]] + ) -> None: + assert len(actual) == len(expected) + for i in range(len(actual)): + assert actual[i][0] == expected[i][0] + assert len(expected[i][1]) == len(actual[i][1]) + for idx in actual[i][1]: + assert idx in expected[i][1] + + assert_candidate_list(size_match, [(1, {0, 1}), (4, {1}), (5, {0, 1})]) + # TODO(@slyubomirsky): I couldn't think of an easy example where sizes don't match, + # but broadcasting might cause it to happen + assert_candidate_list(exact_match, [(1, {0, 1}), (4, {1}), (5, {0, 1})]) + + +def test_inplace_single_call(): + @I.ir_module + class TestModule: + @R.function + def main( + x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32") + ) -> R.Tensor((2, 3), dtype="float32"): + z = R.add(x, y) + q = R.nn.silu(z) + return q + + add_call = TestModule["main"].body.blocks[0].bindings[0].value + new_add, new_mod = dataflow_single_inplace_call(TestModule, add_call, [0]) + + @T.prim_func(private=True) + def expected_add( + A: T.Buffer((T.int64(2), T.int64(3)), "float32"), + B: T.Buffer((T.int64(2), T.int64(3)), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1]) + T.writes(A[v_ax0, v_ax1]) + A[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1] + + tvm.ir.assert_structural_equal(new_mod["add_inplace"], expected_add) + assert new_add.op.name == "relax.call_tir_inplace" + assert new_add.args[0].name_hint == "add_inplace" + for i, arg in enumerate(new_add.args[1].fields): + arg == add_call.args[i] + new_add.attrs.inplace_indices == [0] + + @T.prim_func(private=True) + def expected_silu(A: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + compute = T.alloc_buffer((T.int64(2), T.int64(3))) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.sigmoid(A[v_i0, v_i1]) + for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_multiply"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1], compute[v_ax0, v_ax1]) + T.writes(A[v_ax0, v_ax1]) + A[v_ax0, v_ax1] = A[v_ax0, v_ax1] * compute[v_ax0, v_ax1] + + silu_call = TestModule["main"].body.blocks[0].bindings[1].value + new_silu, new_mod = dataflow_single_inplace_call(TestModule, silu_call, [0]) + + tvm.ir.assert_structural_equal(new_mod["silu_inplace"], expected_silu) + assert new_silu.op.name == "relax.call_tir_inplace" + assert new_silu.args[0].name_hint == "silu_inplace" + for i, arg in enumerate(new_silu.args[1].fields): + arg == silu_call.args[i] + new_silu.attrs.inplace_indices == [0] + + +def test_insert_inplace_calls(): + @I.ir_module + class EndToEndTest: + @R.function + def main( + x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((1, 3), dtype="float32") + ) -> R.Tensor((2, 3), dtype="float32"): + with R.dataflow(): + z = R.add(x, y) # broadcast happens here + # Cannot be done in-place because x is an argument. + a = R.add(z, y) # this one can be done in-place + q = R.multiply(a, y) # broadcast again, a is eligible + r = R.subtract(y, y) # cannot be done in-place because y is an argument + s = R.subtract(r, r) # No broadcast. Can be done in-place + m = R.multiply(q, s) # should give us all zeros + R.output(m) + return m + + @I.ir_module + class Expected: + @T.prim_func(private=True) + def add_inplace( + A: T.Buffer((T.int64(2), T.int64(3)), "float32"), + B: T.Buffer((T.int64(1), T.int64(3)), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1], B[T.int64(0), v_ax1]) + T.writes(A[v_ax0, v_ax1]) + A[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[T.int64(0), v_ax1] + + @T.prim_func(private=True) + def multiply_inplace( + A: T.Buffer((T.int64(2), T.int64(3)), "float32"), + B: T.Buffer((T.int64(1), T.int64(3)), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_multiply"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1], B[T.int64(0), v_ax1]) + T.writes(A[v_ax0, v_ax1]) + A[v_ax0, v_ax1] = A[v_ax0, v_ax1] * B[T.int64(0), v_ax1] + + @T.prim_func(private=True) + def subtract_inplace( + A: T.Buffer((T.int64(1), T.int64(3)), "float32"), + B: T.Buffer((T.int64(1), T.int64(3)), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for ax0, ax1 in T.grid(T.int64(1), T.int64(3)): + with T.block("T_subtract"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1]) + T.writes(B[v_ax0, v_ax1]) + B[v_ax0, v_ax1] = A[v_ax0, v_ax1] - B[v_ax0, v_ax1] + + @R.function + def main( + x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((1, 3), dtype="float32") + ) -> R.Tensor((2, 3), dtype="float32"): + cls = Expected + with R.dataflow(): + z: R.Tensor((2, 3), dtype="float32") = R.add(x, y) + a: R.Tensor((2, 3), dtype="float32") = R.call_tir_inplace( + cls.add_inplace, + (z, y), + inplace_indices=[0], + out_sinfo=[ + R.Tensor((2, 3), dtype="float32"), + ], + ) + q: R.Tensor((2, 3), dtype="float32") = R.call_tir_inplace( + cls.multiply_inplace, + (a, y), + inplace_indices=[0], + out_sinfo=[ + R.Tensor((2, 3), dtype="float32"), + ], + ) + r: R.Tensor((1, 3), dtype="float32") = R.subtract(y, y) + s: R.Tensor((1, 3), dtype="float32") = R.call_tir_inplace( + cls.subtract_inplace, + (r, r), + inplace_indices=[1], + out_sinfo=[ + R.Tensor((1, 3), dtype="float32"), + ], + ) + m: R.Tensor((2, 3), dtype="float32") = R.call_tir_inplace( + cls.multiply_inplace, + (q, s), + inplace_indices=[0], + out_sinfo=[ + R.Tensor((2, 3), dtype="float32"), + ], + ) + R.output(m) + return m + + transform_pass = DataflowUseInplaceCalls() + new_mod = transform_pass(EndToEndTest) + tvm.ir.assert_structural_equal(new_mod, Expected) + + x = tvm.nd.array(np.random.rand(2, 3).astype("float32")) + y = tvm.nd.array(np.random.rand(1, 3).astype("float32")) + expected = np.zeros((2, 3), dtype="float32") + + target = tvm.target.Target("llvm") + ex = relax.build(new_mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + res = vm["main"](x, y) + assert (expected == res.numpy()).all() + + +def test_dynamic(): + @I.ir_module + class DynamicTestCase: + @R.function + def main( + x: R.Tensor(("a", "b"), dtype="float32"), y: R.Tensor(("a", "b"), dtype="float32") + ) -> R.Tensor(("a", "b"), dtype="float32"): + with R.dataflow(): + z = R.add(x, y) + # Cannot be done in-place because x and y are arguments + a = R.add(z, y) # this one can be done in-place + s = R.subtract(a, a) # No broadcast. Can be done in-place + R.output(s) + return s + + # the result should be all zeroes + transform_pass = DataflowUseInplaceCalls() + new_mod = transform_pass(DynamicTestCase) + + @I.ir_module + class Expected: + @T.prim_func(private=True) + def add_inplace(var_A: T.handle, var_B: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + a, b = T.int64(), T.int64() + A = T.match_buffer(var_A, (a, b)) + B = T.match_buffer(var_B, (a, b)) + for ax0, ax1 in T.grid(a, b): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1]) + T.writes(A[v_ax0, v_ax1]) + A[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1] + + @T.prim_func(private=True) + def subtract_inplace(var_A: T.handle, var_B: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + a, b = T.int64(), T.int64() + A = T.match_buffer(var_A, (a, b)) + B = T.match_buffer(var_B, (a, b)) + for ax0, ax1 in T.grid(a, b): + with T.block("T_subtract"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1]) + T.writes(B[v_ax0, v_ax1]) + B[v_ax0, v_ax1] = A[v_ax0, v_ax1] - B[v_ax0, v_ax1] + + @R.function + def main( + x: R.Tensor(("a", "b"), dtype="float32"), y: R.Tensor(("a", "b"), dtype="float32") + ) -> R.Tensor(("a", "b"), dtype="float32"): + a = T.int64() + b = T.int64() + cls = Expected + with R.dataflow(): + z = R.add(x, y) + a_1 = R.call_tir_inplace( + cls.add_inplace, + (z, y), + out_sinfo=R.Tensor((a, b), dtype="float32"), + inplace_indices=[0], + ) + s = R.call_tir_inplace( + cls.subtract_inplace, + (a_1, a_1), + out_sinfo=R.Tensor((a, b), dtype="float32"), + inplace_indices=[1], + ) + R.output(s) + return s + + tvm.ir.assert_structural_equal(new_mod, Expected, map_free_vars=True) + x = tvm.nd.array(np.random.rand(2, 3).astype("float32")) + y = tvm.nd.array(np.random.rand(2, 3).astype("float32")) + expected = np.zeros((2, 3), dtype="float32") + + target = tvm.target.Target("llvm") + ex = relax.build(new_mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + res = vm["main"](x, y) + assert (expected == res.numpy()).all() + + +def test_dynamic_mismatch(): + # cannot statically prove the shapes to be equal so the module should be unchanged + @I.ir_module + class DynamicMistmatchTestCase: + @R.function + def main( + x: R.Tensor(("a", "b"), dtype="float32"), y: R.Tensor(("c", "d"), dtype="float32") + ): + with R.dataflow(): + z = R.add(x, y) + # Cannot be done in-place because x and y are arguments + a = R.add(z, y) # cannot conclude that shapes match + R.output(a) + return a + + transform_pass = DataflowUseInplaceCalls() + new_mod = transform_pass(DynamicMistmatchTestCase) + tvm.ir.assert_structural_equal(new_mod, DynamicMistmatchTestCase) + + +if __name__ == "__main__": + testing.main() diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index b45c3c6e4a93..f317d04f59ae 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -986,6 +986,42 @@ def main(v0: R.Tensor([54, 96], "float32")): _check(Module) +def test_call_tir_inplace(): + @tvm.script.ir_module + class Module: + @T.prim_func + def copy( + A: T.Buffer((2, 3), "int32"), + B: T.Buffer((2, 3), "int32"), + out1: T.Buffer((2, 3), "int32"), + ): + # copies the contents of B into A and out1 + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_zeros"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(B[ax0, ax1]) + T.writes(A[ax0, ax1], out1[ax0, ax1]) + A[ax0, ax1] = B[ax0, ax1] + out1[ax0, ax1] = B[ax0, ax1] + + @R.function + def main( + x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32") + ) -> R.Tuple( + R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32") + ): + res = R.call_tir_inplace( + Module.copy, + (x, y), + [0, -1], + [R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")], + ) + return res + + _check(Module) + + def test_local_function(): @R.function def main( diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index dc3334f216c0..530e45e61074 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -399,6 +399,31 @@ def test_call_tir_with_grad(): ) +def test_call_tir_inplace(): + x = relax.Var("x", R.Tensor((32, 32), dtype="int32")) + y = relax.Var("y", R.Tensor((32, 32), dtype="int32")) + t = tir.Var("t", dtype="int64") + call = relax.call_tir_inplace( + relax.GlobalVar("tir_func"), + ( + x, + y, + ), + inplace_indices=[-1, 0], + out_sinfo=[R.Tensor((32, 32), dtype="int32"), R.Tensor((32, 32), dtype="int32")], + tir_vars=[t], + ) + _assert_print( + call, + """ +x: R.Tensor((32, 32), dtype="int32") +y: R.Tensor((32, 32), dtype="int32") +t = T.int64() +R.call_tir_inplace(tir_func, (x, y), out_sinfo=[R.Tensor((32, 32), dtype="int32"), R.Tensor((32, 32), dtype="int32")], inplace_indices=[-1, 0], tir_vars=R.shape([t])) + """, + ) + + def test_seq_expr(): x = tir.Var("x", "int64") a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32"))