From 4ae40863a57f197a6941f809231d21ce860c4584 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 2 Oct 2023 15:42:57 -0400 Subject: [PATCH 01/55] Implement basic analyses --- src/relax/transform/dataflow_in_place.cc | 275 +++++++++++++++++++++++ 1 file changed, 275 insertions(+) create mode 100644 src/relax/transform/dataflow_in_place.cc diff --git a/src/relax/transform/dataflow_in_place.cc b/src/relax/transform/dataflow_in_place.cc new file mode 100644 index 000000000000..cf44b88b5980 --- /dev/null +++ b/src/relax/transform/dataflow_in_place.cc @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +#include "utils.h" + +namespace tvm { +namespace relax { + +std::unordered_map, ObjectPtrHash, ObjectPtrEqual> analyze_liveness( + const DataflowBlock& block) { + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> ret; + for (int i = block->bindings.size() - 1; i >= 0; i--) { + Binding b = block->bindings[i]; + Var defined_var = b->var; + Expr value; + if (const auto* var_binding = b.as()) { + value = var_binding->value; + } else if (const auto* match_binding = b.as()) { + value = match_binding->value; + } else { + CHECK(false) << "Invalid binding"; // impossible + } + Array used_vars; + // for a function literal, we consider only the free vars + // (those captured from the outer scope) + if (value.as()) { + used_vars = FreeVars(value); + } else { + used_vars = AllVars(value); + } + + for (auto var : used_vars) { + *******************************************************************************************************************************************************************************************************************************if ( + !ret.count(var)) { + ret[var] = {-1, i}; + } + } + + if (!ret.count(defined_var)) { + ret[defined_var] = {i, block->bindings.size()}; + } else { + // this means the var is used later but we encountered its definition now + auto last_range = ret[defined_var]; + CHECK_EQ(last_range.first, -1); + std::pair new_range = {i, last_range.second}; + ret[defined_var] = new_range; + } + } + return ret; +} + +class AliasAnalyzer { + public: + explicit AliasAnalyzer() : alias_map_(), tuple_map_(), captured_by_functions_(), mem_idx_(0) {} + + // alias: map of var to memory locations (we will call these indices and use -1 as an index for + // "unknown") + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> Analyze( + const DataflowBlock& block, const Array& inputs) { + for (auto input : inputs) { + int curr_idx = get_fresh_idx(); + alias_map_[input] = {curr_idx}; + if (auto* tup_info = GetStructInfoAs(input)) { + insert_fresh_tuple(curr_idx, tup_info->fields.size()); + } + } + + for (const Binding& binding : block->bindings) { + Var current_var = binding->var; + Expr value; + if (const auto* var_binding = binding.as()) { + value = var_binding->value; + } else if (const auto* match_binding = binding.as()) { + value = match_binding->value; + } else { + CHECK(false) << "Invalid binding"; // impossible + } + alias_map_[current_var] = get_alias_set(value); + } + + return alias_map_; + } + + private: + int get_fresh_idx() { + int ret = mem_idx_; + mem_idx_++; + return ret; + } + + void insert_fresh_tuple(int tup_idx, size_t num_members) { + std::vector> tuple_set; + for (int i = 0; i < num_members; i++) { + tuple_set.push_back({get_fresh_idx()}); + } + tuple_map_[tup_idx] = tuple_set; + } + + // Conservative extremely pessimistic assumption: + // assume that the result of a non-op call can be aliased to anything + // ever passed to or returned from any non-op call. + // For tuples, assume all members are aliased. Yeah, it's bad. + // (Skip first arg is for handling call_pure_packed, where the first arg is an ExternFunc that we + // should ignore) + std::unordered_set handle_mystery_call(const CallNode* call_node, + bool skip_first_arg = false) { + // the result may or may not be newly allocated + std::unordered_set ret; + int res_idx = get_fresh_idx(); + captured_by_functions_.insert(res_idx); + + for (size_t i = (skip_first_arg) ? 1 : 0; i < call_node->args.size(); i++) { + auto arg = call_node->args[i]; + auto arg_alias_set = get_alias_set(arg); + // for any tuples in the set, also throw in all components since they can get captured + // too + std::vector captured_tuples; + for (int alias_idx : arg_alias_set) { + if (tuple_map_.count(alias_idx)) { + captured_tuples.push_back(alias_idx); + } + } + // this is to avoid modifying the set while we're iterating over it + for (int tuple_idx : captured_tuples) { + auto tuple_members = tuple_map_[tuple_idx]; + for (std::unordered_set tuple_member : tuple_members) { + arg_alias_set.insert(tuple_member.begin(), tuple_member.end()); + } + } + captured_by_functions_.insert(arg_alias_set.begin(), arg_alias_set.end()); + } + ret.insert(captured_by_functions_.begin(), captured_by_functions_.end()); + return ret; + } + + std::unordered_set get_alias_set(const Expr& expr) { + std::unordered_set ret; + + // cases for value: + // constant: it's a fresh index + // var: look up in alias map (-1 if not present) + // op call: assume it's fresh (may need to make list of exceptions) + // tuple: fresh entry in tuple index, recurse to determine indices for values + // function/packed call: chaos reigns, alias with everything ever passed or returned from func + // (if tuple is passed, assume also aliased with all members of the tuple) + // tuple index: -1 if tuple is not in tuple map, otherwise look up corresponding entry + // function constant: give them a fresh index (TODO: we can handle in more detail if this is a + // case we need to support) prim value: fresh index if node: should not happen inside dataflow + // block + if (value.as() || value.as() || value.as()) { + // TODO(@slyubomirsky): We will probably want special handling for closures + ret.insert(get_fresh_idx()); + } else if (auto* target_var_node = value.as()) { + auto target_var = Downcast(target_var_node); + if (alias_map_.count(target_var)) { + ret.insert(alias_map_[target_var].begin(), alias_map_[target_var].end()); + } else { + ret.insert(-1); + } + } else if (auto* target_tuple = value.as()) { + // fresh idx but we update the tuple map + int tup_idx = get_fresh_idx(); + ret.insert(tup_idx); + std::vector> new_tuple_map; + for (auto field = target_tuple->fields) { + new_tuple_map.push_back(get_alias_set(field)); + } + tuple_map_[tup_idx] = new_tuple_map; + } else if (auto* target_tgi = value.as()) { + std::unordered_set tuple_set = get_alias_set(target_tgi->tuple); + // if there's only one possibility for the tuple and it's in the tuple map, + // index into it + if (tuple_set.size() == 1) { + int index = *(tuple_set.begin()); + if (tuple_map_.count(index)) { + return tuple_map_[index][target_tgi->index]; + } else { + ret.insert(-1); + } + } else { + ret.insert(-1); + } + } else if (auto* call_node = value.as()) { + if (auto* op_node = call_node->op.as()) { + // call_pure_packed: treat as non-op call + if (op_node.name == "call_pure_packed") { + return handle_mystery_call(call_node, true); + } + // split: Returns a tuple, treat as allocation + else if (op_node.name == "split") { + // tuple is freshly allocated, but also add components to the tuple map + int tup_idx = get_fresh_idx(); + ret.insert(tup_idx); + + std::vector> tuple_set; + auto attrs = Downcast(call_node->attrs); + int num_members = 0; + if (const auto* indices = attrs->indices_or_sections.as()) { + // see struct info rule for split + num_members = indices.size() + 1; + } else if (const auto* n_section = attrs->indices_or_sections.as()) { + num_members = n_section->value; + } else { + CHECK(false) << "Invalid split call"; + } + + for (int i = 0; i < num_members; i++) { + tuple_set.push_back({get_fresh_idx()}); + } + tuple_map_[tup_idx] = tuple_set; + } + // call_tir: can potentially return a tuple + else if (op_node.name == "call_tir") { + if (auto* tuple_struct_info = call->sinfo_args[0].as()) { + int tup_idx = get_fresh_idx(); + ret.insert(tup_idx); + + int num_members = tuple_struct_info->fields.size(); + insert_fresh_tuple(tup_idx, num_members); + } else { + ret.insert(get_fresh_idx()); + } + } + // We are assuming most op calls return a single fresh allocation. + // We may have to track more exceptions + else { + ret.insert(get_fresh_idx()); + } + } else { + // assume any non-op call can be extremely dangerous and do anything + return handle_mystery_call(call_node); + } + } + + return ret; + } + + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> alias_map_; + std::unordered_map>> tuple_map_; + std::unordered_set captured_by_functions_; + int mem_idx_; +}; + +// export for testing + +// check for in-place eligibility: +// 1. see if there's an arg big enough to hold the result +// 2. see if the arg is live past the call +// 3. see if the arg has an alias that's live past the call +// if conditions are met, we're good to go + +} // namespace relax +} // namespace tvm \ No newline at end of file From 2ffec3a4ebf61bceaa119f967c013b6e1938cb13 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 2 Oct 2023 18:58:42 -0400 Subject: [PATCH 02/55] Fix typo --- src/relax/transform/dataflow_in_place.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/relax/transform/dataflow_in_place.cc b/src/relax/transform/dataflow_in_place.cc index cf44b88b5980..12fff602047b 100644 --- a/src/relax/transform/dataflow_in_place.cc +++ b/src/relax/transform/dataflow_in_place.cc @@ -52,8 +52,7 @@ std::unordered_map, ObjectPtrHash, ObjectPtrEqual> anal } for (auto var : used_vars) { - *******************************************************************************************************************************************************************************************************************************if ( - !ret.count(var)) { + if (!ret.count(var)) { ret[var] = {-1, i}; } } From 3faf91413dbf8b1d74bd11b21c0e14d00932ef3f Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 3 Oct 2023 16:12:08 -0400 Subject: [PATCH 03/55] Add tests for analyses --- include/tvm/relax/analysis.h | 4 + python/tvm/relax/analysis/analysis.py | 19 +- src/relax/transform/dataflow_in_place.cc | 129 +++++---- tests/python/relax/test_dataflow_in_place.py | 262 +++++++++++++++++++ 4 files changed, 358 insertions(+), 56 deletions(-) create mode 100644 tests/python/relax/test_dataflow_in_place.py diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index 6e2209d51950..7ca021fc1558 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -533,6 +533,10 @@ TVM_DLL bool WellFormed(IRModule m, bool check_struct_info = true); TVM_DLL Map> SuggestLayoutTransforms( const Function& fn, Array write_buffer_transformations); +// included for testing purposes +TVM_DLL Map> DataflowLivenessAnalysis(const DataflowBlock& block); +TVM_DLL Map> DataflowAliasAnalysis(const DataflowBlock& block, + Array inputs); } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index 38f5ea2fea0e..ddee364c9490 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -21,7 +21,7 @@ configuring the passes and scripting them in Python. """ -from typing import Dict, List, Optional, Union, Callable +from typing import Dict, List, Optional, Set, Tuple, Union, Callable from enum import IntEnum import tvm @@ -528,3 +528,20 @@ def detect_recursion(mod: tvm.IRModule) -> List[List[GlobalVar]]: with any other, it will be a singleton in this list. """ return _ffi_api.detect_recursion(mod) # type: ignore + + +# expose for testing +def dataflow_liveness_analysis(block: DataflowBlock) -> Dict[Var, Tuple[int, int]]: + live_ranges = _ffi_api.DataflowLivenessAnalysis(block) # type: ignore + ret = {} + for (var, live_range) in live_ranges.items(): + ret[var] = tuple(live_range) + return ret # type: ignore + + +def dataflow_alias_analysis(block: DataflowBlock, inputs: List[Var]) -> Dict[Var, Set[int]]: + alias_sets = _ffi_api.DataflowAliasAnalysis(block, inputs) # type: ignore + ret = {} + for (var, alias_set) in alias_sets.items(): + ret[var] = set(alias_set) + return ret # type: ignore diff --git a/src/relax/transform/dataflow_in_place.cc b/src/relax/transform/dataflow_in_place.cc index 12fff602047b..02d3e36b694d 100644 --- a/src/relax/transform/dataflow_in_place.cc +++ b/src/relax/transform/dataflow_in_place.cc @@ -82,7 +82,7 @@ class AliasAnalyzer { int curr_idx = get_fresh_idx(); alias_map_[input] = {curr_idx}; if (auto* tup_info = GetStructInfoAs(input)) { - insert_fresh_tuple(curr_idx, tup_info->fields.size()); + insert_fresh_tuple(curr_idx, tup_info); } } @@ -96,7 +96,7 @@ class AliasAnalyzer { } else { CHECK(false) << "Invalid binding"; // impossible } - alias_map_[current_var] = get_alias_set(value); + alias_map_[current_var] = get_alias_set(value, current_var); } return alias_map_; @@ -109,52 +109,60 @@ class AliasAnalyzer { return ret; } - void insert_fresh_tuple(int tup_idx, size_t num_members) { + void insert_fresh_tuple(int tup_idx, const TupleStructInfoNode* tup_info) { std::vector> tuple_set; - for (int i = 0; i < num_members; i++) { - tuple_set.push_back({get_fresh_idx()}); + for (int i = 0; i < static_cast(tup_info->fields.size()); i++) { + int curr_field = get_fresh_idx(); + tuple_set.push_back({curr_field}); + if (auto* nested_tup_info = tup_info->fields[i].as()) { + insert_fresh_tuple(curr_field, nested_tup_info); + } } tuple_map_[tup_idx] = tuple_set; } + // capture the given index and also its tuple components (including recursively) + // if they exist + void add_to_captured_set(int idx) { + captured_by_functions_.insert(idx); + if (tuple_map_.count(idx)) { + for (auto comp_set : tuple_map_[idx]) { + for (auto tup_comp_idx : comp_set) { + add_to_captured_set(tup_comp_idx); + } + } + } + } + // Conservative extremely pessimistic assumption: // assume that the result of a non-op call can be aliased to anything // ever passed to or returned from any non-op call. // For tuples, assume all members are aliased. Yeah, it's bad. // (Skip first arg is for handling call_pure_packed, where the first arg is an ExternFunc that we // should ignore) - std::unordered_set handle_mystery_call(const CallNode* call_node, + std::unordered_set handle_mystery_call(const CallNode* call_node, const Var& bound_var, bool skip_first_arg = false) { // the result may or may not be newly allocated std::unordered_set ret; int res_idx = get_fresh_idx(); - captured_by_functions_.insert(res_idx); + // the result may be a tuple + if (auto* tup_info_node = GetStructInfoAs(bound_var)) { + insert_fresh_tuple(res_idx, tup_info_node); + } + add_to_captured_set(res_idx); for (size_t i = (skip_first_arg) ? 1 : 0; i < call_node->args.size(); i++) { auto arg = call_node->args[i]; - auto arg_alias_set = get_alias_set(arg); - // for any tuples in the set, also throw in all components since they can get captured - // too - std::vector captured_tuples; + auto arg_alias_set = get_alias_set(arg, bound_var); for (int alias_idx : arg_alias_set) { - if (tuple_map_.count(alias_idx)) { - captured_tuples.push_back(alias_idx); - } + add_to_captured_set(alias_idx); } - // this is to avoid modifying the set while we're iterating over it - for (int tuple_idx : captured_tuples) { - auto tuple_members = tuple_map_[tuple_idx]; - for (std::unordered_set tuple_member : tuple_members) { - arg_alias_set.insert(tuple_member.begin(), tuple_member.end()); - } - } - captured_by_functions_.insert(arg_alias_set.begin(), arg_alias_set.end()); } ret.insert(captured_by_functions_.begin(), captured_by_functions_.end()); return ret; } - std::unordered_set get_alias_set(const Expr& expr) { + std::unordered_set get_alias_set(const Expr& value, const Var& bound_var) { std::unordered_set ret; // cases for value: @@ -172,7 +180,7 @@ class AliasAnalyzer { // TODO(@slyubomirsky): We will probably want special handling for closures ret.insert(get_fresh_idx()); } else if (auto* target_var_node = value.as()) { - auto target_var = Downcast(target_var_node); + auto target_var = GetRef(target_var_node); if (alias_map_.count(target_var)) { ret.insert(alias_map_[target_var].begin(), alias_map_[target_var].end()); } else { @@ -183,12 +191,12 @@ class AliasAnalyzer { int tup_idx = get_fresh_idx(); ret.insert(tup_idx); std::vector> new_tuple_map; - for (auto field = target_tuple->fields) { - new_tuple_map.push_back(get_alias_set(field)); + for (auto field : target_tuple->fields) { + new_tuple_map.push_back(get_alias_set(field, bound_var)); } tuple_map_[tup_idx] = new_tuple_map; } else if (auto* target_tgi = value.as()) { - std::unordered_set tuple_set = get_alias_set(target_tgi->tuple); + std::unordered_set tuple_set = get_alias_set(target_tgi->tuple, bound_var); // if there's only one possibility for the tuple and it's in the tuple map, // index into it if (tuple_set.size() == 1) { @@ -204,40 +212,23 @@ class AliasAnalyzer { } else if (auto* call_node = value.as()) { if (auto* op_node = call_node->op.as()) { // call_pure_packed: treat as non-op call - if (op_node.name == "call_pure_packed") { - return handle_mystery_call(call_node, true); + if (op_node->name == "relax.call_pure_packed") { + return handle_mystery_call(call_node, bound_var, true); } // split: Returns a tuple, treat as allocation - else if (op_node.name == "split") { + else if (op_node->name == "relax.split") { // tuple is freshly allocated, but also add components to the tuple map int tup_idx = get_fresh_idx(); ret.insert(tup_idx); - - std::vector> tuple_set; - auto attrs = Downcast(call_node->attrs); - int num_members = 0; - if (const auto* indices = attrs->indices_or_sections.as()) { - // see struct info rule for split - num_members = indices.size() + 1; - } else if (const auto* n_section = attrs->indices_or_sections.as()) { - num_members = n_section->value; - } else { - CHECK(false) << "Invalid split call"; - } - - for (int i = 0; i < num_members; i++) { - tuple_set.push_back({get_fresh_idx()}); - } - tuple_map_[tup_idx] = tuple_set; + // the LHS (the bound var) will definitely have a tuple struct info + insert_fresh_tuple(tup_idx, GetStructInfoAs(bound_var)); } // call_tir: can potentially return a tuple - else if (op_node.name == "call_tir") { - if (auto* tuple_struct_info = call->sinfo_args[0].as()) { + else if (op_node->name == "relax.call_tir") { + if (auto* tuple_struct_info = call_node->sinfo_args[0].as()) { int tup_idx = get_fresh_idx(); ret.insert(tup_idx); - - int num_members = tuple_struct_info->fields.size(); - insert_fresh_tuple(tup_idx, num_members); + insert_fresh_tuple(tup_idx, tuple_struct_info); } else { ret.insert(get_fresh_idx()); } @@ -249,7 +240,7 @@ class AliasAnalyzer { } } else { // assume any non-op call can be extremely dangerous and do anything - return handle_mystery_call(call_node); + return handle_mystery_call(call_node, bound_var); } } @@ -262,13 +253,41 @@ class AliasAnalyzer { int mem_idx_; }; -// export for testing - // check for in-place eligibility: // 1. see if there's an arg big enough to hold the result // 2. see if the arg is live past the call // 3. see if the arg has an alias that's live past the call // if conditions are met, we're good to go +// export for testing +namespace transform { + +Map> DataflowLivenessAnalysis(const DataflowBlock& block) { + auto liveness_ranges = analyze_liveness(block); + Map> ret; + for (auto kv : liveness_ranges) { + ret.Set(kv.first, {kv.second.first, kv.second.second}); + } + return ret; +} + +Map> DataflowAliasAnalysis(const DataflowBlock& block, Array inputs) { + AliasAnalyzer analyzer; + auto alias_sets = analyzer.Analyze(block, inputs); + Map> ret; + for (auto kv : alias_sets) { + Array aliases; + for (auto alias : kv.second) { + aliases.push_back(alias); + } + ret.Set(kv.first, aliases); + } + return ret; +} + +TVM_REGISTER_GLOBAL("relax.analysis.DataflowLivenessAnalysis") + .set_body_typed(DataflowLivenessAnalysis); +TVM_REGISTER_GLOBAL("relax.analysis.DataflowAliasAnalysis").set_body_typed(DataflowAliasAnalysis); +} // namespace transform } // namespace relax } // namespace tvm \ No newline at end of file diff --git a/tests/python/relax/test_dataflow_in_place.py b/tests/python/relax/test_dataflow_in_place.py new file mode 100644 index 000000000000..54f7112e9e47 --- /dev/null +++ b/tests/python/relax/test_dataflow_in_place.py @@ -0,0 +1,262 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm.relax.analysis import dataflow_liveness_analysis, dataflow_alias_analysis +from tvm.script.parser import ir as I, relax as R, tir as T + + +def test_liveness_analysis(): + @I.ir_module + class BasicLiveness: + @R.function + def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + with R.dataflow(): + y = R.const(1, dtype="int32") + z = R.add(x, y) + q = R.multiply(z, y) + p = R.add(z, q) + n = R.multiply(p, p) + R.output(n) + return n + + block = BasicLiveness["main"].body.blocks[0] + live_ranges = dataflow_liveness_analysis(block) + expected_ranges = { + "x": (-1, 1), + "y": (0, 2), + "z": (1, 3), + "q": (2, 3), + "p": (3, 4), + "n": (4, 5), + } + for var, live_range in live_ranges.items(): + assert live_range == expected_ranges[var.name_hint] + + +def test_alias_analysis_basic(): + @I.ir_module + class BasicAliasAnalysis: + @R.function + def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + with R.dataflow(): + y = x # y is an alias of x + z = R.add(y, y) # fresh value + n = z # alias of z + R.output(n) + return n + + block = BasicAliasAnalysis["main"].body.blocks[0] + alias_sets = dataflow_alias_analysis(block, BasicAliasAnalysis["main"].params) + expected = { + "x": {0}, + "y": {0}, + "z": {1}, + "n": {1}, + } + + for var, alias_set in alias_sets.items(): + assert alias_set == expected[var.name_hint] + + +def test_alias_analysis_tuple(): + @I.ir_module + class AliasesWithTuples: + @R.function + def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + with R.dataflow(): + y = R.const(1, dtype="int32") + t = (x, y) + a = t[0] + b = t[1] + c = t[0] + d = t[1] + u = t + e = t[0] + f = t[1] + z = R.add(c, d) + n = z + R.output(n) + return n + + block = AliasesWithTuples["main"].body.blocks[0] + alias_sets = dataflow_alias_analysis(block, AliasesWithTuples["main"].params) + expected = { + "x": {0}, + "y": {1}, + "t": {2}, + "a": {0}, + "b": {1}, + "c": {0}, + "d": {1}, + "u": {2}, + "e": {0}, + "f": {1}, + "z": {3}, + "n": {3}, + } + + for var, alias_set in alias_sets.items(): + assert alias_set == expected[var.name_hint] + + +def test_alias_split(): + @I.ir_module + class AliasSplit: + @R.function + def main(x: R.Tensor((60,), "int32")) -> R.Tensor((15,), "int32"): + with R.dataflow(): + t = R.split(x, 4) + y = t[0] + z = t[1] + q = t[2] + p = t[3] + n = z + R.output(n) + return n + + block = AliasSplit["main"].body.blocks[0] + alias_sets = dataflow_alias_analysis(block, AliasSplit["main"].params) + expected = { + "x": {0}, + "t": {1}, + "y": {2}, + "z": {3}, + "q": {4}, + "p": {5}, + "n": {3}, + } + + for var, alias_set in alias_sets.items(): + assert alias_set == expected[var.name_hint] + + +def test_alias_call_tir(): + # call TIR can yield either a single tensor or a tuple + @I.ir_module + class AliasCallTir: + @T.prim_func + def tir_id(x: T.handle, y: T.handle) -> None: + T.func_attr({"global_symbol": "tir_id"}) + m = T.int32() + n = T.int32() + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (m, n)) + + for i, j in T.grid(m, n): + with T.block("id"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + + @T.prim_func + def tir_id2(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_id"}) + m = T.int32() + n = T.int32() + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (m, n)) + C = T.match_buffer(z, (m, n)) + + for i, j in T.grid(m, n): + with T.block("id"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + C[vi, vj] = A[vi, vj] + + @R.function + def main(x: R.Tensor((10, 10), "int32")) -> R.Tensor((10, 10), "int32"): + with R.dataflow(): + cls = AliasCallTir + y = R.call_tir(cls.tir_id, (x,), out_sinfo=R.Tensor((10, 10), "int32")) + t = R.call_tir( + cls.tir_id2, + (y,), + out_sinfo=[R.Tensor((10, 10), "int32"), R.Tensor((10, 10), "int32")], + ) + z = y + p = t[0] + q = t[1] + u = t + m = u[0] + n = u[1] + v = n + R.output(v) + return v + + block = AliasCallTir["main"].body.blocks[0] + alias_sets = dataflow_alias_analysis(block, AliasCallTir["main"].params) + expected = { + "x": {0}, + "y": {1}, + "t": {2}, + "z": {1}, + "p": {3}, + "q": {4}, + "u": {2}, + "m": {3}, + "n": {4}, + "v": {4}, + } + + for var, alias_set in alias_sets.items(): + assert alias_set == expected[var.name_hint] + + +def test_mystery_calls(): + @I.ir_module + class AliasChaosCalls: + @R.function + def identity(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + return x + + @R.function + def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + with R.dataflow(): + cls = AliasChaosCalls + y = cls.identity(x) + z = cls.identity(y) + m = R.const(1, dtype="int32") + n = R.const(2, dtype="int32") + t = (m, n) + a = R.call_pure_packed( + "chaos", t, sinfo_args=R.Tuple(R.Tensor((), "int32"), R.Tensor((), "int32")) + ) + b = a[0] + c = a[1] + R.output(c) + return c + + block = AliasChaosCalls["main"].body.blocks[0] + alias_sets = dataflow_alias_analysis(block, AliasChaosCalls["main"].params) + expected = { + "x": {0}, + "y": {0, 1}, + "z": {0, 1, 2}, + "m": {3}, + "n": {4}, + "t": {5}, + "a": {0, 1, 2, 3, 4, 5, 6, 7, 8}, + "b": {-1}, # because a can be many things, b is unknown + "c": {-1}, # because a can be many things, c is unknown + } + + for var, alias_set in alias_sets.items(): + assert alias_set == expected[var.name_hint] + + +if __name__ == "__main__": + tvm.testing.main() From 3cc3c9b4cb7449dd8e33ff293f22f1de045e7bd7 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 3 Oct 2023 20:37:08 -0400 Subject: [PATCH 04/55] Include in-place analysis --- python/tvm/relax/analysis/analysis.py | 15 +- src/relax/transform/dataflow_in_place.cc | 277 ++++++++++++++++++- tests/python/relax/test_dataflow_in_place.py | 24 +- 3 files changed, 298 insertions(+), 18 deletions(-) diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index ddee364c9490..fe87f9a723eb 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -539,9 +539,14 @@ def dataflow_liveness_analysis(block: DataflowBlock) -> Dict[Var, Tuple[int, int return ret # type: ignore -def dataflow_alias_analysis(block: DataflowBlock, inputs: List[Var]) -> Dict[Var, Set[int]]: - alias_sets = _ffi_api.DataflowAliasAnalysis(block, inputs) # type: ignore - ret = {} +def dataflow_alias_analysis( + block: DataflowBlock, inputs: List[Var] +) -> Tuple[Dict[Var, Set[int]], Dict[int, List[Set[int]]]]: + alias_sets, tuple_map = _ffi_api.DataflowAliasAnalysis(block, inputs) # type: ignore + res_alias_sets = {} + res_tuple_map = {} for (var, alias_set) in alias_sets.items(): - ret[var] = set(alias_set) - return ret # type: ignore + res_alias_sets[var] = set(alias_set) + for (idx, elem_alias_sets) in tuple_map.items(): + res_tuple_map[idx] = [set(alias_set) for alias_set in elem_alias_sets] + return res_alias_sets, res_tuple_map # type: ignore diff --git a/src/relax/transform/dataflow_in_place.cc b/src/relax/transform/dataflow_in_place.cc index 02d3e36b694d..21e0bb5b2d0a 100644 --- a/src/relax/transform/dataflow_in_place.cc +++ b/src/relax/transform/dataflow_in_place.cc @@ -76,8 +76,9 @@ class AliasAnalyzer { // alias: map of var to memory locations (we will call these indices and use -1 as an index for // "unknown") - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> Analyze( - const DataflowBlock& block, const Array& inputs) { + std::pair, ObjectPtrHash, ObjectPtrEqual>, + std::unordered_map>>> + Analyze(const DataflowBlock& block, const Array& inputs) { for (auto input : inputs) { int curr_idx = get_fresh_idx(); alias_map_[input] = {curr_idx}; @@ -99,7 +100,7 @@ class AliasAnalyzer { alias_map_[current_var] = get_alias_set(value, current_var); } - return alias_map_; + return {alias_map_, tuple_map_}; } private: @@ -253,11 +254,250 @@ class AliasAnalyzer { int mem_idx_; }; +int shape_size(const ShapeExpr& shape) { + int ret = 1; + for (auto dim : shape->values) { + if (auto int_dim = dim.as()) { + ret *= static_cast(int_dim->value); + } else { + return -1; + } + } + return ret; +} + +std::pair size_matches(const StructInfo& target_info, const StructInfo& arg_info) { + if (target_info.as() && arg_info.as()) { + auto target_tensor = Downcast(target_info); + auto arg_tensor = Downcast(arg_info); + if (target_tensor->shape.defined() && target_tensor->shape.as() && + arg_tensor->shape.defined() && arg_tensor->shape.as()) { + auto target_shape = Downcast(target_tensor->shape); + auto arg_shape = Downcast(arg_tensor->shape); + int target_size = shape_size(target_shape); + int arg_size = shape_size(arg_shape); + if (target_size == -1 || arg_size == -1 || target_size != arg_size) { + return {false, false}; + } + // exact match: number of dims and each dim matches + if (target_shape->values.size() == arg_shape->values.size()) { + for (size_t i = 0; i < target_shape->values.size(); i++) { + if (Downcast(target_shape->values[i])->value != + Downcast(arg_shape->values[i])->value) { + return {true, false}; + } + } + return {true, true}; + } + return {true, false}; + } else { + return {false, false}; + } + } else if (target_info.as() && arg_info.as()) { + auto target_tup = Downcast(target_info); + auto arg_tup = Downcast(arg_info); + if (target_tup->fields.size() != arg_tup->fields.size()) { + return {false, false}; + } + bool all_exact = true; + for (size_t i = 0; i < target_tup->fields.size(); i++) { + auto element_match = size_matches(target_tup->fields[i], arg_tup->fields[i]); + if (!element_match.first) { + return {false, false}; + } + if (!element_match.second) { + all_exact = false; + } + } + return {true, all_exact}; + } else if (target_info.as() && arg_info.as()) { + return {true, true}; + } else { + return {false, false}; + } +} + +bool intersecting_live_aliases( + std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& live_ranges, + std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& alias_sets, + std::unordered_map>>& tuple_map, + std::unordered_set& currently_live, const Expr& target, + int idx) { + if (auto* var_node = target.as()) { + auto current_var = GetRef(var_node); + // no entry for the current var -> it must be something external and we have to assume the worst + if (!alias_sets.count(current_var)) { + return true; + } + auto alias_set = alias_sets[current_var]; + // -1 -> an external value and we must assume the worst + if (alias_set.count(-1)) { + return true; + } + std::vector> sets_to_check = {alias_set}; + std::unordered_set indices_checked; + // if a possible alias is a tuple, we will also check for aliases of the members + for (int alias_idx : alias_set) { + if (tuple_map.count(alias_idx)) { + for (auto member_set : tuple_map[alias_idx]) { + if (member_set.count(-1)) { + return true; + } + sets_to_check.push_back(member_set); + } + } + } + + for (Var other_var : currently_live) { + if (!alias_sets.count(other_var) || !live_ranges.count(other_var)) { + return true; + } + // var is not live past this point => don't need to worry + if (live_ranges[other_var].second <= idx) { + continue; + } + auto other_alias_set = alias_sets[other_var]; + for (int alias_idx : other_alias_set) { + for (auto check_set : sets_to_check) { + if (check_set.count(alias_idx)) { + return true; + } + } + } + } + return false; + } else if (auto* tup_node = target.as()) { + for (auto field : tup_node->fields) { + if (intersecting_live_aliases(live_ranges, alias_sets, tuple_map, currently_live, field, + idx)) { + return true; + } + } + return false; + } else { + return false; + } +} + // check for in-place eligibility: // 1. see if there's an arg big enough to hold the result // 2. see if the arg is live past the call // 3. see if the arg has an alias that's live past the call // if conditions are met, we're good to go +void find_inplace_opportunities(const DataflowBlock& block, const Array& inputs) { + auto live_ranges = analyze_liveness(block); + AliasAnalyzer analyzer; + auto alias_info = analyzer.Analyze(block, inputs); + auto alias_sets = alias_info.first; + auto tuple_map = alias_info.second; + + // sort the live ranges by starting index + std::vector live_order; + for (auto kv : live_ranges) { + live_order.push_back(kv.first); + } + std::sort(live_order.begin(), live_order.end(), + [&live_ranges](const Var& var1, const Var& var2) -> bool { + return live_ranges[var1].first < live_ranges[var2].first; + }); + + std::unordered_set currently_live; + for (auto var : live_order) { + auto live_range = live_ranges[var]; + if (live_range.first > 0) { + break; + } + currently_live.insert(var); + } + + for (size_t i = 0; i < block->bindings.size(); i++) { + // if we reach a binding check the conditions + Binding b = block->bindings[i]; + Var defined_var = b->var; + Expr value; + if (const auto* var_binding = b.as()) { + value = var_binding->value; + } else if (const auto* match_binding = b.as()) { + value = match_binding->value; + } else { + CHECK(false) << "Invalid binding"; // impossible + } + + if (auto* call_node = value.as()) { + if (call_node->op.as()) { + std::unordered_set candidates; + std::unordered_set exact_match_candidates; + + // 1. Check that at least one argument matches size with the result + for (auto arg : call_node->args) { + std::pair match = + size_matches(GetStructInfo(defined_var), GetStructInfo(arg)); + if (match.first) { + candidates.insert(arg); + if (match.second) { + exact_match_candidates.insert(arg); + } + } + } + if (!candidates.size()) { + continue; + } + + // 2. Make sure at least one candidate is not live past this point + std::unordered_set remove_candidates; + for (auto candidate : candidates) { + // only var nodes need to be checked; other leaf exprs (e.g., tuples) are live + // only in the current binding unless they're bound + if (auto* var_node = candidate.as()) { + // live past the current binding -> remove from candidates + auto arg_var = GetRef(var_node); + if (live_ranges.count(arg_var)) { + auto live_range = live_ranges[arg_var]; + if (live_range.second > static_cast(i)) { + remove_candidates.insert(candidate); + } + } + } + } + candidates.erase(remove_candidates.begin(), remove_candidates.end()); + if (!candidates.size()) { + continue; + } + + // 3. Make sure at least one candidate does not have an alias live past this point + remove_candidates.clear(); + for (auto candidate : candidates) { + if (intersecting_live_aliases(live_ranges, alias_sets, tuple_map, currently_live, + candidate, i)) { + remove_candidates.insert(candidate); + } + } + candidates.erase(remove_candidates.begin(), remove_candidates.end()); + + // if we have a candidate, then this can be made in-place. Report the result + std::cout << "Operation " << i << " (" << value << ") can be made in-place"; + for (auto candidate : candidates) { + if (exact_match_candidates.count(candidate)) { + std::cout << " (exact dimension match)"; + break; + } + } + std::cout << std::endl; + } + } + + // remove vars whose range has come to an end + // (keep a separate set to avoid changing the sit while iterating on it) + std::unordered_set remove; + for (auto var : currently_live) { + auto live_range = live_ranges[var]; + if (live_range.second <= static_cast(i)) { + remove.insert(var); + } + } + currently_live.erase(remove.begin(), remove.end()); + } +} // export for testing namespace transform { @@ -271,23 +511,44 @@ Map> DataflowLivenessAnalysis(const DataflowBlock& block) { return ret; } -Map> DataflowAliasAnalysis(const DataflowBlock& block, Array inputs) { +Array DataflowAliasAnalysis(const DataflowBlock& block, Array inputs) { AliasAnalyzer analyzer; - auto alias_sets = analyzer.Analyze(block, inputs); - Map> ret; + auto res = analyzer.Analyze(block, inputs); + auto alias_sets = res.first; + auto tuple_map = res.second; + Map> new_alias_sets; + Map>> new_tuple_map; for (auto kv : alias_sets) { Array aliases; for (auto alias : kv.second) { aliases.push_back(alias); } - ret.Set(kv.first, aliases); + new_alias_sets.Set(kv.first, aliases); } - return ret; + for (auto kv : tuple_map) { + Array> elem_aliases; + for (auto alias_set : kv.second) { + Array dim_aliases; + for (auto alias : alias_set) { + dim_aliases.push_back(alias); + } + elem_aliases.push_back(dim_aliases); + } + new_tuple_map.Set(kv.first, elem_aliases); + } + return {new_alias_sets, new_tuple_map}; +} + +void DataflowInPlaceAnalysis(const DataflowBlock& block, const Array& inputs) { + relax::find_inplace_opportunities(block, inputs); } TVM_REGISTER_GLOBAL("relax.analysis.DataflowLivenessAnalysis") .set_body_typed(DataflowLivenessAnalysis); TVM_REGISTER_GLOBAL("relax.analysis.DataflowAliasAnalysis").set_body_typed(DataflowAliasAnalysis); +TVM_REGISTER_GLOBAL("relax.analysis.DataflowInPlaceAnalasis") + .set_body_typed(DataflowInPlaceAnalysis); + } // namespace transform } // namespace relax } // namespace tvm \ No newline at end of file diff --git a/tests/python/relax/test_dataflow_in_place.py b/tests/python/relax/test_dataflow_in_place.py index 54f7112e9e47..6d0fbe0eef46 100644 --- a/tests/python/relax/test_dataflow_in_place.py +++ b/tests/python/relax/test_dataflow_in_place.py @@ -61,7 +61,7 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): return n block = BasicAliasAnalysis["main"].body.blocks[0] - alias_sets = dataflow_alias_analysis(block, BasicAliasAnalysis["main"].params) + alias_sets, tuple_map = dataflow_alias_analysis(block, BasicAliasAnalysis["main"].params) expected = { "x": {0}, "y": {0}, @@ -71,6 +71,7 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): for var, alias_set in alias_sets.items(): assert alias_set == expected[var.name_hint] + assert tuple_map == {} def test_alias_analysis_tuple(): @@ -94,7 +95,7 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): return n block = AliasesWithTuples["main"].body.blocks[0] - alias_sets = dataflow_alias_analysis(block, AliasesWithTuples["main"].params) + alias_sets, tuple_map = dataflow_alias_analysis(block, AliasesWithTuples["main"].params) expected = { "x": {0}, "y": {1}, @@ -112,6 +113,8 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): for var, alias_set in alias_sets.items(): assert alias_set == expected[var.name_hint] + assert 2 in tuple_map + assert tuple_map[2] == [{0}, {1}] def test_alias_split(): @@ -130,7 +133,7 @@ def main(x: R.Tensor((60,), "int32")) -> R.Tensor((15,), "int32"): return n block = AliasSplit["main"].body.blocks[0] - alias_sets = dataflow_alias_analysis(block, AliasSplit["main"].params) + alias_sets, tuple_map = dataflow_alias_analysis(block, AliasSplit["main"].params) expected = { "x": {0}, "t": {1}, @@ -143,6 +146,9 @@ def main(x: R.Tensor((60,), "int32")) -> R.Tensor((15,), "int32"): for var, alias_set in alias_sets.items(): assert alias_set == expected[var.name_hint] + assert len(tuple_map) == 1 + assert 1 in tuple_map + assert tuple_map[1] == [{2}, {3}, {4}, {5}] def test_alias_call_tir(): @@ -198,7 +204,7 @@ def main(x: R.Tensor((10, 10), "int32")) -> R.Tensor((10, 10), "int32"): return v block = AliasCallTir["main"].body.blocks[0] - alias_sets = dataflow_alias_analysis(block, AliasCallTir["main"].params) + alias_sets, tuple_map = dataflow_alias_analysis(block, AliasCallTir["main"].params) expected = { "x": {0}, "y": {1}, @@ -214,6 +220,9 @@ def main(x: R.Tensor((10, 10), "int32")) -> R.Tensor((10, 10), "int32"): for var, alias_set in alias_sets.items(): assert alias_set == expected[var.name_hint] + assert len(tuple_map) == 1 + assert 2 in tuple_map + assert tuple_map[2] == [{3}, {4}] def test_mystery_calls(): @@ -241,7 +250,7 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): return c block = AliasChaosCalls["main"].body.blocks[0] - alias_sets = dataflow_alias_analysis(block, AliasChaosCalls["main"].params) + alias_sets, tuple_map = dataflow_alias_analysis(block, AliasChaosCalls["main"].params) expected = { "x": {0}, "y": {0, 1}, @@ -256,6 +265,11 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): for var, alias_set in alias_sets.items(): assert alias_set == expected[var.name_hint] + assert len(tuple_map) == 2 + assert 5 in tuple_map + assert tuple_map[5] == [{3}, {4}] + assert 6 in tuple_map + assert tuple_map[6] == [{7}, {8}] if __name__ == "__main__": From 441ead2a3d55eb74d64c4c916b38e76d9f98835c Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 3 Oct 2023 20:46:35 -0400 Subject: [PATCH 05/55] Return the lists instead --- src/relax/transform/dataflow_in_place.cc | 29 +++++++++++++++++++----- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/src/relax/transform/dataflow_in_place.cc b/src/relax/transform/dataflow_in_place.cc index 21e0bb5b2d0a..dbc1bacde604 100644 --- a/src/relax/transform/dataflow_in_place.cc +++ b/src/relax/transform/dataflow_in_place.cc @@ -384,13 +384,17 @@ bool intersecting_live_aliases( // 2. see if the arg is live past the call // 3. see if the arg has an alias that's live past the call // if conditions are met, we're good to go -void find_inplace_opportunities(const DataflowBlock& block, const Array& inputs) { +std::pair, std::vector> find_inplace_opportunities(const DataflowBlock& block, + const Array& inputs) { auto live_ranges = analyze_liveness(block); AliasAnalyzer analyzer; auto alias_info = analyzer.Analyze(block, inputs); auto alias_sets = alias_info.first; auto tuple_map = alias_info.second; + std::vector size_match_list; + std::vector exact_match_list; + // sort the live ranges by starting index std::vector live_order; for (auto kv : live_ranges) { @@ -475,14 +479,15 @@ void find_inplace_opportunities(const DataflowBlock& block, const Array& in candidates.erase(remove_candidates.begin(), remove_candidates.end()); // if we have a candidate, then this can be made in-place. Report the result - std::cout << "Operation " << i << " (" << value << ") can be made in-place"; + if (candidates.size()) { + size_match_list.push_back(i); + } for (auto candidate : candidates) { if (exact_match_candidates.count(candidate)) { - std::cout << " (exact dimension match)"; + exact_match_list.push_back(i); break; } } - std::cout << std::endl; } } @@ -497,6 +502,8 @@ void find_inplace_opportunities(const DataflowBlock& block, const Array& in } currently_live.erase(remove.begin(), remove.end()); } + + return {size_match_list, exact_match_list}; } // export for testing @@ -539,8 +546,18 @@ Array DataflowAliasAnalysis(const DataflowBlock& block, Array in return {new_alias_sets, new_tuple_map}; } -void DataflowInPlaceAnalysis(const DataflowBlock& block, const Array& inputs) { - relax::find_inplace_opportunities(block, inputs); +Array> DataflowInPlaceAnalysis(const DataflowBlock& block, + const Array& inputs) { + auto index_lists = relax::find_inplace_opportunities(block, inputs); + Array size_match_array; + for (int index : index_lists.first) { + size_match_array.push_back(index); + } + Array exact_match_array; + for (int index : index_lists.second) { + exact_match_array.push_back(index); + } + return {size_match_array, exact_match_array}; } TVM_REGISTER_GLOBAL("relax.analysis.DataflowLivenessAnalysis") From 71d32a1ce8073a4238e3fadf0e9b172b314bf84e Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 3 Oct 2023 21:50:41 -0400 Subject: [PATCH 06/55] Update python binding --- python/tvm/relax/analysis/analysis.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index fe87f9a723eb..29eb760eb7f2 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -550,3 +550,10 @@ def dataflow_alias_analysis( for (idx, elem_alias_sets) in tuple_map.items(): res_tuple_map[idx] = [set(alias_set) for alias_set in elem_alias_sets] return res_alias_sets, res_tuple_map # type: ignore + + +def dataflow_inplace_analysis( + block: DataflowBlock, inputs: List[Var] +) -> Tuple[List[int], List[int]]: + index_lists = _ffi_api.DataflowInPlaceAnalysis(block, inputs) # type: ignore + return tuple(index_lists) # type: ignore From 5a3d3f21a49fc1edecfa16a528f95eeb85f2c4d9 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 5 Oct 2023 16:05:57 -0400 Subject: [PATCH 07/55] No need to assume *pure* functions capture all values ever passed to them. Also use pointers instead of non-const refs --- src/relax/transform/dataflow_in_place.cc | 14 +++++++------- tests/python/relax/test_dataflow_in_place.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/relax/transform/dataflow_in_place.cc b/src/relax/transform/dataflow_in_place.cc index dbc1bacde604..7dc765284986 100644 --- a/src/relax/transform/dataflow_in_place.cc +++ b/src/relax/transform/dataflow_in_place.cc @@ -124,20 +124,20 @@ class AliasAnalyzer { // capture the given index and also its tuple components (including recursively) // if they exist - void add_to_captured_set(int idx) { - captured_by_functions_.insert(idx); + void add_captured_indices(std::unordered_set* captured_set, int idx) { + captured_set->insert(idx); if (tuple_map_.count(idx)) { for (auto comp_set : tuple_map_[idx]) { for (auto tup_comp_idx : comp_set) { - add_to_captured_set(tup_comp_idx); + add_captured_indices(captured_set, tup_comp_idx); } } } } // Conservative extremely pessimistic assumption: - // assume that the result of a non-op call can be aliased to anything - // ever passed to or returned from any non-op call. + // assume that the result of a non-op call can be aliased to any argument + // or that it could be a newly allocated value. // For tuples, assume all members are aliased. Yeah, it's bad. // (Skip first arg is for handling call_pure_packed, where the first arg is an ExternFunc that we // should ignore) @@ -150,13 +150,13 @@ class AliasAnalyzer { if (auto* tup_info_node = GetStructInfoAs(bound_var)) { insert_fresh_tuple(res_idx, tup_info_node); } - add_to_captured_set(res_idx); + add_captured_indices(&ret, res_idx); for (size_t i = (skip_first_arg) ? 1 : 0; i < call_node->args.size(); i++) { auto arg = call_node->args[i]; auto arg_alias_set = get_alias_set(arg, bound_var); for (int alias_idx : arg_alias_set) { - add_to_captured_set(alias_idx); + add_captured_indices(&ret, alias_idx); } } ret.insert(captured_by_functions_.begin(), captured_by_functions_.end()); diff --git a/tests/python/relax/test_dataflow_in_place.py b/tests/python/relax/test_dataflow_in_place.py index 6d0fbe0eef46..7aa12b035d78 100644 --- a/tests/python/relax/test_dataflow_in_place.py +++ b/tests/python/relax/test_dataflow_in_place.py @@ -258,7 +258,7 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): "m": {3}, "n": {4}, "t": {5}, - "a": {0, 1, 2, 3, 4, 5, 6, 7, 8}, + "a": {3, 4, 5, 6, 7, 8}, "b": {-1}, # because a can be many things, b is unknown "c": {-1}, # because a can be many things, c is unknown } From 632c773822d01d63615b8a661688ea81befedbfb Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 5 Oct 2023 17:12:41 -0400 Subject: [PATCH 08/55] Improve handling of tuples in mystery call case --- src/relax/transform/dataflow_in_place.cc | 23 +++++++++++++++++++- tests/python/relax/test_dataflow_in_place.py | 9 ++++---- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/src/relax/transform/dataflow_in_place.cc b/src/relax/transform/dataflow_in_place.cc index 7dc765284986..3a464ee77fad 100644 --- a/src/relax/transform/dataflow_in_place.cc +++ b/src/relax/transform/dataflow_in_place.cc @@ -122,6 +122,25 @@ class AliasAnalyzer { tuple_map_[tup_idx] = tuple_set; } + void update_tuple_components(int tup_idx, const std::unordered_set& insert_idxs) { + if (tuple_map_.count(tup_idx)) { + auto tuple_comps = tuple_map_[tup_idx]; + for (size_t i = 0; i < tuple_comps.size(); i++) { + auto comp_set = tuple_comps[i]; + + // if a member is a tuple, update its components as well + for (int member : comp_set) { + if (tuple_map_.count(member)) { + update_tuple_components(member, insert_idxs); + } + } + + // update after iterating to avoid iterating over the inserted elements + tuple_map_[tup_idx][i].insert(insert_idxs.begin(), insert_idxs.end()); + } + } + } + // capture the given index and also its tuple components (including recursively) // if they exist void add_captured_indices(std::unordered_set* captured_set, int idx) { @@ -159,7 +178,9 @@ class AliasAnalyzer { add_captured_indices(&ret, alias_idx); } } - ret.insert(captured_by_functions_.begin(), captured_by_functions_.end()); + // if the result is a tuple, the components can also potentially be aliased to any arg + // or, in fact, to each other + update_tuple_components(res_idx, ret); return ret; } diff --git a/tests/python/relax/test_dataflow_in_place.py b/tests/python/relax/test_dataflow_in_place.py index 7aa12b035d78..b5e97fdb1f89 100644 --- a/tests/python/relax/test_dataflow_in_place.py +++ b/tests/python/relax/test_dataflow_in_place.py @@ -258,9 +258,10 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): "m": {3}, "n": {4}, "t": {5}, - "a": {3, 4, 5, 6, 7, 8}, - "b": {-1}, # because a can be many things, b is unknown - "c": {-1}, # because a can be many things, c is unknown + "a": {3, 4, 5, 6, 7, 8}, # either t or a fresh tuple + "b": {3, 4, 5, 6, 7, 8}, # the tuple components can be aliased to any member... + "c": {3, 4, 5, 6, 7, 8}, # the tuple components can be aliased to any member... + # (in principle, we can use type information to narrow down the aliasing) } for var, alias_set in alias_sets.items(): @@ -269,7 +270,7 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): assert 5 in tuple_map assert tuple_map[5] == [{3}, {4}] assert 6 in tuple_map - assert tuple_map[6] == [{7}, {8}] + assert tuple_map[6] == [{3, 4, 5, 6, 7, 8}, {3, 4, 5, 6, 7, 8}] if __name__ == "__main__": From d97ab92c700e97e8d1048af4e13a216b8621b16f Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 5 Oct 2023 17:20:32 -0400 Subject: [PATCH 09/55] Corrections to inplace checking --- src/relax/transform/dataflow_in_place.cc | 146 +++++++++++++---------- 1 file changed, 83 insertions(+), 63 deletions(-) diff --git a/src/relax/transform/dataflow_in_place.cc b/src/relax/transform/dataflow_in_place.cc index 3a464ee77fad..fcb7de77b147 100644 --- a/src/relax/transform/dataflow_in_place.cc +++ b/src/relax/transform/dataflow_in_place.cc @@ -72,7 +72,7 @@ std::unordered_map, ObjectPtrHash, ObjectPtrEqual> anal class AliasAnalyzer { public: - explicit AliasAnalyzer() : alias_map_(), tuple_map_(), captured_by_functions_(), mem_idx_(0) {} + explicit AliasAnalyzer() : alias_map_(), tuple_map_(), mem_idx_(0) {} // alias: map of var to memory locations (we will call these indices and use -1 as an index for // "unknown") @@ -219,17 +219,20 @@ class AliasAnalyzer { tuple_map_[tup_idx] = new_tuple_map; } else if (auto* target_tgi = value.as()) { std::unordered_set tuple_set = get_alias_set(target_tgi->tuple, bound_var); - // if there's only one possibility for the tuple and it's in the tuple map, - // index into it - if (tuple_set.size() == 1) { - int index = *(tuple_set.begin()); - if (tuple_map_.count(index)) { - return tuple_map_[index][target_tgi->index]; - } else { - ret.insert(-1); - } - } else { + // if -1 is a member of the tuple set, then we have to assume the result is -1 + if (tuple_set.count(-1)) { ret.insert(-1); + } else { + // otherwise, consider all members that are tuples of appropriate size and index into them + // (this is safe because the type system will ensure we're not indexing into a tuple + // of the wrong size) + for (int member : tuple_set) { + if (tuple_map_.count(member) && + static_cast(tuple_map_[member].size()) > target_tgi->index) { + auto member_set = tuple_map_[member][target_tgi->index]; + ret.insert(member_set.begin(), member_set.end()); + } + } } } else if (auto* call_node = value.as()) { if (auto* op_node = call_node->op.as()) { @@ -271,7 +274,6 @@ class AliasAnalyzer { std::unordered_map, ObjectPtrHash, ObjectPtrEqual> alias_map_; std::unordered_map>> tuple_map_; - std::unordered_set captured_by_functions_; int mem_idx_; }; @@ -338,65 +340,102 @@ std::pair size_matches(const StructInfo& target_info, const StructIn } } -bool intersecting_live_aliases( - std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& live_ranges, - std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& alias_sets, - std::unordered_map>>& tuple_map, - std::unordered_set& currently_live, const Expr& target, - int idx) { +// Given an alias index, check if it's a tuple and gather the sets of aliases for the tuple +// members if so (apply recursively if any of those members are tuples). +// Return false if the alias set contains -1, meaning a reference to an unknown or +// possibly dangerous value (no checking we can do for that). +bool gather_sets_to_check_for_liveness( + const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& + alias_sets, + const std::unordered_map>>& tuple_map, + std::vector>* sets_to_check, int alias_idx) { + if (tuple_map.count(alias_idx)) { + for (auto member_set : tuple_map.at(alias_idx)) { + // contains -1 -> unknown and dangerous, we can short-circuit + if (member_set.count(-1)) { + return false; + } + sets_to_check->push_back(member_set); + + // if a member can be a tuple, check it recursively + for (int member : member_set) { + if (tuple_map.count(member)) { + if (!gather_sets_to_check_for_liveness(alias_sets, tuple_map, sets_to_check, member)) { + return false; + } + } + } + } + } + return true; +} + +// check that the target is not live past the index and that no alias of it is live past the +// binding index (if the target is a tuple, check the conditions recursively for the members) +bool df_inplace_conditions_met( + const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& live_ranges, + const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& + alias_sets, + const std::unordered_map>>& tuple_map, + const std::unordered_set& currently_live, + const Expr& target, int idx) { if (auto* var_node = target.as()) { auto current_var = GetRef(var_node); + // if the var is live past this point, we can't use it for in-place computations anyway + if (live_ranges.count(current_var)) { + auto live_range = live_ranges.at(current_var); + if (live_range.second > idx) { + return false; + } + } + // no entry for the current var -> it must be something external and we have to assume the worst if (!alias_sets.count(current_var)) { - return true; + return false; } - auto alias_set = alias_sets[current_var]; + auto alias_set = alias_sets.at(current_var); // -1 -> an external value and we must assume the worst if (alias_set.count(-1)) { - return true; + return false; } std::vector> sets_to_check = {alias_set}; std::unordered_set indices_checked; - // if a possible alias is a tuple, we will also check for aliases of the members + // If a possible alias is a tuple, we will also check for aliases of the members + // (possibly recursively) for (int alias_idx : alias_set) { - if (tuple_map.count(alias_idx)) { - for (auto member_set : tuple_map[alias_idx]) { - if (member_set.count(-1)) { - return true; - } - sets_to_check.push_back(member_set); - } + if (!gather_sets_to_check_for_liveness(alias_sets, tuple_map, &sets_to_check, alias_idx)) { + return false; } } for (Var other_var : currently_live) { if (!alias_sets.count(other_var) || !live_ranges.count(other_var)) { - return true; + return false; } // var is not live past this point => don't need to worry - if (live_ranges[other_var].second <= idx) { + if (live_ranges.at(other_var).second <= idx) { continue; } - auto other_alias_set = alias_sets[other_var]; + auto other_alias_set = alias_sets.at(other_var); for (int alias_idx : other_alias_set) { for (auto check_set : sets_to_check) { if (check_set.count(alias_idx)) { - return true; + return false; } } } } - return false; + return true; } else if (auto* tup_node = target.as()) { for (auto field : tup_node->fields) { - if (intersecting_live_aliases(live_ranges, alias_sets, tuple_map, currently_live, field, - idx)) { - return true; + if (!df_inplace_conditions_met(live_ranges, alias_sets, tuple_map, currently_live, field, + idx)) { + return false; } } - return false; + return true; } else { - return false; + return true; } } @@ -453,7 +492,7 @@ std::pair, std::vector> find_inplace_opportunities(const D std::unordered_set candidates; std::unordered_set exact_match_candidates; - // 1. Check that at least one argument matches size with the result + // Check that at least one argument matches size with the result for (auto arg : call_node->args) { std::pair match = size_matches(GetStructInfo(defined_var), GetStructInfo(arg)); @@ -468,32 +507,13 @@ std::pair, std::vector> find_inplace_opportunities(const D continue; } - // 2. Make sure at least one candidate is not live past this point + // Make sure at least one candidate is not live past this point and does not have an alias + // live past this point std::unordered_set remove_candidates; - for (auto candidate : candidates) { - // only var nodes need to be checked; other leaf exprs (e.g., tuples) are live - // only in the current binding unless they're bound - if (auto* var_node = candidate.as()) { - // live past the current binding -> remove from candidates - auto arg_var = GetRef(var_node); - if (live_ranges.count(arg_var)) { - auto live_range = live_ranges[arg_var]; - if (live_range.second > static_cast(i)) { - remove_candidates.insert(candidate); - } - } - } - } - candidates.erase(remove_candidates.begin(), remove_candidates.end()); - if (!candidates.size()) { - continue; - } - - // 3. Make sure at least one candidate does not have an alias live past this point remove_candidates.clear(); for (auto candidate : candidates) { - if (intersecting_live_aliases(live_ranges, alias_sets, tuple_map, currently_live, - candidate, i)) { + if (!df_inplace_conditions_met(live_ranges, alias_sets, tuple_map, currently_live, + candidate, i)) { remove_candidates.insert(candidate); } } From 3f099df8334814a9f543597d7a4a4aae25d404fe Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 5 Oct 2023 21:19:14 -0400 Subject: [PATCH 10/55] Add test case for mystery value --- tests/python/relax/test_dataflow_in_place.py | 34 ++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/python/relax/test_dataflow_in_place.py b/tests/python/relax/test_dataflow_in_place.py index b5e97fdb1f89..ade7ef17285f 100644 --- a/tests/python/relax/test_dataflow_in_place.py +++ b/tests/python/relax/test_dataflow_in_place.py @@ -273,5 +273,39 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): assert tuple_map[6] == [{3, 4, 5, 6, 7, 8}, {3, 4, 5, 6, 7, 8}] +def test_alias_external_value(): + @I.ir_module + class AliasExternalValue: + @R.function + def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + y = R.const(1, dtype="int32") # not in DF block, treated as external + t1 = (y, y) # not in DF block, treated as external + with R.dataflow(): + z = y # mystery value + a = R.const(2, dtype="int32") + t2 = (z, a) + b = t2[0] + c = t1[1] # tuple index into external value + R.output(b) + return b + + block = AliasExternalValue["main"].body.blocks[1] + alias_sets, tuple_map = dataflow_alias_analysis(block, AliasExternalValue["main"].params) + expected = { + "x": {0}, + "z": {-1}, + "a": {1}, + "t2": {2}, + "b": {-1}, + "c": {-1}, + } + + for var, alias_set in alias_sets.items(): + assert alias_set == expected[var.name_hint] + assert len(tuple_map) == 1 + assert 2 in tuple_map + assert tuple_map[2] == [{-1}, {1}] + + if __name__ == "__main__": tvm.testing.main() From 20189801096fa4627b8ddd01ce9810456af7a2b8 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 5 Oct 2023 21:27:41 -0400 Subject: [PATCH 11/55] typo --- src/relax/transform/dataflow_in_place.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relax/transform/dataflow_in_place.cc b/src/relax/transform/dataflow_in_place.cc index fcb7de77b147..e22330b61671 100644 --- a/src/relax/transform/dataflow_in_place.cc +++ b/src/relax/transform/dataflow_in_place.cc @@ -604,7 +604,7 @@ Array> DataflowInPlaceAnalysis(const DataflowBlock& block, TVM_REGISTER_GLOBAL("relax.analysis.DataflowLivenessAnalysis") .set_body_typed(DataflowLivenessAnalysis); TVM_REGISTER_GLOBAL("relax.analysis.DataflowAliasAnalysis").set_body_typed(DataflowAliasAnalysis); -TVM_REGISTER_GLOBAL("relax.analysis.DataflowInPlaceAnalasis") +TVM_REGISTER_GLOBAL("relax.analysis.DataflowInPlaceAnalysis") .set_body_typed(DataflowInPlaceAnalysis); } // namespace transform From cfe03d1243e5daee89b22cc6af02b1cf6f6ee7d2 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 5 Oct 2023 22:01:26 -0400 Subject: [PATCH 12/55] Add inplace test case, correct minor issues --- python/tvm/relax/analysis/analysis.py | 2 +- src/relax/transform/dataflow_in_place.cc | 32 ++++++++++++------ tests/python/relax/test_dataflow_in_place.py | 34 ++++++++++++++++++-- 3 files changed, 55 insertions(+), 13 deletions(-) diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index 29eb760eb7f2..5cae33a0c7a2 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -556,4 +556,4 @@ def dataflow_inplace_analysis( block: DataflowBlock, inputs: List[Var] ) -> Tuple[List[int], List[int]]: index_lists = _ffi_api.DataflowInPlaceAnalysis(block, inputs) # type: ignore - return tuple(index_lists) # type: ignore + return tuple(map(list, index_lists)) # type: ignore diff --git a/src/relax/transform/dataflow_in_place.cc b/src/relax/transform/dataflow_in_place.cc index e22330b61671..1d45e6cddebe 100644 --- a/src/relax/transform/dataflow_in_place.cc +++ b/src/relax/transform/dataflow_in_place.cc @@ -466,15 +466,20 @@ std::pair, std::vector> find_inplace_opportunities(const D }); std::unordered_set currently_live; - for (auto var : live_order) { - auto live_range = live_ranges[var]; - if (live_range.first > 0) { - break; - } - currently_live.insert(var); - } + int last_live = 0; for (size_t i = 0; i < block->bindings.size(); i++) { + // include all vars that are currently live + for (int j = last_live; j < static_cast(live_order.size()); j++) { + auto live_var = live_order[j]; + auto live_range = live_ranges[live_var]; + if (live_range.first > static_cast(i)) { + break; + } + currently_live.insert(live_var); + last_live++; + } + // if we reach a binding check the conditions Binding b = block->bindings[i]; Var defined_var = b->var; @@ -510,14 +515,18 @@ std::pair, std::vector> find_inplace_opportunities(const D // Make sure at least one candidate is not live past this point and does not have an alias // live past this point std::unordered_set remove_candidates; - remove_candidates.clear(); for (auto candidate : candidates) { if (!df_inplace_conditions_met(live_ranges, alias_sets, tuple_map, currently_live, candidate, i)) { remove_candidates.insert(candidate); } } - candidates.erase(remove_candidates.begin(), remove_candidates.end()); + // bizarre bug: this works, + // but candidates.erase(remove_candidates.begin(), remove_candidates.end()) + // gives a segfault + for (auto candidate : remove_candidates) { + candidates.erase(candidate); + } // if we have a candidate, then this can be made in-place. Report the result if (candidates.size()) { @@ -541,7 +550,10 @@ std::pair, std::vector> find_inplace_opportunities(const D remove.insert(var); } } - currently_live.erase(remove.begin(), remove.end()); + // same issue, using the ranged erase causes a segfault + for (auto var : remove) { + currently_live.erase(var); + } } return {size_match_list, exact_match_list}; diff --git a/tests/python/relax/test_dataflow_in_place.py b/tests/python/relax/test_dataflow_in_place.py index ade7ef17285f..f06138961849 100644 --- a/tests/python/relax/test_dataflow_in_place.py +++ b/tests/python/relax/test_dataflow_in_place.py @@ -16,7 +16,12 @@ # under the License. import tvm -from tvm.relax.analysis import dataflow_liveness_analysis, dataflow_alias_analysis +from tvm import testing +from tvm.relax.analysis import ( + dataflow_liveness_analysis, + dataflow_alias_analysis, + dataflow_inplace_analysis, +) from tvm.script.parser import ir as I, relax as R, tir as T @@ -307,5 +312,30 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): assert tuple_map[2] == [{-1}, {1}] +def test_inplace_simple_case(): + @I.ir_module + class InplaceBasic: + @R.function + def main( + x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32") + ) -> R.Tensor((2, 3), "int32"): + with R.dataflow(): + z = R.add(x, y) # cannot be done inplace: x and y are live later + q = R.multiply(x, y) # can be done inplace: neither x nor y is used later + p = R.add(z, q) # can be done inplace: neither z nor q is used later + r = p # alias of p + m = R.multiply(p, p) # p is not used later but r is, so can't do inplace + n = R.add(m, r) # can be done inplace: neither is used again + l = R.reshape(n, (1, 2, 3)) # same size but not not identical shape + ret = R.reshape(l, (2, 3)) # same size but not identical shape + R.output(ret) + return ret + + block = InplaceBasic["main"].body.blocks[0] + size_match, exact_match = dataflow_inplace_analysis(block, InplaceBasic["main"].params) + assert size_match == [1, 2, 5, 6, 7] + assert exact_match == [1, 2, 5] + + if __name__ == "__main__": - tvm.testing.main() + testing.main() From 945c206e4af81034918b09171b7bee6d90cdfd0c Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Fri, 6 Oct 2023 14:47:09 -0400 Subject: [PATCH 13/55] Consider also using larger tensors to store smaller ones --- src/relax/transform/dataflow_in_place.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relax/transform/dataflow_in_place.cc b/src/relax/transform/dataflow_in_place.cc index 1d45e6cddebe..694743f4789f 100644 --- a/src/relax/transform/dataflow_in_place.cc +++ b/src/relax/transform/dataflow_in_place.cc @@ -299,7 +299,7 @@ std::pair size_matches(const StructInfo& target_info, const StructIn auto arg_shape = Downcast(arg_tensor->shape); int target_size = shape_size(target_shape); int arg_size = shape_size(arg_shape); - if (target_size == -1 || arg_size == -1 || target_size != arg_size) { + if (target_size == -1 || arg_size == -1 || target_size < arg_size) { return {false, false}; } // exact match: number of dims and each dim matches From 0e6027a0f15a6f4dc91509385173056267bd17ec Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Fri, 6 Oct 2023 15:10:25 -0400 Subject: [PATCH 14/55] Check call args against any possible target sinfo, also check tensor sinfo dtype --- src/relax/transform/dataflow_in_place.cc | 62 +++++++++++++++++++++--- 1 file changed, 56 insertions(+), 6 deletions(-) diff --git a/src/relax/transform/dataflow_in_place.cc b/src/relax/transform/dataflow_in_place.cc index 694743f4789f..0e7f64efb129 100644 --- a/src/relax/transform/dataflow_in_place.cc +++ b/src/relax/transform/dataflow_in_place.cc @@ -289,12 +289,52 @@ int shape_size(const ShapeExpr& shape) { return ret; } +std::unordered_set gather_candidate_sinfo( + const StructInfo& result_sinfo) { + if (auto* tensor_info = result_sinfo.as()) { + // don't consider void dtype (don't know the size at compile time) + if (tensor_info->dtype.is_void()) { + return {}; + } + // don't consider cases where we don't know the shape at compile time + // TODO(@slyubomirsky): variables might be okay if we use the arithmetic analyzer + if (auto* shape_node = tensor_info->shape.as()) { + for (auto dim : shape_node->values) { + if (!dim.as()) { + return {}; + } + } + return {GetRef(tensor_info)}; + } else { + return {}; + } + } else if (auto* tuple_info = result_sinfo.as()) { + // we can see if the whole tuple matches or go for any of the components + std::unordered_set ret; + for (auto field : tuple_info->fields) { + auto field_candidates = gather_candidate_sinfo(field); + ret.insert(field_candidates.begin(), field_candidates.end()); + } + // at least one field should be eligible to be done in-place + if (!ret.empty()) { + ret.insert(GetRef(tuple_info)); + } + return ret; + } else { + // don't consider any other types + return {}; + } +} + std::pair size_matches(const StructInfo& target_info, const StructInfo& arg_info) { if (target_info.as() && arg_info.as()) { auto target_tensor = Downcast(target_info); auto arg_tensor = Downcast(arg_info); if (target_tensor->shape.defined() && target_tensor->shape.as() && arg_tensor->shape.defined() && arg_tensor->shape.as()) { + if (target_tensor->dtype != arg_tensor->dtype) { + return {false, false}; + } auto target_shape = Downcast(target_tensor->shape); auto arg_shape = Downcast(arg_tensor->shape); int target_size = shape_size(target_shape); @@ -305,6 +345,9 @@ std::pair size_matches(const StructInfo& target_info, const StructIn // exact match: number of dims and each dim matches if (target_shape->values.size() == arg_shape->values.size()) { for (size_t i = 0; i < target_shape->values.size(); i++) { + if (!arg_shape->values[i].as()) { + return {false, false}; + } if (Downcast(target_shape->values[i])->value != Downcast(arg_shape->values[i])->value) { return {true, false}; @@ -497,14 +540,21 @@ std::pair, std::vector> find_inplace_opportunities(const D std::unordered_set candidates; std::unordered_set exact_match_candidates; + auto target_sinfo = gather_candidate_sinfo(GetStructInfo(defined_var)); + // can't be done in-place, ignore + if (target_sinfo.empty()) { + continue; + } + // Check that at least one argument matches size with the result for (auto arg : call_node->args) { - std::pair match = - size_matches(GetStructInfo(defined_var), GetStructInfo(arg)); - if (match.first) { - candidates.insert(arg); - if (match.second) { - exact_match_candidates.insert(arg); + for (auto target : target_sinfo) { + std::pair match = size_matches(target, GetStructInfo(arg)); + if (match.first) { + candidates.insert(arg); + if (match.second) { + exact_match_candidates.insert(arg); + } } } } From 693bc825d49c7dac9a4f10fa07c9ede3423ec187 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 23 Oct 2023 16:55:51 -0400 Subject: [PATCH 15/55] Handle output vars and tuple get item --- src/relax/transform/dataflow_in_place.cc | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/relax/transform/dataflow_in_place.cc b/src/relax/transform/dataflow_in_place.cc index 0e7f64efb129..9acc9f0984d1 100644 --- a/src/relax/transform/dataflow_in_place.cc +++ b/src/relax/transform/dataflow_in_place.cc @@ -47,6 +47,13 @@ std::unordered_map, ObjectPtrHash, ObjectPtrEqual> anal // (those captured from the outer scope) if (value.as()) { used_vars = FreeVars(value); + } else if (value.as()) { + // Special case: we do not consider a tuple index to be a "use." + // This is a bit of a hack but allows us to do operations that + // create tuples to be done in-place (otherwise, any index of the tuple + // would be considered a use and so the tuple would be live later). + // Hence we keep the array empty. + ; } else { used_vars = AllVars(value); } @@ -58,7 +65,13 @@ std::unordered_map, ObjectPtrHash, ObjectPtrEqual> anal } if (!ret.count(defined_var)) { - ret[defined_var] = {i, block->bindings.size()}; + // if it's an output, then it lives past the end of the block + if (!defined_var.as()) { + ret[defined_var] = {i, block->bindings.size()}; + } else { + // otherwise, it's live only here + ret[defined_var] = {i, i}; + } } else { // this means the var is used later but we encountered its definition now auto last_range = ret[defined_var]; From bcc85137f64282193f13032b1fef4ba6e1f40f83 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 7 Nov 2023 17:43:48 -0500 Subject: [PATCH 16/55] Add legalization for in-place functions --- src/relax/transform/dataflow_in_place.cc | 85 ++++++++++++++++++++++-- 1 file changed, 81 insertions(+), 4 deletions(-) diff --git a/src/relax/transform/dataflow_in_place.cc b/src/relax/transform/dataflow_in_place.cc index 9acc9f0984d1..34af8fbd9c23 100644 --- a/src/relax/transform/dataflow_in_place.cc +++ b/src/relax/transform/dataflow_in_place.cc @@ -19,9 +19,11 @@ */ #include +#include #include #include #include +#include #include "utils.h" @@ -584,9 +586,6 @@ std::pair, std::vector> find_inplace_opportunities(const D remove_candidates.insert(candidate); } } - // bizarre bug: this works, - // but candidates.erase(remove_candidates.begin(), remove_candidates.end()) - // gives a segfault for (auto candidate : remove_candidates) { candidates.erase(candidate); } @@ -613,7 +612,6 @@ std::pair, std::vector> find_inplace_opportunities(const D remove.insert(var); } } - // same issue, using the ranged erase causes a segfault for (auto var : remove) { currently_live.erase(var); } @@ -622,6 +620,81 @@ std::pair, std::vector> find_inplace_opportunities(const D return {size_match_list, exact_match_list}; } +Call add_inplace_legalization(const BlockBuilder& bb, const Call& call, + const Array& inplace_indices) { + static const auto& legalize_map = Op::GetAttrMap("FLegalize"); + static const auto& call_tir_inplace_op = Op::Get("relax.call_tir_inplace"); + + auto op = Downcast(call->op); + auto legalized_call = Downcast(legalize_map[op](bb, call)); + auto* legalized_call_cow = legalized_call.CopyOnWrite(); + + // The legalized call should be call_tir. We will replace it with call_tir_inplace + // and replace the called PrimFunc with an inplace version + auto legal_op = Downcast(legalized_call->args[0]); + auto inline_legal_op_name = legal_op->name_hint + "_inline"; + auto mod = bb->GetContextIRModule(); + + auto legal_primfunc = Downcast(mod->Lookup(legal_op)); + auto* legal_primfunc_cow = legal_primfunc.CopyOnWrite(); + size_t num_outs = inplace_indices.size(); + size_t num_params = legal_primfunc->params.size(); + Map subst_map; + for (size_t i = 0; i < num_outs; i++) { + // we will substitute output i with the corresponding param indicated by inplace indices + subst_map.Set(legal_primfunc->params[num_params - num_outs + i], + legal_primfunc->params[inplace_indices[i].IntValue()]); + } + // take off the last num_outputs arguments + legal_primfunc_cow->params = Array( + legal_primfunc->params.begin(), legal_primfunc->params.begin() + (num_params - num_outs)); + // apply substitution + legal_primfunc_cow->body = + tir::Substitute(legal_primfunc->body, [&subst_map](const tir::Var& v) -> Optional { + if (subst_map.count(v)) { + return subst_map.at(v); + } else { + return Optional(); + } + }); + + // remove the now-unused outputs from the buffer map if they're there + auto buffer_map = legal_primfunc->buffer_map; + for (size_t i = 0; i < num_outs; i++) { + auto out_var = legal_primfunc->params[num_params - num_outs + i]; + if (buffer_map.count(out_var)) { + buffer_map.erase(out_var); + } + } + legal_primfunc_cow->buffer_map = buffer_map; + + // set the no alias attribute to false + auto legal_primfunc_attrs = legal_primfunc->attrs; + auto* legal_primfunc_attrs_cow = legal_primfunc_attrs.CopyOnWrite(); + auto legal_primfunc_attrs_dict = legal_primfunc_attrs_cow->dict; + legal_primfunc_attrs_dict.erase(tir::attr::kNoAlias); + legal_primfunc_attrs_dict.Set(tir::attr::kNoAlias, Bool(false)); + legal_primfunc_attrs_cow->dict = legal_primfunc_attrs_dict; + legal_primfunc_cow->attrs = legal_primfunc_attrs; + + mod->Remove(legal_op); + bb->AddFunction(legal_primfunc, inline_legal_op_name); + auto new_gv = mod->GetGlobalVar(inline_legal_op_name); + + // update the call (change the op, update the argument, change the attrs) + legalized_call_cow->op = call_tir_inplace_op; + + Array new_args(legalized_call->args.begin(), legalized_call->args.end()); + new_args.Set(0, new_gv); + legalized_call_cow->args = new_args; + + ObjectPtr attrs = make_object(); + attrs->inplace_indices = inplace_indices; + legalized_call_cow->attrs = Attrs(attrs); + + return legalized_call; +} + // export for testing namespace transform { @@ -682,6 +755,10 @@ TVM_REGISTER_GLOBAL("relax.analysis.DataflowAliasAnalysis").set_body_typed(Dataf TVM_REGISTER_GLOBAL("relax.analysis.DataflowInPlaceAnalysis") .set_body_typed(DataflowInPlaceAnalysis); +// really only for testing +TVM_REGISTER_GLOBAL("relax.transform.SingleInplaceCall") + .set_body_typed(relax::add_inplace_legalization); + } // namespace transform } // namespace relax } // namespace tvm \ No newline at end of file From eb7a004a834d786885615416641e7669c752f5fc Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 7 Nov 2023 17:46:22 -0500 Subject: [PATCH 17/55] No need to update the NoAlias attribute, actually --- src/relax/transform/dataflow_in_place.cc | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/relax/transform/dataflow_in_place.cc b/src/relax/transform/dataflow_in_place.cc index 34af8fbd9c23..0c0c895b050a 100644 --- a/src/relax/transform/dataflow_in_place.cc +++ b/src/relax/transform/dataflow_in_place.cc @@ -668,15 +668,6 @@ Call add_inplace_legalization(const BlockBuilder& bb, const Call& call, } legal_primfunc_cow->buffer_map = buffer_map; - // set the no alias attribute to false - auto legal_primfunc_attrs = legal_primfunc->attrs; - auto* legal_primfunc_attrs_cow = legal_primfunc_attrs.CopyOnWrite(); - auto legal_primfunc_attrs_dict = legal_primfunc_attrs_cow->dict; - legal_primfunc_attrs_dict.erase(tir::attr::kNoAlias); - legal_primfunc_attrs_dict.Set(tir::attr::kNoAlias, Bool(false)); - legal_primfunc_attrs_cow->dict = legal_primfunc_attrs_dict; - legal_primfunc_cow->attrs = legal_primfunc_attrs; - mod->Remove(legal_op); bb->AddFunction(legal_primfunc, inline_legal_op_name); auto new_gv = mod->GetGlobalVar(inline_legal_op_name); From 2e9315d2646fc83fb6da61aa99be68290f87427b Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 7 Nov 2023 21:46:42 -0500 Subject: [PATCH 18/55] Fix TIR transformation, add tests for inline transformation --- python/tvm/relax/analysis/analysis.py | 14 +- src/relax/transform/dataflow_in_place.cc | 138 ++++++++++++++++--- tests/python/relax/test_dataflow_in_place.py | 67 +++++++++ 3 files changed, 195 insertions(+), 24 deletions(-) diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index 5cae33a0c7a2..f158937d395a 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -27,6 +27,7 @@ import tvm from tvm import tir from tvm import IRModule +from tvm.relax import BlockBuilder from tvm.relax.ty import Type from tvm.relax.struct_info import StructInfo, FuncStructInfo from tvm.relax.expr import DataflowBlock, Var, GlobalVar, Expr, Function, Call, Binding @@ -534,7 +535,7 @@ def detect_recursion(mod: tvm.IRModule) -> List[List[GlobalVar]]: def dataflow_liveness_analysis(block: DataflowBlock) -> Dict[Var, Tuple[int, int]]: live_ranges = _ffi_api.DataflowLivenessAnalysis(block) # type: ignore ret = {} - for (var, live_range) in live_ranges.items(): + for var, live_range in live_ranges.items(): ret[var] = tuple(live_range) return ret # type: ignore @@ -545,9 +546,9 @@ def dataflow_alias_analysis( alias_sets, tuple_map = _ffi_api.DataflowAliasAnalysis(block, inputs) # type: ignore res_alias_sets = {} res_tuple_map = {} - for (var, alias_set) in alias_sets.items(): + for var, alias_set in alias_sets.items(): res_alias_sets[var] = set(alias_set) - for (idx, elem_alias_sets) in tuple_map.items(): + for idx, elem_alias_sets in tuple_map.items(): res_tuple_map[idx] = [set(alias_set) for alias_set in elem_alias_sets] return res_alias_sets, res_tuple_map # type: ignore @@ -557,3 +558,10 @@ def dataflow_inplace_analysis( ) -> Tuple[List[int], List[int]]: index_lists = _ffi_api.DataflowInPlaceAnalysis(block, inputs) # type: ignore return tuple(map(list, index_lists)) # type: ignore + + +# not actually an analysis but putting it here for testing +def dataflow_single_inplace_call( + builder: BlockBuilder, call: Call, inplace_indices: List[int] +) -> Call: + return _ffi_api.SingleInplaceCall(builder, call, inplace_indices) # type: ignore diff --git a/src/relax/transform/dataflow_in_place.cc b/src/relax/transform/dataflow_in_place.cc index 0c0c895b050a..e088b19a02ca 100644 --- a/src/relax/transform/dataflow_in_place.cc +++ b/src/relax/transform/dataflow_in_place.cc @@ -620,6 +620,85 @@ std::pair, std::vector> find_inplace_opportunities(const D return {size_match_list, exact_match_list}; } +tir::Stmt remap_buffers(const tir::Stmt& stmt, const Map& buffer_map) { + class BufferMapper : public tir::StmtExprMutator { + public: + explicit BufferMapper(const Map& buffer_map) + : buffer_map_(buffer_map) {} + + tir::Stmt Remap(const tir::Stmt& stmt) { return VisitStmt(stmt); } + + PrimExpr VisitExpr_(const tir::BufferLoadNode* op) final { + auto node = Downcast(tir::StmtExprMutator::VisitExpr_(op)); + auto* node_cow = node.CopyOnWrite(); + node_cow->buffer = AttemptRemap(node->buffer); + return node; + } + + tir::Stmt VisitStmt_(const tir::BufferStoreNode* op) final { + auto node = Downcast(tir::StmtExprMutator::VisitStmt_(op)); + auto* node_cow = node.CopyOnWrite(); + node_cow->buffer = AttemptRemap(node->buffer); + return node; + } + + tir::Stmt VisitStmt_(const tir::BufferRealizeNode* op) final { + auto node = Downcast(tir::StmtExprMutator::VisitStmt_(op)); + auto* node_cow = node.CopyOnWrite(); + node_cow->buffer = AttemptRemap(node->buffer); + return node; + } + + tir::Stmt VisitStmt_(const tir::DeclBufferNode* op) final { + auto node = Downcast(tir::StmtExprMutator::VisitStmt_(op)); + auto* node_cow = node.CopyOnWrite(); + node_cow->buffer = AttemptRemap(node->buffer); + return node; + } + + tir::Stmt VisitStmt_(const tir::BlockNode* op) final { + auto node = Downcast(tir::StmtExprMutator::VisitStmt_(op)); + auto* node_cow = node.CopyOnWrite(); + // need the lambdas because class methods are not first-class (how ironic) + node_cow->alloc_buffers = + node->alloc_buffers.Map([this](const tir::Buffer& b) { return AttemptRemap(b); }); + node_cow->reads = node->reads.Map( + [this](const tir::BufferRegion& br) { return VisitBufferRegion(br); }); + node_cow->writes = node->writes.Map( + [this](const tir::BufferRegion& br) { return VisitBufferRegion(br); }); + node_cow->match_buffers = node->match_buffers.Map( + [this](const tir::MatchBufferRegion& mbr) { return VisitMatchBufferRegion(mbr); }); + return node; + } + + private: + tir::Buffer AttemptRemap(const tir::Buffer& buffer) { + if (buffer_map_.count(buffer)) { + return buffer_map_.at(buffer); + } + return buffer; + } + + tir::BufferRegion VisitBufferRegion(tir::BufferRegion region) { + auto* region_cow = region.CopyOnWrite(); + region_cow->buffer = AttemptRemap(region_cow->buffer); + return region; + } + + tir::MatchBufferRegion VisitMatchBufferRegion(tir::MatchBufferRegion region) { + auto* region_cow = region.CopyOnWrite(); + region_cow->buffer = AttemptRemap(region_cow->buffer); + return region; + } + + const Map& buffer_map_; + }; + + BufferMapper mapper(buffer_map); + auto ret = mapper.Remap(stmt); + return ret; +} + Call add_inplace_legalization(const BlockBuilder& bb, const Call& call, const Array& inplace_indices) { static const auto& legalize_map = Op::GetAttrMap("FLegalize"); @@ -632,44 +711,61 @@ Call add_inplace_legalization(const BlockBuilder& bb, const Call& call, // The legalized call should be call_tir. We will replace it with call_tir_inplace // and replace the called PrimFunc with an inplace version auto legal_op = Downcast(legalized_call->args[0]); - auto inline_legal_op_name = legal_op->name_hint + "_inline"; + auto inline_legal_op_name = legal_op->name_hint + "_inplace"; auto mod = bb->GetContextIRModule(); auto legal_primfunc = Downcast(mod->Lookup(legal_op)); auto* legal_primfunc_cow = legal_primfunc.CopyOnWrite(); size_t num_outs = inplace_indices.size(); size_t num_params = legal_primfunc->params.size(); - Map subst_map; + + // the replacement we must make: + // 1. For each output var, replace its corresponding buffers with the corresponding inplace index + // var's buffers + // 2. For each output var, replace its instances with the corresponding inplace index var + // 3. Do the same for the *buffer vars* corresponding to the output vars + // 4. Remove the output vars from the param list and buffer map + Map buffer_subst_map; + Map var_subst_map; for (size_t i = 0; i < num_outs; i++) { // we will substitute output i with the corresponding param indicated by inplace indices - subst_map.Set(legal_primfunc->params[num_params - num_outs + i], - legal_primfunc->params[inplace_indices[i].IntValue()]); + auto output_var = legal_primfunc->params[num_params - num_outs + i]; + auto inplace_var = legal_primfunc->params[inplace_indices[i].IntValue()]; + var_subst_map.Set(output_var, inplace_var); + + // also do the same with the buffer vars + auto output_buffer = legal_primfunc->buffer_map.at(output_var); + auto inplace_buffer = legal_primfunc->buffer_map.at(inplace_var); + var_subst_map.Set(output_buffer->data, inplace_buffer->data); + buffer_subst_map.Set(output_buffer, inplace_buffer); } - // take off the last num_outputs arguments - legal_primfunc_cow->params = Array( - legal_primfunc->params.begin(), legal_primfunc->params.begin() + (num_params - num_outs)); - // apply substitution - legal_primfunc_cow->body = - tir::Substitute(legal_primfunc->body, [&subst_map](const tir::Var& v) -> Optional { - if (subst_map.count(v)) { - return subst_map.at(v); - } else { - return Optional(); + + // apply substitutions + legal_primfunc_cow->body = remap_buffers(legal_primfunc->body, buffer_subst_map); + legal_primfunc_cow->body = tir::Substitute( + legal_primfunc->body, [&var_subst_map](const tir::Var& v) -> Optional { + if (var_subst_map.count(v)) { + return var_subst_map.at(v); } + return Optional(); }); - // remove the now-unused outputs from the buffer map if they're there + // remove the now-unused outputs from the buffer map auto buffer_map = legal_primfunc->buffer_map; for (size_t i = 0; i < num_outs; i++) { - auto out_var = legal_primfunc->params[num_params - num_outs + i]; - if (buffer_map.count(out_var)) { - buffer_map.erase(out_var); - } + buffer_map.erase(legal_primfunc->params[num_params - num_outs + i]); } legal_primfunc_cow->buffer_map = buffer_map; + // now get rid of the last num_outputs arguments + // (couldn't do earlier or else it would have thrown off the indexing) + legal_primfunc_cow->params = Array( + legal_primfunc->params.begin(), legal_primfunc->params.begin() + (num_params - num_outs)); + mod->Remove(legal_op); bb->AddFunction(legal_primfunc, inline_legal_op_name); + // need to update the mod to get the new function + mod = std::move(bb->GetContextIRModule()); auto new_gv = mod->GetGlobalVar(inline_legal_op_name); // update the call (change the op, update the argument, change the attrs) @@ -746,8 +842,8 @@ TVM_REGISTER_GLOBAL("relax.analysis.DataflowAliasAnalysis").set_body_typed(Dataf TVM_REGISTER_GLOBAL("relax.analysis.DataflowInPlaceAnalysis") .set_body_typed(DataflowInPlaceAnalysis); -// really only for testing -TVM_REGISTER_GLOBAL("relax.transform.SingleInplaceCall") +// really only for testing (not actually an analysis, will move) +TVM_REGISTER_GLOBAL("relax.analysis.SingleInplaceCall") .set_body_typed(relax::add_inplace_legalization); } // namespace transform diff --git a/tests/python/relax/test_dataflow_in_place.py b/tests/python/relax/test_dataflow_in_place.py index f06138961849..0e64f48cb818 100644 --- a/tests/python/relax/test_dataflow_in_place.py +++ b/tests/python/relax/test_dataflow_in_place.py @@ -17,10 +17,12 @@ import tvm from tvm import testing +from tvm.relax import BlockBuilder from tvm.relax.analysis import ( dataflow_liveness_analysis, dataflow_alias_analysis, dataflow_inplace_analysis, + dataflow_single_inplace_call, ) from tvm.script.parser import ir as I, relax as R, tir as T @@ -337,5 +339,70 @@ def main( assert exact_match == [1, 2, 5] +def test_inplace_call(): + @I.ir_module + class TestModule: + @R.function + def main( + x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32") + ) -> R.Tensor((2, 3), dtype="float32"): + z = R.add(x, y) + q = R.nn.silu(z) + return q + + builder = BlockBuilder(mod=TestModule) + add_call = TestModule["main"].body.blocks[0].bindings[0].value + new_add = dataflow_single_inplace_call(builder, add_call, [0]) + + @T.prim_func(private=True) + def expected_add( + A: T.Buffer((T.int64(2), T.int64(3)), "float32"), + B: T.Buffer((T.int64(2), T.int64(3)), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1]) + T.writes(A[v_ax0, v_ax1]) + A[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1] + + new_mod = builder.get() + tvm.ir.assert_structural_equal(new_mod["add_inplace"], expected_add) + assert new_add.op.name == "relax.call_tir_inplace" + assert new_add.args[0].name_hint == "add_inplace" + for i, arg in enumerate(new_add.args[1].fields): + arg == add_call.args[i] + new_add.attrs.inplace_indices == [0] + + @T.prim_func(private=True) + def expected_silu(A: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + compute = T.alloc_buffer((T.int64(2), T.int64(3))) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.sigmoid(A[v_i0, v_i1]) + for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_multiply"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1], compute[v_ax0, v_ax1]) + T.writes(A[v_ax0, v_ax1]) + A[v_ax0, v_ax1] = A[v_ax0, v_ax1] * compute[v_ax0, v_ax1] + + silu_call = TestModule["main"].body.blocks[0].bindings[1].value + new_silu = dataflow_single_inplace_call(builder, silu_call, [0]) + + new_mod = builder.get() + tvm.ir.assert_structural_equal(new_mod["silu_inplace"], expected_silu) + assert new_silu.op.name == "relax.call_tir_inplace" + assert new_silu.args[0].name_hint == "silu_inplace" + for i, arg in enumerate(new_silu.args[1].fields): + arg == silu_call.args[i] + new_silu.attrs.inplace_indices == [0] + + if __name__ == "__main__": testing.main() From 5616789a175b21a13a910f0a350461cf646fe5f5 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 7 Nov 2023 22:33:09 -0500 Subject: [PATCH 19/55] Only find candidates from supported ops and list _all_ feasible argument indices --- src/relax/transform/dataflow_in_place.cc | 88 +++++++++++++------- tests/python/relax/test_dataflow_in_place.py | 26 ++++-- 2 files changed, 79 insertions(+), 35 deletions(-) diff --git a/src/relax/transform/dataflow_in_place.cc b/src/relax/transform/dataflow_in_place.cc index e088b19a02ca..7cc19b70dee5 100644 --- a/src/relax/transform/dataflow_in_place.cc +++ b/src/relax/transform/dataflow_in_place.cc @@ -497,21 +497,27 @@ bool df_inplace_conditions_met( } } +// this is obviously not a complete list +static std::unordered_set SUPPORTED_OPS = {"relax.add", "relax.subtract", + "relax.multiply", "relax.divide", + "relax.nn.silu", "relax.nn.relu"}; +bool op_supports_inplace(const Op& op) { return SUPPORTED_OPS.count(op->name); } + // check for in-place eligibility: // 1. see if there's an arg big enough to hold the result // 2. see if the arg is live past the call // 3. see if the arg has an alias that's live past the call // if conditions are met, we're good to go -std::pair, std::vector> find_inplace_opportunities(const DataflowBlock& block, - const Array& inputs) { +std::pair>, std::vector>> find_inplace_opportunities( + const DataflowBlock& block, const Array& inputs) { auto live_ranges = analyze_liveness(block); AliasAnalyzer analyzer; auto alias_info = analyzer.Analyze(block, inputs); auto alias_sets = alias_info.first; auto tuple_map = alias_info.second; - std::vector size_match_list; - std::vector exact_match_list; + std::vector> size_match_list; + std::vector> exact_match_list; // sort the live ranges by starting index std::vector live_order; @@ -551,9 +557,13 @@ std::pair, std::vector> find_inplace_opportunities(const D } if (auto* call_node = value.as()) { - if (call_node->op.as()) { - std::unordered_set candidates; - std::unordered_set exact_match_candidates; + if (auto* op_node = call_node->op.as()) { + if (!op_supports_inplace(GetRef(op_node))) { + continue; + } + + std::unordered_set candidates; + std::unordered_set exact_match_candidates; auto target_sinfo = gather_candidate_sinfo(GetStructInfo(defined_var)); // can't be done in-place, ignore @@ -562,13 +572,14 @@ std::pair, std::vector> find_inplace_opportunities(const D } // Check that at least one argument matches size with the result - for (auto arg : call_node->args) { + for (size_t j = 0; j < call_node->args.size(); j++) { + auto arg = call_node->args[j]; for (auto target : target_sinfo) { std::pair match = size_matches(target, GetStructInfo(arg)); if (match.first) { - candidates.insert(arg); + candidates.insert(static_cast(j)); if (match.second) { - exact_match_candidates.insert(arg); + exact_match_candidates.insert(static_cast(j)); } } } @@ -579,27 +590,38 @@ std::pair, std::vector> find_inplace_opportunities(const D // Make sure at least one candidate is not live past this point and does not have an alias // live past this point - std::unordered_set remove_candidates; + std::unordered_set remove_candidates; for (auto candidate : candidates) { if (!df_inplace_conditions_met(live_ranges, alias_sets, tuple_map, currently_live, - candidate, i)) { + call_node->args[candidate], i)) { remove_candidates.insert(candidate); } } + // (remove now to avoid modifying the list as we iterate on it) for (auto candidate : remove_candidates) { candidates.erase(candidate); } - // if we have a candidate, then this can be made in-place. Report the result - if (candidates.size()) { - size_match_list.push_back(i); + // if we have a candidate, then this can be made in-place. Report the appropriate candidates + if (!candidates.size()) { + continue; } + + // produce a list of candidates for this index + std::vector size_match_indices = {static_cast(i)}; + size_match_indices.insert(size_match_indices.end(), candidates.begin(), candidates.end()); + size_match_list.push_back(size_match_indices); + + // also gather up the exact match candidates if there are any + std::vector exact_match_indices = {static_cast(i)}; for (auto candidate : candidates) { if (exact_match_candidates.count(candidate)) { - exact_match_list.push_back(i); - break; + exact_match_indices.push_back(candidate); } } + if (exact_match_indices.size() > 1) { + exact_match_list.push_back(exact_match_indices); + } } } @@ -662,10 +684,10 @@ tir::Stmt remap_buffers(const tir::Stmt& stmt, const Mapalloc_buffers = node->alloc_buffers.Map([this](const tir::Buffer& b) { return AttemptRemap(b); }); - node_cow->reads = node->reads.Map( - [this](const tir::BufferRegion& br) { return VisitBufferRegion(br); }); - node_cow->writes = node->writes.Map( - [this](const tir::BufferRegion& br) { return VisitBufferRegion(br); }); + node_cow->reads = + node->reads.Map([this](const tir::BufferRegion& br) { return VisitBufferRegion(br); }); + node_cow->writes = + node->writes.Map([this](const tir::BufferRegion& br) { return VisitBufferRegion(br); }); node_cow->match_buffers = node->match_buffers.Map( [this](const tir::MatchBufferRegion& mbr) { return VisitMatchBufferRegion(mbr); }); return node; @@ -822,16 +844,24 @@ Array DataflowAliasAnalysis(const DataflowBlock& block, Array in return {new_alias_sets, new_tuple_map}; } -Array> DataflowInPlaceAnalysis(const DataflowBlock& block, - const Array& inputs) { +Array>> DataflowInPlaceAnalysis(const DataflowBlock& block, + const Array& inputs) { auto index_lists = relax::find_inplace_opportunities(block, inputs); - Array size_match_array; - for (int index : index_lists.first) { - size_match_array.push_back(index); + Array> size_match_array; + for (auto indices : index_lists.first) { + Array index_array; + for (auto index : indices) { + index_array.push_back(Integer(index)); + } + size_match_array.push_back(index_array); } - Array exact_match_array; - for (int index : index_lists.second) { - exact_match_array.push_back(index); + Array> exact_match_array; + for (auto indices : index_lists.second) { + Array index_array; + for (auto index : indices) { + index_array.push_back(Integer(index)); + } + exact_match_array.push_back(index_array); } return {size_match_array, exact_match_array}; } diff --git a/tests/python/relax/test_dataflow_in_place.py b/tests/python/relax/test_dataflow_in_place.py index 0e64f48cb818..07af3301b9d7 100644 --- a/tests/python/relax/test_dataflow_in_place.py +++ b/tests/python/relax/test_dataflow_in_place.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +from typing import List, Set, Tuple import tvm from tvm import testing from tvm.relax import BlockBuilder @@ -327,19 +328,32 @@ def main( p = R.add(z, q) # can be done inplace: neither z nor q is used later r = p # alias of p m = R.multiply(p, p) # p is not used later but r is, so can't do inplace - n = R.add(m, r) # can be done inplace: neither is used again - l = R.reshape(n, (1, 2, 3)) # same size but not not identical shape - ret = R.reshape(l, (2, 3)) # same size but not identical shape + n = R.add(m, r) # can be done inplace: r is not used again + ret = R.subtract(n, m) # can be done inplace: neither is used again R.output(ret) return ret block = InplaceBasic["main"].body.blocks[0] size_match, exact_match = dataflow_inplace_analysis(block, InplaceBasic["main"].params) - assert size_match == [1, 2, 5, 6, 7] - assert exact_match == [1, 2, 5] + # order does not matter for the listing of candidates, so we have to implement as sets + def assert_candidate_list( + actual: List[List[int]], expected: List[Tuple[int, Set[int]]] + ) -> None: + assert len(actual) == len(expected) + for i in range(len(actual)): + assert actual[i][0] == expected[i][0] + assert len(expected[i][1]) == len(actual[i]) - 1 + for j in range(len(expected[i][1])): + assert actual[i][j + 1] in expected[i][1] -def test_inplace_call(): + assert_candidate_list(size_match, [(1, {0, 1}), (2, {0, 1}), (5, {1}), (6, {0, 1})]) + # TODO(@slyubomirsky): I couldn't think of an easy example where sizes don't match, + # but broadcasting might cause it to happen + assert_candidate_list(exact_match, [(1, {0, 1}), (2, {0, 1}), (5, {1}), (6, {0, 1})]) + + +def test_inplace_single_call(): @I.ir_module class TestModule: @R.function From dd3c8828045605ced34c2db27a70b91f36608250 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 7 Nov 2023 23:05:24 -0500 Subject: [PATCH 20/55] Implement basic transformation pass --- src/relax/transform/dataflow_in_place.cc | 89 +++++++++++++++++------- 1 file changed, 65 insertions(+), 24 deletions(-) diff --git a/src/relax/transform/dataflow_in_place.cc b/src/relax/transform/dataflow_in_place.cc index 7cc19b70dee5..cae82ce65222 100644 --- a/src/relax/transform/dataflow_in_place.cc +++ b/src/relax/transform/dataflow_in_place.cc @@ -30,20 +30,25 @@ namespace tvm { namespace relax { +Expr BindingValue(const Binding& b) { + Expr value; + if (const auto* var_binding = b.as()) { + value = var_binding->value; + } else if (const auto* match_binding = b.as()) { + value = match_binding->value; + } else { + CHECK(false) << "Invalid binding"; // impossible + } + return value; +} + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> analyze_liveness( const DataflowBlock& block) { std::unordered_map, ObjectPtrHash, ObjectPtrEqual> ret; for (int i = block->bindings.size() - 1; i >= 0; i--) { Binding b = block->bindings[i]; Var defined_var = b->var; - Expr value; - if (const auto* var_binding = b.as()) { - value = var_binding->value; - } else if (const auto* match_binding = b.as()) { - value = match_binding->value; - } else { - CHECK(false) << "Invalid binding"; // impossible - } + Expr value = BindingValue(b); Array used_vars; // for a function literal, we consider only the free vars // (those captured from the outer scope) @@ -104,14 +109,7 @@ class AliasAnalyzer { for (const Binding& binding : block->bindings) { Var current_var = binding->var; - Expr value; - if (const auto* var_binding = binding.as()) { - value = var_binding->value; - } else if (const auto* match_binding = binding.as()) { - value = match_binding->value; - } else { - CHECK(false) << "Invalid binding"; // impossible - } + Expr value = BindingValue(binding); alias_map_[current_var] = get_alias_set(value, current_var); } @@ -547,14 +545,7 @@ std::pair>, std::vector>> find_inp // if we reach a binding check the conditions Binding b = block->bindings[i]; Var defined_var = b->var; - Expr value; - if (const auto* var_binding = b.as()) { - value = var_binding->value; - } else if (const auto* match_binding = b.as()) { - value = match_binding->value; - } else { - CHECK(false) << "Invalid binding"; // impossible - } + Expr value = BindingValue(b); if (auto* call_node = value.as()) { if (auto* op_node = call_node->op.as()) { @@ -876,6 +867,56 @@ TVM_REGISTER_GLOBAL("relax.analysis.DataflowInPlaceAnalysis") TVM_REGISTER_GLOBAL("relax.analysis.SingleInplaceCall") .set_body_typed(relax::add_inplace_legalization); +// not actually an analysis, will rename +TVM_REGISTER_GLOBAL("relax.analysis.DataflowInsertInPlaceCalls").set_body_typed([]() -> Pass { + return CreateDataflowBlockPass( + [](const DataflowBlock& block, const IRModule& mod, const PassContext& ctx) -> DataflowBlock { + BlockBuilder bb = BlockBuilder::Create(mod); + std::unordered_set free_var_set; + for (auto binding : block->bindings) { + auto binding_free_vars = FreeVars(BindingValue(binding)); + free_var_set.insert(binding_free_vars.begin(), binding_free_vars.end()); + } + Array free_var_list(free_var_set.begin(), free_var_set.end()); + + // for now, only handle exact match cases + auto matches_found = find_inplace_opportunities(block, free_var_list); + auto exact_matches = matches_found.second; + + Array new_bindings; + int current_match_index = 0; + for (size_t i = 0; i < block->bindings.size(); i++) { + int candidate_binding_idx = exact_matches[current_match_index][0]; + if (candidate_binding_idx != static_cast(i)) { + new_bindings.push_back(block->bindings[i]); + continue; + } + auto target_binding = block->bindings[i]; + auto target_call = Downcast(BindingValue(target_binding)); + // can just pick the first index arbitrarily (only using one output for now too) + auto new_call = + add_inplace_legalization(bb, target_call, {exact_matches[current_match_index][1]}); + // now replace the binding appropriately + if (auto* var_binding_node = target_binding.as()) { + auto var_binding = GetRef(var_binding_node); + auto* var_binding_cow = var_binding.CopyOnWrite(); + var_binding_cow->value = new_call; + new_bindings.push_back(var_binding); + } else if (auto* match_cast_node = target_binding.as()) { + auto match_cast = GetRef(match_cast_node); + auto* match_cast_cow = match_cast.CopyOnWrite(); + match_cast_cow->value = new_call; + new_bindings.push_back(match_cast); + } else { + CHECK(false) << "Invalid binding type"; + } + current_match_index++; + } + return DataflowBlock(new_bindings, block->span); + }, + 0, "DataflowInsertInPlaceCalls", {}); +}); + } // namespace transform } // namespace relax } // namespace tvm \ No newline at end of file From b8f735942b527e2125b648282c6a7752b48dc4e9 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 8 Nov 2023 16:45:25 -0500 Subject: [PATCH 21/55] Use a module pass so wider changes are visible, reorganize --- python/tvm/relax/analysis/analysis.py | 12 +- src/relax/transform/dataflow_in_place.cc | 296 +++++++++++-------- tests/python/relax/test_dataflow_in_place.py | 29 +- 3 files changed, 211 insertions(+), 126 deletions(-) diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index f158937d395a..1f32b0e8b600 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -562,6 +562,12 @@ def dataflow_inplace_analysis( # not actually an analysis but putting it here for testing def dataflow_single_inplace_call( - builder: BlockBuilder, call: Call, inplace_indices: List[int] -) -> Call: - return _ffi_api.SingleInplaceCall(builder, call, inplace_indices) # type: ignore + mod: IRModule, call: Call, inplace_indices: List[int] +) -> Tuple[Call, IRModule]: + ret = _ffi_api.SingleInplaceCall(mod, call, inplace_indices) # type: ignore + return (ret[0], ret[1]) # type: ignore + + +# also not actually an analysis but putting it here for testing +def dataflow_insert_inplace_calls() -> tvm.ir.transform.Pass: + return _ffi_api.DataflowInsertInPlaceCalls() # type: ignore diff --git a/src/relax/transform/dataflow_in_place.cc b/src/relax/transform/dataflow_in_place.cc index cae82ce65222..0dda066eab23 100644 --- a/src/relax/transform/dataflow_in_place.cc +++ b/src/relax/transform/dataflow_in_place.cc @@ -18,6 +18,7 @@ * under the License. */ +#include #include #include #include @@ -712,88 +713,182 @@ tir::Stmt remap_buffers(const tir::Stmt& stmt, const Map& inplace_indices) { - static const auto& legalize_map = Op::GetAttrMap("FLegalize"); - static const auto& call_tir_inplace_op = Op::Get("relax.call_tir_inplace"); - - auto op = Downcast(call->op); - auto legalized_call = Downcast(legalize_map[op](bb, call)); - auto* legalized_call_cow = legalized_call.CopyOnWrite(); - - // The legalized call should be call_tir. We will replace it with call_tir_inplace - // and replace the called PrimFunc with an inplace version - auto legal_op = Downcast(legalized_call->args[0]); - auto inline_legal_op_name = legal_op->name_hint + "_inplace"; - auto mod = bb->GetContextIRModule(); - - auto legal_primfunc = Downcast(mod->Lookup(legal_op)); - auto* legal_primfunc_cow = legal_primfunc.CopyOnWrite(); - size_t num_outs = inplace_indices.size(); - size_t num_params = legal_primfunc->params.size(); - - // the replacement we must make: - // 1. For each output var, replace its corresponding buffers with the corresponding inplace index - // var's buffers - // 2. For each output var, replace its instances with the corresponding inplace index var - // 3. Do the same for the *buffer vars* corresponding to the output vars - // 4. Remove the output vars from the param list and buffer map - Map buffer_subst_map; - Map var_subst_map; - for (size_t i = 0; i < num_outs; i++) { - // we will substitute output i with the corresponding param indicated by inplace indices - auto output_var = legal_primfunc->params[num_params - num_outs + i]; - auto inplace_var = legal_primfunc->params[inplace_indices[i].IntValue()]; - var_subst_map.Set(output_var, inplace_var); - - // also do the same with the buffer vars - auto output_buffer = legal_primfunc->buffer_map.at(output_var); - auto inplace_buffer = legal_primfunc->buffer_map.at(inplace_var); - var_subst_map.Set(output_buffer->data, inplace_buffer->data); - buffer_subst_map.Set(output_buffer, inplace_buffer); +class ModuleInplaceTransformer : public ExprMutator { + public: + explicit ModuleInplaceTransformer(const IRModule& mod) : mod_(mod) { + builder_ = BlockBuilder::Create(mod); } - // apply substitutions - legal_primfunc_cow->body = remap_buffers(legal_primfunc->body, buffer_subst_map); - legal_primfunc_cow->body = tir::Substitute( - legal_primfunc->body, [&var_subst_map](const tir::Var& v) -> Optional { - if (var_subst_map.count(v)) { - return var_subst_map.at(v); - } - return Optional(); - }); + IRModule Transform() { + // visit every Relax function in the module + for (auto kv : mod_->functions) { + if (auto* func_node = kv.second.as()) { + auto gv = kv.first; + auto func_params = func_node->params; + auto function = GetRef(func_node); + auto* function_cow = function.CopyOnWrite(); + auto new_body = VisitExpr(function->body); + function_cow->body = new_body; + builder_->UpdateFunction(gv, function); + } + } - // remove the now-unused outputs from the buffer map - auto buffer_map = legal_primfunc->buffer_map; - for (size_t i = 0; i < num_outs; i++) { - buffer_map.erase(legal_primfunc->params[num_params - num_outs + i]); + auto ret = builder_->GetContextIRModule(); + // clean up to avoid polluting the IRModule + for (auto gv : legalizers_added) { + ret->Remove(gv); + } + return ret; } - legal_primfunc_cow->buffer_map = buffer_map; - // now get rid of the last num_outputs arguments - // (couldn't do earlier or else it would have thrown off the indexing) - legal_primfunc_cow->params = Array( - legal_primfunc->params.begin(), legal_primfunc->params.begin() + (num_params - num_outs)); + // for handling inner functions + Expr VisitExpr_(const FunctionNode* op) override { + auto old_func_params = func_params; + func_params = op->params; + auto ret = ExprMutator::VisitExpr(GetRef(op)); + func_params = old_func_params; + return ret; + } - mod->Remove(legal_op); - bb->AddFunction(legal_primfunc, inline_legal_op_name); - // need to update the mod to get the new function - mod = std::move(bb->GetContextIRModule()); - auto new_gv = mod->GetGlobalVar(inline_legal_op_name); + // the only case we will override: we will visit all binding blocks + // and replace any valid calls in them + BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override { + auto block = GetRef(op); + std::unordered_set free_var_set; + for (auto binding : block->bindings) { + auto binding_free_vars = FreeVars(BindingValue(binding)); + free_var_set.insert(binding_free_vars.begin(), binding_free_vars.end()); + } + Array free_var_list(free_var_set.begin(), free_var_set.end()); + + // for now, only handle exact match cases + auto matches_found = find_inplace_opportunities(block, free_var_list); + auto exact_matches = matches_found.second; + + Array new_bindings; + int current_match_index = 0; + for (size_t i = 0; i < block->bindings.size(); i++) { + int candidate_binding_idx = exact_matches[current_match_index][0]; + if (candidate_binding_idx != static_cast(i)) { + new_bindings.push_back(block->bindings[i]); + continue; + } + auto target_binding = block->bindings[i]; + auto target_call = Downcast(BindingValue(target_binding)); + // can just pick the first index arbitrarily (only using one output for now too) + auto new_call = CreateInplaceCall(target_call, {exact_matches[current_match_index][1]}); + // now replace the binding appropriately + if (auto* var_binding_node = target_binding.as()) { + auto var_binding = GetRef(var_binding_node); + auto* var_binding_cow = var_binding.CopyOnWrite(); + var_binding_cow->value = new_call; + new_bindings.push_back(var_binding); + } else if (auto* match_cast_node = target_binding.as()) { + auto match_cast = GetRef(match_cast_node); + auto* match_cast_cow = match_cast.CopyOnWrite(); + match_cast_cow->value = new_call; + new_bindings.push_back(match_cast); + } else { + CHECK(false) << "Invalid binding type"; + } + current_match_index++; + } + return DataflowBlock(new_bindings, block->span); + } - // update the call (change the op, update the argument, change the attrs) - legalized_call_cow->op = call_tir_inplace_op; + // exposed for testing + Call CreateInplaceCall(const Call& call, const Array& inplace_indices) { + static const auto& legalize_map = Op::GetAttrMap("FLegalize"); + static const auto& call_tir_inplace_op = Op::Get("relax.call_tir_inplace"); + + auto op = Downcast(call->op); + auto legalized_call = Downcast(legalize_map[op](builder_, call)); + auto* legalized_call_cow = legalized_call.CopyOnWrite(); + + // The legalized call should be call_tir. We will replace it with call_tir_inplace + // and replace the called PrimFunc with an inplace version + auto legal_op = Downcast(legalized_call->args[0]); + legalizers_added.push_back(legal_op); + auto inline_legal_op_name = legal_op->name_hint + "_inplace"; + + auto mod = builder_->GetContextIRModule(); + auto legal_primfunc = Downcast(mod->Lookup(legal_op)); + auto* legal_primfunc_cow = legal_primfunc.CopyOnWrite(); + size_t num_outs = inplace_indices.size(); + size_t num_params = legal_primfunc->params.size(); + + // the replacement we must make: + // 1. For each output var, replace its corresponding buffers with the corresponding inplace + // index + // var's buffers + // 2. For each output var, replace its instances with the corresponding inplace index var + // 3. Do the same for the *buffer vars* corresponding to the output vars + // 4. Remove the output vars from the param list and buffer map + Map buffer_subst_map; + Map var_subst_map; + for (size_t i = 0; i < num_outs; i++) { + // we will substitute output i with the corresponding param indicated by inplace indices + auto output_var = legal_primfunc->params[num_params - num_outs + i]; + auto inplace_var = legal_primfunc->params[inplace_indices[i].IntValue()]; + var_subst_map.Set(output_var, inplace_var); + + // also do the same with the buffer vars + auto output_buffer = legal_primfunc->buffer_map.at(output_var); + auto inplace_buffer = legal_primfunc->buffer_map.at(inplace_var); + var_subst_map.Set(output_buffer->data, inplace_buffer->data); + buffer_subst_map.Set(output_buffer, inplace_buffer); + } + + // apply substitutions + legal_primfunc_cow->body = remap_buffers(legal_primfunc->body, buffer_subst_map); + legal_primfunc_cow->body = tir::Substitute( + legal_primfunc->body, [&var_subst_map](const tir::Var& v) -> Optional { + if (var_subst_map.count(v)) { + return var_subst_map.at(v); + } + return Optional(); + }); - Array new_args(legalized_call->args.begin(), legalized_call->args.end()); - new_args.Set(0, new_gv); - legalized_call_cow->args = new_args; + // remove the now-unused outputs from the buffer map + auto buffer_map = legal_primfunc->buffer_map; + for (size_t i = 0; i < num_outs; i++) { + buffer_map.erase(legal_primfunc->params[num_params - num_outs + i]); + } + legal_primfunc_cow->buffer_map = buffer_map; - ObjectPtr attrs = make_object(); - attrs->inplace_indices = inplace_indices; - legalized_call_cow->attrs = Attrs(attrs); + // now get rid of the last num_outputs arguments + // (couldn't do earlier or else it would have thrown off the indexing) + legal_primfunc_cow->params = Array( + legal_primfunc->params.begin(), legal_primfunc->params.begin() + (num_params - num_outs)); - return legalized_call; -} + // note: this might be a good time to get rid of the old legalized function, but we don't do it + // now because later ops might need the same one. Instead, we will clean up at the end + auto new_gv = builder_->AddFunction(legal_primfunc, inline_legal_op_name); + + // update the call (change the op, update the argument, change the attrs) + legalized_call_cow->op = call_tir_inplace_op; + + Array new_args(legalized_call->args.begin(), legalized_call->args.end()); + new_args.Set(0, new_gv); + legalized_call_cow->args = new_args; + + ObjectPtr attrs = make_object(); + attrs->inplace_indices = inplace_indices; + legalized_call_cow->attrs = Attrs(attrs); + + return legalized_call; + } + + // exposed for testing + IRModule CurrentMod() { return builder_->GetContextIRModule(); } + + private: + const IRModule& mod_; + Array + legalizers_added; // Keep track of legalizers we add so we can clean up at the end. + Array func_params; // The current function's params will be treated as non-aliased + // (we are assuming good behavior on the user's part). +}; // export for testing namespace transform { @@ -865,56 +960,21 @@ TVM_REGISTER_GLOBAL("relax.analysis.DataflowInPlaceAnalysis") // really only for testing (not actually an analysis, will move) TVM_REGISTER_GLOBAL("relax.analysis.SingleInplaceCall") - .set_body_typed(relax::add_inplace_legalization); + .set_body_typed([](const IRModule& mod, const Call& call, + const Array& inplace_indices) -> Array { + ModuleInplaceTransformer transformer(mod); + auto ret_call = transformer.CreateInplaceCall(call, inplace_indices); + return Array{ret_call, transformer.CurrentMod()}; + }); // not actually an analysis, will rename TVM_REGISTER_GLOBAL("relax.analysis.DataflowInsertInPlaceCalls").set_body_typed([]() -> Pass { - return CreateDataflowBlockPass( - [](const DataflowBlock& block, const IRModule& mod, const PassContext& ctx) -> DataflowBlock { - BlockBuilder bb = BlockBuilder::Create(mod); - std::unordered_set free_var_set; - for (auto binding : block->bindings) { - auto binding_free_vars = FreeVars(BindingValue(binding)); - free_var_set.insert(binding_free_vars.begin(), binding_free_vars.end()); - } - Array free_var_list(free_var_set.begin(), free_var_set.end()); - - // for now, only handle exact match cases - auto matches_found = find_inplace_opportunities(block, free_var_list); - auto exact_matches = matches_found.second; - - Array new_bindings; - int current_match_index = 0; - for (size_t i = 0; i < block->bindings.size(); i++) { - int candidate_binding_idx = exact_matches[current_match_index][0]; - if (candidate_binding_idx != static_cast(i)) { - new_bindings.push_back(block->bindings[i]); - continue; - } - auto target_binding = block->bindings[i]; - auto target_call = Downcast(BindingValue(target_binding)); - // can just pick the first index arbitrarily (only using one output for now too) - auto new_call = - add_inplace_legalization(bb, target_call, {exact_matches[current_match_index][1]}); - // now replace the binding appropriately - if (auto* var_binding_node = target_binding.as()) { - auto var_binding = GetRef(var_binding_node); - auto* var_binding_cow = var_binding.CopyOnWrite(); - var_binding_cow->value = new_call; - new_bindings.push_back(var_binding); - } else if (auto* match_cast_node = target_binding.as()) { - auto match_cast = GetRef(match_cast_node); - auto* match_cast_cow = match_cast.CopyOnWrite(); - match_cast_cow->value = new_call; - new_bindings.push_back(match_cast); - } else { - CHECK(false) << "Invalid binding type"; - } - current_match_index++; - } - return DataflowBlock(new_bindings, block->span); + return tvm::transform::CreateModulePass( + [](const IRModule& mod, const PassContext& ctx) -> IRModule { + ModuleInplaceTransformer transformer(mod); + return transformer.Transform(); }, - 0, "DataflowInsertInPlaceCalls", {}); + 0, "DataflowInsertInPlaceCalls", {}, false); }); } // namespace transform diff --git a/tests/python/relax/test_dataflow_in_place.py b/tests/python/relax/test_dataflow_in_place.py index 07af3301b9d7..fab038dd24a2 100644 --- a/tests/python/relax/test_dataflow_in_place.py +++ b/tests/python/relax/test_dataflow_in_place.py @@ -24,6 +24,7 @@ dataflow_alias_analysis, dataflow_inplace_analysis, dataflow_single_inplace_call, + dataflow_insert_inplace_calls, ) from tvm.script.parser import ir as I, relax as R, tir as T @@ -364,9 +365,8 @@ def main( q = R.nn.silu(z) return q - builder = BlockBuilder(mod=TestModule) add_call = TestModule["main"].body.blocks[0].bindings[0].value - new_add = dataflow_single_inplace_call(builder, add_call, [0]) + new_add, new_mod = dataflow_single_inplace_call(TestModule, add_call, [0]) @T.prim_func(private=True) def expected_add( @@ -381,7 +381,6 @@ def expected_add( T.writes(A[v_ax0, v_ax1]) A[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1] - new_mod = builder.get() tvm.ir.assert_structural_equal(new_mod["add_inplace"], expected_add) assert new_add.op.name == "relax.call_tir_inplace" assert new_add.args[0].name_hint == "add_inplace" @@ -407,9 +406,8 @@ def expected_silu(A: T.Buffer((T.int64(2), T.int64(3)), "float32")): A[v_ax0, v_ax1] = A[v_ax0, v_ax1] * compute[v_ax0, v_ax1] silu_call = TestModule["main"].body.blocks[0].bindings[1].value - new_silu = dataflow_single_inplace_call(builder, silu_call, [0]) + new_silu, new_mod = dataflow_single_inplace_call(TestModule, silu_call, [0]) - new_mod = builder.get() tvm.ir.assert_structural_equal(new_mod["silu_inplace"], expected_silu) assert new_silu.op.name == "relax.call_tir_inplace" assert new_silu.args[0].name_hint == "silu_inplace" @@ -418,5 +416,26 @@ def expected_silu(A: T.Buffer((T.int64(2), T.int64(3)), "float32")): new_silu.attrs.inplace_indices == [0] +def test_insert_inplace_calls(): + @I.ir_module + class EndToEndTest: + @R.function + def main( + x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((1, 3), dtype="float32") + ) -> R.Tensor((2, 3), dtype="float32"): + with R.dataflow(): + z = R.add(x, y) # broadcast happens here + q = R.multiply(z, y) # broadcast again + r = R.subtract(y, y) # now can be done inplace + m = R.multiply(q, r) # should give us all zeros + R.output(m) + return m + + transform_pass = dataflow_insert_inplace_calls() + new_mod = transform_pass(EndToEndTest) + print(new_mod) + assert False + + if __name__ == "__main__": testing.main() From d8a2f4db6b6756b17a08efb7cc29bf2a6723fe37 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 9 Nov 2023 20:40:17 -0500 Subject: [PATCH 22/55] Have an end-to-end test case for the in-place transformation --- tests/python/relax/test_dataflow_in_place.py | 29 +++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/tests/python/relax/test_dataflow_in_place.py b/tests/python/relax/test_dataflow_in_place.py index fab038dd24a2..23de5675b667 100644 --- a/tests/python/relax/test_dataflow_in_place.py +++ b/tests/python/relax/test_dataflow_in_place.py @@ -17,8 +17,7 @@ from typing import List, Set, Tuple import tvm -from tvm import testing -from tvm.relax import BlockBuilder +from tvm import relax, testing from tvm.relax.analysis import ( dataflow_liveness_analysis, dataflow_alias_analysis, @@ -28,6 +27,8 @@ ) from tvm.script.parser import ir as I, relax as R, tir as T +import numpy as np + def test_liveness_analysis(): @I.ir_module @@ -433,8 +434,28 @@ def main( transform_pass = dataflow_insert_inplace_calls() new_mod = transform_pass(EndToEndTest) - print(new_mod) - assert False + + # check that all operations are done in-place + assert new_mod["add_inplace"] + assert new_mod["subtract_inplace"] + assert new_mod["multiply_inplace"] + expected_ops = ["add_inplace", "multiply_inplace", "subtract_inplace", "multiply_inplace"] + for i, binding in enumerate(new_mod["main"].body.blocks[0].bindings): + assert binding.value.op.name == "relax.call_tir_inplace" + assert binding.value.args[0].name_hint == expected_ops[i] + + x = tvm.nd.array(np.random.rand(2, 3).astype("float32")) + y = tvm.nd.array(np.random.rand(1, 3).astype("float32")) + expected = np.zeros((2, 3), dtype="float32") + + target = tvm.target.Target("llvm") + ex = relax.build(new_mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + res = vm["main"](x, y) + # due to reuse of buffers, the result is actually reference equal to argument x + # (we can disable this by setting the arguments to "unknown value" in the alias analysis) + assert res == x + assert (expected == res.numpy()).all() if __name__ == "__main__": From 87e2a41c64e834b4e059b0e0f2edd8e31a1ae91f Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 13 Nov 2023 17:40:58 -0500 Subject: [PATCH 23/55] Rebase fixes and use GetBoundValue instead of reimplementing it --- include/tvm/relax/transform.h | 3 +++ python/tvm/relax/analysis/__init__.py | 5 +++++ src/relax/transform/dataflow_in_place.cc | 23 ++++++----------------- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 5376d99ee15b..af794028102b 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -572,6 +572,9 @@ TVM_DLL Pass ConvertToDataflow(int min_size = 2); */ TVM_DLL Pass DeadCodeElimination(Array entry_functions); + + + /*! * \brief Automatic mixed precision pass. Currently the pass assumes the input module to be fp32 * only, and will automatically cast fp32 to fp16 for certain ops. diff --git a/python/tvm/relax/analysis/__init__.py b/python/tvm/relax/analysis/__init__.py index d8454a02cc84..2702c33839af 100644 --- a/python/tvm/relax/analysis/__init__.py +++ b/python/tvm/relax/analysis/__init__.py @@ -22,6 +22,11 @@ all_vars, bound_vars, contains_impure_call, + dataflow_liveness_analysis, + dataflow_alias_analysis, + dataflow_inplace_analysis, + dataflow_single_inplace_call, + dataflow_insert_inplace_calls, definable_tir_vars_in_struct_info, defined_symbolic_vars, derive_call_ret_struct_info, diff --git a/src/relax/transform/dataflow_in_place.cc b/src/relax/transform/dataflow_in_place.cc index 0dda066eab23..ea29d551e74d 100644 --- a/src/relax/transform/dataflow_in_place.cc +++ b/src/relax/transform/dataflow_in_place.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include "utils.h" @@ -31,25 +32,13 @@ namespace tvm { namespace relax { -Expr BindingValue(const Binding& b) { - Expr value; - if (const auto* var_binding = b.as()) { - value = var_binding->value; - } else if (const auto* match_binding = b.as()) { - value = match_binding->value; - } else { - CHECK(false) << "Invalid binding"; // impossible - } - return value; -} - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> analyze_liveness( const DataflowBlock& block) { std::unordered_map, ObjectPtrHash, ObjectPtrEqual> ret; for (int i = block->bindings.size() - 1; i >= 0; i--) { Binding b = block->bindings[i]; Var defined_var = b->var; - Expr value = BindingValue(b); + Expr value = GetBoundValue(b); Array used_vars; // for a function literal, we consider only the free vars // (those captured from the outer scope) @@ -110,7 +99,7 @@ class AliasAnalyzer { for (const Binding& binding : block->bindings) { Var current_var = binding->var; - Expr value = BindingValue(binding); + Expr value = GetBoundValue(binding); alias_map_[current_var] = get_alias_set(value, current_var); } @@ -546,7 +535,7 @@ std::pair>, std::vector>> find_inp // if we reach a binding check the conditions Binding b = block->bindings[i]; Var defined_var = b->var; - Expr value = BindingValue(b); + Expr value = GetBoundValue(b); if (auto* call_node = value.as()) { if (auto* op_node = call_node->op.as()) { @@ -756,7 +745,7 @@ class ModuleInplaceTransformer : public ExprMutator { auto block = GetRef(op); std::unordered_set free_var_set; for (auto binding : block->bindings) { - auto binding_free_vars = FreeVars(BindingValue(binding)); + auto binding_free_vars = FreeVars(GetBoundValue(binding)); free_var_set.insert(binding_free_vars.begin(), binding_free_vars.end()); } Array free_var_list(free_var_set.begin(), free_var_set.end()); @@ -774,7 +763,7 @@ class ModuleInplaceTransformer : public ExprMutator { continue; } auto target_binding = block->bindings[i]; - auto target_call = Downcast(BindingValue(target_binding)); + auto target_call = Downcast(GetBoundValue(target_binding)); // can just pick the first index arbitrarily (only using one output for now too) auto new_call = CreateInplaceCall(target_call, {exact_matches[current_match_index][1]}); // now replace the binding appropriately From 6e7d447478276564d29d776c1d58d19133998973 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 13 Nov 2023 17:46:08 -0500 Subject: [PATCH 24/55] Let's just use 'inplace' everywhere --- python/tvm/relax/analysis/analysis.py | 4 ++-- .../{dataflow_in_place.cc => dataflow_inplace.cc} | 8 ++++---- ...test_dataflow_in_place.py => test_dataflow_inplace.py} | 0 3 files changed, 6 insertions(+), 6 deletions(-) rename src/relax/transform/{dataflow_in_place.cc => dataflow_inplace.cc} (99%) rename tests/python/relax/{test_dataflow_in_place.py => test_dataflow_inplace.py} (100%) diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index 1f32b0e8b600..240a153b0e4a 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -556,7 +556,7 @@ def dataflow_alias_analysis( def dataflow_inplace_analysis( block: DataflowBlock, inputs: List[Var] ) -> Tuple[List[int], List[int]]: - index_lists = _ffi_api.DataflowInPlaceAnalysis(block, inputs) # type: ignore + index_lists = _ffi_api.DataflowInplaceAnalysis(block, inputs) # type: ignore return tuple(map(list, index_lists)) # type: ignore @@ -570,4 +570,4 @@ def dataflow_single_inplace_call( # also not actually an analysis but putting it here for testing def dataflow_insert_inplace_calls() -> tvm.ir.transform.Pass: - return _ffi_api.DataflowInsertInPlaceCalls() # type: ignore + return _ffi_api.DataflowInsertInplaceCalls() # type: ignore diff --git a/src/relax/transform/dataflow_in_place.cc b/src/relax/transform/dataflow_inplace.cc similarity index 99% rename from src/relax/transform/dataflow_in_place.cc rename to src/relax/transform/dataflow_inplace.cc index ea29d551e74d..d503854823ff 100644 --- a/src/relax/transform/dataflow_in_place.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -919,7 +919,7 @@ Array DataflowAliasAnalysis(const DataflowBlock& block, Array in return {new_alias_sets, new_tuple_map}; } -Array>> DataflowInPlaceAnalysis(const DataflowBlock& block, +Array>> DataflowInplaceAnalysis(const DataflowBlock& block, const Array& inputs) { auto index_lists = relax::find_inplace_opportunities(block, inputs); Array> size_match_array; @@ -944,8 +944,8 @@ Array>> DataflowInPlaceAnalysis(const DataflowBlock& block, TVM_REGISTER_GLOBAL("relax.analysis.DataflowLivenessAnalysis") .set_body_typed(DataflowLivenessAnalysis); TVM_REGISTER_GLOBAL("relax.analysis.DataflowAliasAnalysis").set_body_typed(DataflowAliasAnalysis); -TVM_REGISTER_GLOBAL("relax.analysis.DataflowInPlaceAnalysis") - .set_body_typed(DataflowInPlaceAnalysis); +TVM_REGISTER_GLOBAL("relax.analysis.DataflowInplaceAnalysis") + .set_body_typed(DataflowInplaceAnalysis); // really only for testing (not actually an analysis, will move) TVM_REGISTER_GLOBAL("relax.analysis.SingleInplaceCall") @@ -957,7 +957,7 @@ TVM_REGISTER_GLOBAL("relax.analysis.SingleInplaceCall") }); // not actually an analysis, will rename -TVM_REGISTER_GLOBAL("relax.analysis.DataflowInsertInPlaceCalls").set_body_typed([]() -> Pass { +TVM_REGISTER_GLOBAL("relax.analysis.DataflowInsertInplaceCalls").set_body_typed([]() -> Pass { return tvm::transform::CreateModulePass( [](const IRModule& mod, const PassContext& ctx) -> IRModule { ModuleInplaceTransformer transformer(mod); diff --git a/tests/python/relax/test_dataflow_in_place.py b/tests/python/relax/test_dataflow_inplace.py similarity index 100% rename from tests/python/relax/test_dataflow_in_place.py rename to tests/python/relax/test_dataflow_inplace.py From 03fb2a51b772d53ad7499660ba51e7e36ae5d59c Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 13 Nov 2023 22:57:05 -0500 Subject: [PATCH 25/55] Reorganize code and add more documentation --- include/tvm/relax/analysis.h | 4 - include/tvm/relax/transform.h | 10 +- python/tvm/relax/analysis/__init__.py | 5 - python/tvm/relax/analysis/analysis.py | 42 ----- python/tvm/relax/testing/transform.py | 51 +++++- python/tvm/relax/transform/__init__.py | 1 + python/tvm/relax/transform/transform.py | 15 ++ src/relax/transform/dataflow_inplace.cc | 183 ++++++++++++-------- tests/python/relax/test_dataflow_inplace.py | 6 +- 9 files changed, 188 insertions(+), 129 deletions(-) diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index 7ca021fc1558..6e2209d51950 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -533,10 +533,6 @@ TVM_DLL bool WellFormed(IRModule m, bool check_struct_info = true); TVM_DLL Map> SuggestLayoutTransforms( const Function& fn, Array write_buffer_transformations); -// included for testing purposes -TVM_DLL Map> DataflowLivenessAnalysis(const DataflowBlock& block); -TVM_DLL Map> DataflowAliasAnalysis(const DataflowBlock& block, - Array inputs); } // namespace relax } // namespace tvm diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index af794028102b..42d1d18d8a32 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -572,8 +572,14 @@ TVM_DLL Pass ConvertToDataflow(int min_size = 2); */ TVM_DLL Pass DeadCodeElimination(Array entry_functions); - - +/*! + * \brief Pass that changes calls to supported operators in dataflow blocks into in-place + * implementations. Supported operators will be replaced by calls to `call_tir_inplace` that invoke + * in-place PrimFunc implementations of those operators (which are based on the legalizations of + * those operators). + * \return The pass. + */ +TVM_DLL Pass DataflowUseInplaceCalls(); /*! * \brief Automatic mixed precision pass. Currently the pass assumes the input module to be fp32 diff --git a/python/tvm/relax/analysis/__init__.py b/python/tvm/relax/analysis/__init__.py index 2702c33839af..d8454a02cc84 100644 --- a/python/tvm/relax/analysis/__init__.py +++ b/python/tvm/relax/analysis/__init__.py @@ -22,11 +22,6 @@ all_vars, bound_vars, contains_impure_call, - dataflow_liveness_analysis, - dataflow_alias_analysis, - dataflow_inplace_analysis, - dataflow_single_inplace_call, - dataflow_insert_inplace_calls, definable_tir_vars_in_struct_info, defined_symbolic_vars, derive_call_ret_struct_info, diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index 240a153b0e4a..8e68b12d45ae 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -529,45 +529,3 @@ def detect_recursion(mod: tvm.IRModule) -> List[List[GlobalVar]]: with any other, it will be a singleton in this list. """ return _ffi_api.detect_recursion(mod) # type: ignore - - -# expose for testing -def dataflow_liveness_analysis(block: DataflowBlock) -> Dict[Var, Tuple[int, int]]: - live_ranges = _ffi_api.DataflowLivenessAnalysis(block) # type: ignore - ret = {} - for var, live_range in live_ranges.items(): - ret[var] = tuple(live_range) - return ret # type: ignore - - -def dataflow_alias_analysis( - block: DataflowBlock, inputs: List[Var] -) -> Tuple[Dict[Var, Set[int]], Dict[int, List[Set[int]]]]: - alias_sets, tuple_map = _ffi_api.DataflowAliasAnalysis(block, inputs) # type: ignore - res_alias_sets = {} - res_tuple_map = {} - for var, alias_set in alias_sets.items(): - res_alias_sets[var] = set(alias_set) - for idx, elem_alias_sets in tuple_map.items(): - res_tuple_map[idx] = [set(alias_set) for alias_set in elem_alias_sets] - return res_alias_sets, res_tuple_map # type: ignore - - -def dataflow_inplace_analysis( - block: DataflowBlock, inputs: List[Var] -) -> Tuple[List[int], List[int]]: - index_lists = _ffi_api.DataflowInplaceAnalysis(block, inputs) # type: ignore - return tuple(map(list, index_lists)) # type: ignore - - -# not actually an analysis but putting it here for testing -def dataflow_single_inplace_call( - mod: IRModule, call: Call, inplace_indices: List[int] -) -> Tuple[Call, IRModule]: - ret = _ffi_api.SingleInplaceCall(mod, call, inplace_indices) # type: ignore - return (ret[0], ret[1]) # type: ignore - - -# also not actually an analysis but putting it here for testing -def dataflow_insert_inplace_calls() -> tvm.ir.transform.Pass: - return _ffi_api.DataflowInsertInplaceCalls() # type: ignore diff --git a/python/tvm/relax/testing/transform.py b/python/tvm/relax/testing/transform.py index ccae38a138a3..aeb9362729d1 100644 --- a/python/tvm/relax/testing/transform.py +++ b/python/tvm/relax/testing/transform.py @@ -17,13 +17,14 @@ # pylint: disable=unused-argument, invalid-name, no-else-return, abstract-method, arguments-differ """Relax transformation passes for testing""" +from typing import Dict, List, Set, Tuple import tvm from tvm import ir, relax from tvm.ir import transform from tvm.ir.module import IRModule from tvm.ir.transform import PassContext from tvm.relax import PyExprMutator -from tvm.relax.expr import Call +from tvm.relax.expr import Call, DataflowBlock, Var from tvm.relay.backend.te_compiler import select_implementation from tvm.target import Target @@ -128,3 +129,51 @@ def transform(self): def ApplyEmptyCppMutator() -> tvm.ir.transform.Pass: packed_func = tvm.get_global_func("relax.testing.transform.ApplyEmptyCppMutator") return packed_func() + + +# inner functions for the dataflow inplace transformation exposed for testing +def dataflow_liveness_analysis(block: DataflowBlock) -> Dict[Var, Tuple[int, int]]: + live_ranges = tvm.get_global_func("relax.testing.transform.DataflowLivenessAnalysis")( + block + ) # type: ignore + ret = {} + for var, live_range in live_ranges.items(): + ret[var] = tuple(live_range) + return ret # type: ignore + + +def dataflow_alias_analysis( + block: DataflowBlock, inputs: List[Var] +) -> Tuple[Dict[Var, Set[int]], Dict[int, List[Set[int]]]]: + alias_sets, tuple_map = tvm.get_global_func("relax.testing.transform.DataflowAliasAnalysis")( + block, + inputs, + ) # type: ignore + res_alias_sets = {} + res_tuple_map = {} + for var, alias_set in alias_sets.items(): + res_alias_sets[var] = set(alias_set) + for idx, elem_alias_sets in tuple_map.items(): + res_tuple_map[idx] = [set(alias_set) for alias_set in elem_alias_sets] + return res_alias_sets, res_tuple_map # type: ignore + + +def dataflow_inplace_analysis( + block: DataflowBlock, inputs: List[Var] +) -> Tuple[List[int], List[int]]: + index_lists = tvm.get_global_func("relax.testing.transform.DataflowInplaceAnalysis")( + block, + inputs, + ) # type: ignore + return tuple(map(list, index_lists)) # type: ignore + + +def dataflow_single_inplace_call( + mod: IRModule, call: Call, inplace_indices: List[int] +) -> Tuple[Call, IRModule]: + ret = tvm.get_global_func("relax.testing.transform.SingleInplaceCall")( + mod, + call, + inplace_indices, + ) # type: ignore + return (ret[0], ret[1]) # type: ignore diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 19316c76b83d..353ee88b6898 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -30,6 +30,7 @@ ConvertLayout, ConvertToDataflow, DataflowBlockPass, + DataflowUseInplaceCalls, DeadCodeElimination, DecomposeOpsForInference, DecomposeOpsForTraining, diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 9589f661d79e..379a22b487f0 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -252,6 +252,21 @@ def RemovePurityChecking() -> tvm.ir.transform.Pass: return _ffi_api.RemovePurityChecking() # type: ignore +def DataflowUseInplaceCalls() -> tvm.ir.transform.Pass: + """ + Pass that changes calls to supported operators in dataflow blocks into in-place + implementations. Supported operators will be replaced by calls to `call_tir_inplace` that invoke + in-place PrimFunc implementations of those operators (which are based on the legalizations of + those operators). + + Returns + ------- + ret: tvm.ir.transform.Pass + The pass + """ + return _ffi_api.DataflowUseInplaceCalls() + + def LambdaLift() -> tvm.ir.transform.Pass: """A pass that lifts local functions into global. diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index d503854823ff..5c49bd624098 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -17,6 +17,11 @@ * specific language governing permissions and limitations * under the License. */ +/*! + * \file src/relax/transform/dataflow_inplace.cc + * \brief Pass that converts eligible operator calls in dataflow blocks + * into in-place versions. + */ #include #include @@ -32,7 +37,11 @@ namespace tvm { namespace relax { -std::unordered_map, ObjectPtrHash, ObjectPtrEqual> analyze_liveness( +// Perform liveness analysis on a dataflow block, returning a map of vars to +// pairs of indices (the liveness interval, from the starting index to the end index). +// A starting index of -1 means the var is defined before the block starts and an end index +// of block->bindings.size() (one past the last index) means it is live after the block ends. +std::unordered_map, ObjectPtrHash, ObjectPtrEqual> AnalyzeLiveness( const DataflowBlock& block) { std::unordered_map, ObjectPtrHash, ObjectPtrEqual> ret; for (int i = block->bindings.size() - 1; i >= 0; i--) { @@ -84,8 +93,9 @@ class AliasAnalyzer { public: explicit AliasAnalyzer() : alias_map_(), tuple_map_(), mem_idx_(0) {} - // alias: map of var to memory locations (we will call these indices and use -1 as an index for - // "unknown") + // The analysis returns a map of vars to memory locations that it *could* map to + // (any unique allocation = one memory location), plus a map of memory locations + // that correspond to tuples (this maps to sets of memory locations for each tuple element) std::pair, ObjectPtrHash, ObjectPtrEqual>, std::unordered_map>>> Analyze(const DataflowBlock& block, const Array& inputs) { @@ -93,14 +103,14 @@ class AliasAnalyzer { int curr_idx = get_fresh_idx(); alias_map_[input] = {curr_idx}; if (auto* tup_info = GetStructInfoAs(input)) { - insert_fresh_tuple(curr_idx, tup_info); + InsertFreshTuple(curr_idx, tup_info); } } for (const Binding& binding : block->bindings) { Var current_var = binding->var; Expr value = GetBoundValue(binding); - alias_map_[current_var] = get_alias_set(value, current_var); + alias_map_[current_var] = GetAliasSet(value, current_var); } return {alias_map_, tuple_map_}; @@ -113,19 +123,22 @@ class AliasAnalyzer { return ret; } - void insert_fresh_tuple(int tup_idx, const TupleStructInfoNode* tup_info) { + // Fresh tuple = each element is assumed to be a unique allocation + void InsertFreshTuple(int tup_idx, const TupleStructInfoNode* tup_info) { std::vector> tuple_set; for (int i = 0; i < static_cast(tup_info->fields.size()); i++) { int curr_field = get_fresh_idx(); tuple_set.push_back({curr_field}); if (auto* nested_tup_info = tup_info->fields[i].as()) { - insert_fresh_tuple(curr_field, nested_tup_info); + InsertFreshTuple(curr_field, nested_tup_info); } } tuple_map_[tup_idx] = tuple_set; } - void update_tuple_components(int tup_idx, const std::unordered_set& insert_idxs) { + // given a tuple index, add the given memory location indices to each component's + // alias set + void UpdateTupleComponents(int tup_idx, const std::unordered_set& insert_idxs) { if (tuple_map_.count(tup_idx)) { auto tuple_comps = tuple_map_[tup_idx]; for (size_t i = 0; i < tuple_comps.size(); i++) { @@ -134,7 +147,7 @@ class AliasAnalyzer { // if a member is a tuple, update its components as well for (int member : comp_set) { if (tuple_map_.count(member)) { - update_tuple_components(member, insert_idxs); + UpdateTupleComponents(member, insert_idxs); } } @@ -146,12 +159,12 @@ class AliasAnalyzer { // capture the given index and also its tuple components (including recursively) // if they exist - void add_captured_indices(std::unordered_set* captured_set, int idx) { + void AddCapturedIndices(std::unordered_set* captured_set, int idx) { captured_set->insert(idx); if (tuple_map_.count(idx)) { for (auto comp_set : tuple_map_[idx]) { for (auto tup_comp_idx : comp_set) { - add_captured_indices(captured_set, tup_comp_idx); + AddCapturedIndices(captured_set, tup_comp_idx); } } } @@ -163,31 +176,33 @@ class AliasAnalyzer { // For tuples, assume all members are aliased. Yeah, it's bad. // (Skip first arg is for handling call_pure_packed, where the first arg is an ExternFunc that we // should ignore) - std::unordered_set handle_mystery_call(const CallNode* call_node, const Var& bound_var, - bool skip_first_arg = false) { + std::unordered_set HandleMysteryCall(const CallNode* call_node, const Var& bound_var, + bool skip_first_arg = false) { // the result may or may not be newly allocated std::unordered_set ret; int res_idx = get_fresh_idx(); // the result may be a tuple if (auto* tup_info_node = GetStructInfoAs(bound_var)) { - insert_fresh_tuple(res_idx, tup_info_node); + InsertFreshTuple(res_idx, tup_info_node); } - add_captured_indices(&ret, res_idx); + AddCapturedIndices(&ret, res_idx); for (size_t i = (skip_first_arg) ? 1 : 0; i < call_node->args.size(); i++) { auto arg = call_node->args[i]; - auto arg_alias_set = get_alias_set(arg, bound_var); + auto arg_alias_set = GetAliasSet(arg, bound_var); for (int alias_idx : arg_alias_set) { - add_captured_indices(&ret, alias_idx); + AddCapturedIndices(&ret, alias_idx); } } // if the result is a tuple, the components can also potentially be aliased to any arg // or, in fact, to each other - update_tuple_components(res_idx, ret); + UpdateTupleComponents(res_idx, ret); return ret; } - std::unordered_set get_alias_set(const Expr& value, const Var& bound_var) { + // given the expression value, return the set of memory locations corresponding to it + // (the var the expression is being bound to is needed for struct info) + std::unordered_set GetAliasSet(const Expr& value, const Var& bound_var) { std::unordered_set ret; // cases for value: @@ -217,11 +232,11 @@ class AliasAnalyzer { ret.insert(tup_idx); std::vector> new_tuple_map; for (auto field : target_tuple->fields) { - new_tuple_map.push_back(get_alias_set(field, bound_var)); + new_tuple_map.push_back(GetAliasSet(field, bound_var)); } tuple_map_[tup_idx] = new_tuple_map; } else if (auto* target_tgi = value.as()) { - std::unordered_set tuple_set = get_alias_set(target_tgi->tuple, bound_var); + std::unordered_set tuple_set = GetAliasSet(target_tgi->tuple, bound_var); // if -1 is a member of the tuple set, then we have to assume the result is -1 if (tuple_set.count(-1)) { ret.insert(-1); @@ -241,7 +256,7 @@ class AliasAnalyzer { if (auto* op_node = call_node->op.as()) { // call_pure_packed: treat as non-op call if (op_node->name == "relax.call_pure_packed") { - return handle_mystery_call(call_node, bound_var, true); + return HandleMysteryCall(call_node, bound_var, true); } // split: Returns a tuple, treat as allocation else if (op_node->name == "relax.split") { @@ -249,14 +264,14 @@ class AliasAnalyzer { int tup_idx = get_fresh_idx(); ret.insert(tup_idx); // the LHS (the bound var) will definitely have a tuple struct info - insert_fresh_tuple(tup_idx, GetStructInfoAs(bound_var)); + InsertFreshTuple(tup_idx, GetStructInfoAs(bound_var)); } // call_tir: can potentially return a tuple else if (op_node->name == "relax.call_tir") { if (auto* tuple_struct_info = call_node->sinfo_args[0].as()) { int tup_idx = get_fresh_idx(); ret.insert(tup_idx); - insert_fresh_tuple(tup_idx, tuple_struct_info); + InsertFreshTuple(tup_idx, tuple_struct_info); } else { ret.insert(get_fresh_idx()); } @@ -268,7 +283,7 @@ class AliasAnalyzer { } } else { // assume any non-op call can be extremely dangerous and do anything - return handle_mystery_call(call_node, bound_var); + return HandleMysteryCall(call_node, bound_var); } } @@ -280,7 +295,8 @@ class AliasAnalyzer { int mem_idx_; }; -int shape_size(const ShapeExpr& shape) { +// given a shape, return the allocation size corresponding to it (product of elements) +int ShapeSize(const ShapeExpr& shape) { int ret = 1; for (auto dim : shape->values) { if (auto int_dim = dim.as()) { @@ -292,7 +308,11 @@ int shape_size(const ShapeExpr& shape) { return ret; } -std::unordered_set gather_candidate_sinfo( +// Given the struct info of the result, return any struct info nested in it +// that is eleigible to be used for in-place computations (tensors are eligible +// only if all their dimensions are integer constants, tuples are eligible if +// all members are eligible though we can consider only individual members separately) +std::unordered_set GatherCandidateSinfo( const StructInfo& result_sinfo) { if (auto* tensor_info = result_sinfo.as()) { // don't consider void dtype (don't know the size at compile time) @@ -315,7 +335,7 @@ std::unordered_set gather_candidate_s // we can see if the whole tuple matches or go for any of the components std::unordered_set ret; for (auto field : tuple_info->fields) { - auto field_candidates = gather_candidate_sinfo(field); + auto field_candidates = GatherCandidateSinfo(field); ret.insert(field_candidates.begin(), field_candidates.end()); } // at least one field should be eligible to be done in-place @@ -329,7 +349,12 @@ std::unordered_set gather_candidate_s } } -std::pair size_matches(const StructInfo& target_info, const StructInfo& arg_info) { +// Given the two struct info, return a pair of bools where the first element is true if +// the two struct info are the same _size_ in memory and the second element is true +// if the shapes match _exactly_. Performs this check recursively and ensures the +// stated condition is true for all tensor members of the struct info (return false +// if a single pair of corresponding tensors does not meet the condition). +std::pair SizeMatches(const StructInfo& target_info, const StructInfo& arg_info) { if (target_info.as() && arg_info.as()) { auto target_tensor = Downcast(target_info); auto arg_tensor = Downcast(arg_info); @@ -340,8 +365,8 @@ std::pair size_matches(const StructInfo& target_info, const StructIn } auto target_shape = Downcast(target_tensor->shape); auto arg_shape = Downcast(arg_tensor->shape); - int target_size = shape_size(target_shape); - int arg_size = shape_size(arg_shape); + int target_size = ShapeSize(target_shape); + int arg_size = ShapeSize(arg_shape); if (target_size == -1 || arg_size == -1 || target_size < arg_size) { return {false, false}; } @@ -370,7 +395,7 @@ std::pair size_matches(const StructInfo& target_info, const StructIn } bool all_exact = true; for (size_t i = 0; i < target_tup->fields.size(); i++) { - auto element_match = size_matches(target_tup->fields[i], arg_tup->fields[i]); + auto element_match = SizeMatches(target_tup->fields[i], arg_tup->fields[i]); if (!element_match.first) { return {false, false}; } @@ -390,7 +415,7 @@ std::pair size_matches(const StructInfo& target_info, const StructIn // members if so (apply recursively if any of those members are tuples). // Return false if the alias set contains -1, meaning a reference to an unknown or // possibly dangerous value (no checking we can do for that). -bool gather_sets_to_check_for_liveness( +bool GatherSetsToCheckForLiveness( const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& alias_sets, const std::unordered_map>>& tuple_map, @@ -406,7 +431,7 @@ bool gather_sets_to_check_for_liveness( // if a member can be a tuple, check it recursively for (int member : member_set) { if (tuple_map.count(member)) { - if (!gather_sets_to_check_for_liveness(alias_sets, tuple_map, sets_to_check, member)) { + if (!GatherSetsToCheckForLiveness(alias_sets, tuple_map, sets_to_check, member)) { return false; } } @@ -416,9 +441,9 @@ bool gather_sets_to_check_for_liveness( return true; } -// check that the target is not live past the index and that no alias of it is live past the +// Check that the target is not live past the index and that no alias of it is live past the // binding index (if the target is a tuple, check the conditions recursively for the members) -bool df_inplace_conditions_met( +bool InplaceConditionsMet( const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& live_ranges, const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& alias_sets, @@ -449,7 +474,7 @@ bool df_inplace_conditions_met( // If a possible alias is a tuple, we will also check for aliases of the members // (possibly recursively) for (int alias_idx : alias_set) { - if (!gather_sets_to_check_for_liveness(alias_sets, tuple_map, &sets_to_check, alias_idx)) { + if (!GatherSetsToCheckForLiveness(alias_sets, tuple_map, &sets_to_check, alias_idx)) { return false; } } @@ -474,8 +499,7 @@ bool df_inplace_conditions_met( return true; } else if (auto* tup_node = target.as()) { for (auto field : tup_node->fields) { - if (!df_inplace_conditions_met(live_ranges, alias_sets, tuple_map, currently_live, field, - idx)) { + if (!InplaceConditionsMet(live_ranges, alias_sets, tuple_map, currently_live, field, idx)) { return false; } } @@ -489,16 +513,24 @@ bool df_inplace_conditions_met( static std::unordered_set SUPPORTED_OPS = {"relax.add", "relax.subtract", "relax.multiply", "relax.divide", "relax.nn.silu", "relax.nn.relu"}; -bool op_supports_inplace(const Op& op) { return SUPPORTED_OPS.count(op->name); } +bool OpSupportsInplace(const Op& op) { return SUPPORTED_OPS.count(op->name); } -// check for in-place eligibility: +// Check for in-place eligibility: // 1. see if there's an arg big enough to hold the result // 2. see if the arg is live past the call // 3. see if the arg has an alias that's live past the call -// if conditions are met, we're good to go -std::pair>, std::vector>> find_inplace_opportunities( +// If the conditions are met, record the index of that binding. +// Returns two lists of lists: +// 1. A list of bindings where at least one argument meets the in-place conditions and the *size* +// matches the size of the result. +// 2. A list of bindings where at least one argument meets the in-place conditions +// and *exactly* matches the shape of the result. +// For both lists, each element is a list of ints of the following format: +// The first element is the index of the *binding* in the block. +// All remaining elements are the indices of *eligible arguments* in that call. +std::pair>, std::vector>> FindInplaceOpportunities( const DataflowBlock& block, const Array& inputs) { - auto live_ranges = analyze_liveness(block); + auto live_ranges = AnalyzeLiveness(block); AliasAnalyzer analyzer; auto alias_info = analyzer.Analyze(block, inputs); auto alias_sets = alias_info.first; @@ -539,14 +571,14 @@ std::pair>, std::vector>> find_inp if (auto* call_node = value.as()) { if (auto* op_node = call_node->op.as()) { - if (!op_supports_inplace(GetRef(op_node))) { + if (!OpSupportsInplace(GetRef(op_node))) { continue; } std::unordered_set candidates; std::unordered_set exact_match_candidates; - auto target_sinfo = gather_candidate_sinfo(GetStructInfo(defined_var)); + auto target_sinfo = GatherCandidateSinfo(GetStructInfo(defined_var)); // can't be done in-place, ignore if (target_sinfo.empty()) { continue; @@ -556,7 +588,7 @@ std::pair>, std::vector>> find_inp for (size_t j = 0; j < call_node->args.size(); j++) { auto arg = call_node->args[j]; for (auto target : target_sinfo) { - std::pair match = size_matches(target, GetStructInfo(arg)); + std::pair match = SizeMatches(target, GetStructInfo(arg)); if (match.first) { candidates.insert(static_cast(j)); if (match.second) { @@ -573,8 +605,8 @@ std::pair>, std::vector>> find_inp // live past this point std::unordered_set remove_candidates; for (auto candidate : candidates) { - if (!df_inplace_conditions_met(live_ranges, alias_sets, tuple_map, currently_live, - call_node->args[candidate], i)) { + if (!InplaceConditionsMet(live_ranges, alias_sets, tuple_map, currently_live, + call_node->args[candidate], i)) { remove_candidates.insert(candidate); } } @@ -623,7 +655,8 @@ std::pair>, std::vector>> find_inp return {size_match_list, exact_match_list}; } -tir::Stmt remap_buffers(const tir::Stmt& stmt, const Map& buffer_map) { +// Replace buffers in a PrimFunc according to the mapping. +tir::Stmt RemapBuffers(const tir::Stmt& stmt, const Map& buffer_map) { class BufferMapper : public tir::StmtExprMutator { public: explicit BufferMapper(const Map& buffer_map) @@ -751,7 +784,7 @@ class ModuleInplaceTransformer : public ExprMutator { Array free_var_list(free_var_set.begin(), free_var_set.end()); // for now, only handle exact match cases - auto matches_found = find_inplace_opportunities(block, free_var_list); + auto matches_found = FindInplaceOpportunities(block, free_var_list); auto exact_matches = matches_found.second; Array new_bindings; @@ -785,7 +818,9 @@ class ModuleInplaceTransformer : public ExprMutator { return DataflowBlock(new_bindings, block->span); } - // exposed for testing + // Given the call and indices of arguments that could be done in-place, + // replace the call with a call to an in-place PrimFunc. + // (Made public for testing.) Call CreateInplaceCall(const Call& call, const Array& inplace_indices) { static const auto& legalize_map = Op::GetAttrMap("FLegalize"); static const auto& call_tir_inplace_op = Op::Get("relax.call_tir_inplace"); @@ -829,7 +864,7 @@ class ModuleInplaceTransformer : public ExprMutator { } // apply substitutions - legal_primfunc_cow->body = remap_buffers(legal_primfunc->body, buffer_subst_map); + legal_primfunc_cow->body = RemapBuffers(legal_primfunc->body, buffer_subst_map); legal_primfunc_cow->body = tir::Substitute( legal_primfunc->body, [&var_subst_map](const tir::Var& v) -> Optional { if (var_subst_map.count(v)) { @@ -868,7 +903,7 @@ class ModuleInplaceTransformer : public ExprMutator { return legalized_call; } - // exposed for testing + // Made public for testing. IRModule CurrentMod() { return builder_->GetContextIRModule(); } private: @@ -879,11 +914,10 @@ class ModuleInplaceTransformer : public ExprMutator { // (we are assuming good behavior on the user's part). }; -// export for testing namespace transform { Map> DataflowLivenessAnalysis(const DataflowBlock& block) { - auto liveness_ranges = analyze_liveness(block); + auto liveness_ranges = AnalyzeLiveness(block); Map> ret; for (auto kv : liveness_ranges) { ret.Set(kv.first, {kv.second.first, kv.second.second}); @@ -919,9 +953,20 @@ Array DataflowAliasAnalysis(const DataflowBlock& block, Array in return {new_alias_sets, new_tuple_map}; } +// this would be preferable to do as a dataflow block pass, +// but the transformation adds new PrimFuncs, so it affects the module +tvm::transform::Pass DataflowUseInplaceCalls() { + return tvm::transform::CreateModulePass( + [](const IRModule& mod, const PassContext& ctx) -> IRModule { + ModuleInplaceTransformer transformer(mod); + return transformer.Transform(); + }, + 0, "DataflowInsertInPlaceCalls", {}, false); +} + Array>> DataflowInplaceAnalysis(const DataflowBlock& block, const Array& inputs) { - auto index_lists = relax::find_inplace_opportunities(block, inputs); + auto index_lists = relax::FindInplaceOpportunities(block, inputs); Array> size_match_array; for (auto indices : index_lists.first) { Array index_array; @@ -941,14 +986,14 @@ Array>> DataflowInplaceAnalysis(const DataflowBlock& block, return {size_match_array, exact_match_array}; } -TVM_REGISTER_GLOBAL("relax.analysis.DataflowLivenessAnalysis") +// these are exposed only for testing +TVM_REGISTER_GLOBAL("relax.testing.transform.DataflowLivenessAnalysis") .set_body_typed(DataflowLivenessAnalysis); -TVM_REGISTER_GLOBAL("relax.analysis.DataflowAliasAnalysis").set_body_typed(DataflowAliasAnalysis); -TVM_REGISTER_GLOBAL("relax.analysis.DataflowInplaceAnalysis") +TVM_REGISTER_GLOBAL("relax.testing.transform.DataflowAliasAnalysis") + .set_body_typed(DataflowAliasAnalysis); +TVM_REGISTER_GLOBAL("relax.testing.transform.DataflowInplaceAnalysis") .set_body_typed(DataflowInplaceAnalysis); - -// really only for testing (not actually an analysis, will move) -TVM_REGISTER_GLOBAL("relax.analysis.SingleInplaceCall") +TVM_REGISTER_GLOBAL("relax.testing.transform.SingleInplaceCall") .set_body_typed([](const IRModule& mod, const Call& call, const Array& inplace_indices) -> Array { ModuleInplaceTransformer transformer(mod); @@ -956,15 +1001,9 @@ TVM_REGISTER_GLOBAL("relax.analysis.SingleInplaceCall") return Array{ret_call, transformer.CurrentMod()}; }); -// not actually an analysis, will rename -TVM_REGISTER_GLOBAL("relax.analysis.DataflowInsertInplaceCalls").set_body_typed([]() -> Pass { - return tvm::transform::CreateModulePass( - [](const IRModule& mod, const PassContext& ctx) -> IRModule { - ModuleInplaceTransformer transformer(mod); - return transformer.Transform(); - }, - 0, "DataflowInsertInPlaceCalls", {}, false); -}); +// actually exposed +TVM_REGISTER_GLOBAL("relax.transform.DataflowUseInplaceCalls") + .set_body_typed(DataflowUseInplaceCalls); } // namespace transform } // namespace relax diff --git a/tests/python/relax/test_dataflow_inplace.py b/tests/python/relax/test_dataflow_inplace.py index 23de5675b667..435164f3ebea 100644 --- a/tests/python/relax/test_dataflow_inplace.py +++ b/tests/python/relax/test_dataflow_inplace.py @@ -18,12 +18,12 @@ from typing import List, Set, Tuple import tvm from tvm import relax, testing -from tvm.relax.analysis import ( +from tvm.relax.transform import DataflowUseInplaceCalls +from tvm.relax.testing.transform import ( dataflow_liveness_analysis, dataflow_alias_analysis, dataflow_inplace_analysis, dataflow_single_inplace_call, - dataflow_insert_inplace_calls, ) from tvm.script.parser import ir as I, relax as R, tir as T @@ -432,7 +432,7 @@ def main( R.output(m) return m - transform_pass = dataflow_insert_inplace_calls() + transform_pass = DataflowUseInplaceCalls() new_mod = transform_pass(EndToEndTest) # check that all operations are done in-place From bbd1d53fbf6bb89495472a71a8460baa70d26dee Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 14 Nov 2023 19:14:22 -0500 Subject: [PATCH 26/55] Include proper bounds check --- src/relax/transform/dataflow_inplace.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index 5c49bd624098..9eb69431f213 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -790,11 +790,12 @@ class ModuleInplaceTransformer : public ExprMutator { Array new_bindings; int current_match_index = 0; for (size_t i = 0; i < block->bindings.size(); i++) { - int candidate_binding_idx = exact_matches[current_match_index][0]; - if (candidate_binding_idx != static_cast(i)) { + if (current_match_index >= static_cast(exact_matches.size()) || + exact_matches[current_match_index][0] != static_cast(i)) { new_bindings.push_back(block->bindings[i]); continue; } + auto target_binding = block->bindings[i]; auto target_call = Downcast(GetBoundValue(target_binding)); // can just pick the first index arbitrarily (only using one output for now too) From 47639e2a4746e977672b615c1f0837cecf7b5149 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 14 Nov 2023 21:23:19 -0500 Subject: [PATCH 27/55] Trailing whitespace --- include/tvm/relax/transform.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 42d1d18d8a32..bc6186a9146b 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -576,7 +576,7 @@ TVM_DLL Pass DeadCodeElimination(Array entry_functions); * \brief Pass that changes calls to supported operators in dataflow blocks into in-place * implementations. Supported operators will be replaced by calls to `call_tir_inplace` that invoke * in-place PrimFunc implementations of those operators (which are based on the legalizations of - * those operators). + * those operators). * \return The pass. */ TVM_DLL Pass DataflowUseInplaceCalls(); From ba7fb3d0b22ebea478a851c2967cfb14fcb527d2 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 14 Nov 2023 21:52:08 -0500 Subject: [PATCH 28/55] Need a trailing newline --- src/relax/transform/dataflow_inplace.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index 9eb69431f213..d44715c0bb69 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -1008,4 +1008,4 @@ TVM_REGISTER_GLOBAL("relax.transform.DataflowUseInplaceCalls") } // namespace transform } // namespace relax -} // namespace tvm \ No newline at end of file +} // namespace tvm From ed7595a3f5c604e3549c5ad97ae16e29b9076da7 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 15 Nov 2023 14:58:07 -0500 Subject: [PATCH 29/55] Remove unused imports --- python/tvm/relax/analysis/analysis.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index 8e68b12d45ae..38f5ea2fea0e 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -21,13 +21,12 @@ configuring the passes and scripting them in Python. """ -from typing import Dict, List, Optional, Set, Tuple, Union, Callable +from typing import Dict, List, Optional, Union, Callable from enum import IntEnum import tvm from tvm import tir from tvm import IRModule -from tvm.relax import BlockBuilder from tvm.relax.ty import Type from tvm.relax.struct_info import StructInfo, FuncStructInfo from tvm.relax.expr import DataflowBlock, Var, GlobalVar, Expr, Function, Call, Binding From 77eab796d71617d9fd919262d07912aab448efd1 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 15 Nov 2023 14:58:23 -0500 Subject: [PATCH 30/55] Add docstrings for exposed inner functions --- python/tvm/relax/testing/transform.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/testing/transform.py b/python/tvm/relax/testing/transform.py index aeb9362729d1..9ba16d1f90cf 100644 --- a/python/tvm/relax/testing/transform.py +++ b/python/tvm/relax/testing/transform.py @@ -131,7 +131,7 @@ def ApplyEmptyCppMutator() -> tvm.ir.transform.Pass: return packed_func() -# inner functions for the dataflow inplace transformation exposed for testing +# inner function for the dataflow inplace transformation exposed for testing def dataflow_liveness_analysis(block: DataflowBlock) -> Dict[Var, Tuple[int, int]]: live_ranges = tvm.get_global_func("relax.testing.transform.DataflowLivenessAnalysis")( block @@ -142,6 +142,7 @@ def dataflow_liveness_analysis(block: DataflowBlock) -> Dict[Var, Tuple[int, int return ret # type: ignore +# inner function for the dataflow inplace transformation exposed for testing def dataflow_alias_analysis( block: DataflowBlock, inputs: List[Var] ) -> Tuple[Dict[Var, Set[int]], Dict[int, List[Set[int]]]]: @@ -158,6 +159,7 @@ def dataflow_alias_analysis( return res_alias_sets, res_tuple_map # type: ignore +# inner function for the dataflow inplace transformation exposed for testing def dataflow_inplace_analysis( block: DataflowBlock, inputs: List[Var] ) -> Tuple[List[int], List[int]]: @@ -168,6 +170,7 @@ def dataflow_inplace_analysis( return tuple(map(list, index_lists)) # type: ignore +# inner function for the dataflow inplace transformation exposed for testing def dataflow_single_inplace_call( mod: IRModule, call: Call, inplace_indices: List[int] ) -> Tuple[Call, IRModule]: From 71e4d713d5955ad5174c3b8e457789b6a9b48bb5 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 16 Nov 2023 15:23:29 -0500 Subject: [PATCH 31/55] Reformat docstrings to appease the linter --- python/tvm/relax/testing/transform.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/testing/transform.py b/python/tvm/relax/testing/transform.py index 9ba16d1f90cf..4521dc46fa93 100644 --- a/python/tvm/relax/testing/transform.py +++ b/python/tvm/relax/testing/transform.py @@ -131,8 +131,10 @@ def ApplyEmptyCppMutator() -> tvm.ir.transform.Pass: return packed_func() -# inner function for the dataflow inplace transformation exposed for testing def dataflow_liveness_analysis(block: DataflowBlock) -> Dict[Var, Tuple[int, int]]: + """ + Inner function for the dataflow inplace transformation exposed for testing. + """ live_ranges = tvm.get_global_func("relax.testing.transform.DataflowLivenessAnalysis")( block ) # type: ignore @@ -142,10 +144,12 @@ def dataflow_liveness_analysis(block: DataflowBlock) -> Dict[Var, Tuple[int, int return ret # type: ignore -# inner function for the dataflow inplace transformation exposed for testing def dataflow_alias_analysis( block: DataflowBlock, inputs: List[Var] ) -> Tuple[Dict[Var, Set[int]], Dict[int, List[Set[int]]]]: + """ + Inner function for the dataflow inplace transformation exposed for testing. + """ alias_sets, tuple_map = tvm.get_global_func("relax.testing.transform.DataflowAliasAnalysis")( block, inputs, @@ -159,10 +163,12 @@ def dataflow_alias_analysis( return res_alias_sets, res_tuple_map # type: ignore -# inner function for the dataflow inplace transformation exposed for testing def dataflow_inplace_analysis( block: DataflowBlock, inputs: List[Var] ) -> Tuple[List[int], List[int]]: + """ + Inner function for the dataflow inplace transformation exposed for testing. + """ index_lists = tvm.get_global_func("relax.testing.transform.DataflowInplaceAnalysis")( block, inputs, @@ -170,10 +176,12 @@ def dataflow_inplace_analysis( return tuple(map(list, index_lists)) # type: ignore -# inner function for the dataflow inplace transformation exposed for testing def dataflow_single_inplace_call( mod: IRModule, call: Call, inplace_indices: List[int] ) -> Tuple[Call, IRModule]: + """ + Inner function for the dataflow inplace transformation exposed for testing. + """ ret = tvm.get_global_func("relax.testing.transform.SingleInplaceCall")( mod, call, From 17724beec4a5725175049d7780c6757a66901e54 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 16 Nov 2023 17:15:56 -0500 Subject: [PATCH 32/55] C++ stylistic changes --- src/relax/transform/dataflow_inplace.cc | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index d44715c0bb69..7399e258d861 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -59,7 +59,6 @@ std::unordered_map, ObjectPtrHash, ObjectPtrEqual> Anal // create tuples to be done in-place (otherwise, any index of the tuple // would be considered a use and so the tuple would be live later). // Hence we keep the array empty. - ; } else { used_vars = AllVars(value); } @@ -91,7 +90,7 @@ std::unordered_map, ObjectPtrHash, ObjectPtrEqual> Anal class AliasAnalyzer { public: - explicit AliasAnalyzer() : alias_map_(), tuple_map_(), mem_idx_(0) {} + AliasAnalyzer() : alias_map_(), tuple_map_(), mem_idx_(0) {} // The analysis returns a map of vars to memory locations that it *could* map to // (any unique allocation = one memory location), plus a map of memory locations @@ -257,17 +256,16 @@ class AliasAnalyzer { // call_pure_packed: treat as non-op call if (op_node->name == "relax.call_pure_packed") { return HandleMysteryCall(call_node, bound_var, true); - } - // split: Returns a tuple, treat as allocation - else if (op_node->name == "relax.split") { + } else if (op_node->name == "relax.split") { + // split: Returns a tuple, treat as allocation + // tuple is freshly allocated, but also add components to the tuple map int tup_idx = get_fresh_idx(); ret.insert(tup_idx); // the LHS (the bound var) will definitely have a tuple struct info InsertFreshTuple(tup_idx, GetStructInfoAs(bound_var)); - } - // call_tir: can potentially return a tuple - else if (op_node->name == "relax.call_tir") { + } else if (op_node->name == "relax.call_tir") { + // call_tir: can potentially return a tuple if (auto* tuple_struct_info = call_node->sinfo_args[0].as()) { int tup_idx = get_fresh_idx(); ret.insert(tup_idx); @@ -275,10 +273,9 @@ class AliasAnalyzer { } else { ret.insert(get_fresh_idx()); } - } - // We are assuming most op calls return a single fresh allocation. - // We may have to track more exceptions - else { + } else { + // We are assuming most op calls return a single fresh allocation. + // We may have to track more exceptions ret.insert(get_fresh_idx()); } } else { From a3299150a0286097f37fae0301af6af59d744432 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 21 Nov 2023 18:27:33 -0500 Subject: [PATCH 33/55] Treat args as mystery values by default, do not allow overwriting --- src/relax/transform/dataflow_inplace.cc | 41 +++++++++++++-------- tests/python/relax/test_dataflow_inplace.py | 16 ++++---- 2 files changed, 34 insertions(+), 23 deletions(-) diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index 7399e258d861..4215a01fba79 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -94,7 +94,9 @@ class AliasAnalyzer { // The analysis returns a map of vars to memory locations that it *could* map to // (any unique allocation = one memory location), plus a map of memory locations - // that correspond to tuples (this maps to sets of memory locations for each tuple element) + // that correspond to tuples (this maps to sets of memory locations for each tuple element). + // Note: inputs are values that should be assumed not to be aliased and are therefore + // (in the case of in-place ops) safe to overwrite. This may not be true of function args. std::pair, ObjectPtrHash, ObjectPtrEqual>, std::unordered_map>>> Analyze(const DataflowBlock& block, const Array& inputs) { @@ -209,7 +211,7 @@ class AliasAnalyzer { // var: look up in alias map (-1 if not present) // op call: assume it's fresh (may need to make list of exceptions) // tuple: fresh entry in tuple index, recurse to determine indices for values - // function/packed call: chaos reigns, alias with everything ever passed or returned from func + // function/packed call: chaos reigns, alias with any other argument // (if tuple is passed, assume also aliased with all members of the tuple) // tuple index: -1 if tuple is not in tuple map, otherwise look up corresponding entry // function constant: give them a fresh index (TODO: we can handle in more detail if this is a @@ -446,13 +448,13 @@ bool InplaceConditionsMet( alias_sets, const std::unordered_map>>& tuple_map, const std::unordered_set& currently_live, - const Expr& target, int idx) { + const Expr& target, int binding_idx) { if (auto* var_node = target.as()) { auto current_var = GetRef(var_node); // if the var is live past this point, we can't use it for in-place computations anyway if (live_ranges.count(current_var)) { auto live_range = live_ranges.at(current_var); - if (live_range.second > idx) { + if (live_range.second > binding_idx) { return false; } } @@ -477,11 +479,15 @@ bool InplaceConditionsMet( } for (Var other_var : currently_live) { + if (other_var.same_as(target)) { + continue; + } + // not represented = spooky unknown value that should be modeled by -1 if (!alias_sets.count(other_var) || !live_ranges.count(other_var)) { - return false; + continue; } // var is not live past this point => don't need to worry - if (live_ranges.at(other_var).second <= idx) { + if (live_ranges.at(other_var).second <= binding_idx) { continue; } auto other_alias_set = alias_sets.at(other_var); @@ -496,7 +502,8 @@ bool InplaceConditionsMet( return true; } else if (auto* tup_node = target.as()) { for (auto field : tup_node->fields) { - if (!InplaceConditionsMet(live_ranges, alias_sets, tuple_map, currently_live, field, idx)) { + if (!InplaceConditionsMet(live_ranges, alias_sets, tuple_map, currently_live, field, + binding_idx)) { return false; } } @@ -773,15 +780,17 @@ class ModuleInplaceTransformer : public ExprMutator { // and replace any valid calls in them BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override { auto block = GetRef(op); - std::unordered_set free_var_set; - for (auto binding : block->bindings) { - auto binding_free_vars = FreeVars(GetBoundValue(binding)); - free_var_set.insert(binding_free_vars.begin(), binding_free_vars.end()); - } - Array free_var_list(free_var_set.begin(), free_var_set.end()); - - // for now, only handle exact match cases - auto matches_found = FindInplaceOpportunities(block, free_var_list); + // std::unordered_set free_var_set; + // for (auto binding : block->bindings) { + // auto binding_free_vars = FreeVars(GetBoundValue(binding)); + // free_var_set.insert(binding_free_vars.begin(), binding_free_vars.end()); + // } + // Array free_var_list(free_var_set.begin(), free_var_set.end()); + + // For now, only handle exact match cases. + // Note: Not passing any input values for now, as we can't make any assumptions + // about them. + auto matches_found = FindInplaceOpportunities(block, {}); auto exact_matches = matches_found.second; Array new_bindings; diff --git a/tests/python/relax/test_dataflow_inplace.py b/tests/python/relax/test_dataflow_inplace.py index 435164f3ebea..1213776d2fa3 100644 --- a/tests/python/relax/test_dataflow_inplace.py +++ b/tests/python/relax/test_dataflow_inplace.py @@ -426,9 +426,12 @@ def main( ) -> R.Tensor((2, 3), dtype="float32"): with R.dataflow(): z = R.add(x, y) # broadcast happens here - q = R.multiply(z, y) # broadcast again - r = R.subtract(y, y) # now can be done inplace - m = R.multiply(q, r) # should give us all zeros + # Cannot be done in-place because x is an argument. + a = R.add(z, y) # this one can be done in-place + q = R.multiply(a, y) # broadcast again, a is eligible + r = R.subtract(y, y) # cannot be done in-place because y is an argument + s = R.subtract(r, r) # No broadcast. Can be done in-place + m = R.multiply(q, s) # should give us all zeros R.output(m) return m @@ -440,7 +443,9 @@ def main( assert new_mod["subtract_inplace"] assert new_mod["multiply_inplace"] expected_ops = ["add_inplace", "multiply_inplace", "subtract_inplace", "multiply_inplace"] - for i, binding in enumerate(new_mod["main"].body.blocks[0].bindings): + inplace_binding_idxs = [1, 2, 4, 5] + for i, idx in enumerate(inplace_binding_idxs): + binding = new_mod["main"].body.blocks[0].bindings[idx] assert binding.value.op.name == "relax.call_tir_inplace" assert binding.value.args[0].name_hint == expected_ops[i] @@ -452,9 +457,6 @@ def main( ex = relax.build(new_mod, target) vm = relax.VirtualMachine(ex, tvm.cpu()) res = vm["main"](x, y) - # due to reuse of buffers, the result is actually reference equal to argument x - # (we can disable this by setting the arguments to "unknown value" in the alias analysis) - assert res == x assert (expected == res.numpy()).all() From 29bcdfb9039a129289faa6a2272982ecf5a33eee Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 22 Nov 2023 15:16:58 -0500 Subject: [PATCH 34/55] Formatting --- tests/python/relax/test_dataflow_inplace.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/relax/test_dataflow_inplace.py b/tests/python/relax/test_dataflow_inplace.py index 1213776d2fa3..e9e9f2da1de8 100644 --- a/tests/python/relax/test_dataflow_inplace.py +++ b/tests/python/relax/test_dataflow_inplace.py @@ -426,11 +426,11 @@ def main( ) -> R.Tensor((2, 3), dtype="float32"): with R.dataflow(): z = R.add(x, y) # broadcast happens here - # Cannot be done in-place because x is an argument. - a = R.add(z, y) # this one can be done in-place + # Cannot be done in-place because x is an argument. + a = R.add(z, y) # this one can be done in-place q = R.multiply(a, y) # broadcast again, a is eligible r = R.subtract(y, y) # cannot be done in-place because y is an argument - s = R.subtract(r, r) # No broadcast. Can be done in-place + s = R.subtract(r, r) # No broadcast. Can be done in-place m = R.multiply(q, s) # should give us all zeros R.output(m) return m From cf396a1b443f0c5a0e6a556dfcf293619b5207b7 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Sat, 13 Jan 2024 19:22:06 -0500 Subject: [PATCH 35/55] Clarify pass description --- include/tvm/relax/transform.h | 9 +++++---- python/tvm/relax/transform/transform.py | 5 +++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index bc6186a9146b..efe30e5cbb50 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -573,10 +573,11 @@ TVM_DLL Pass ConvertToDataflow(int min_size = 2); TVM_DLL Pass DeadCodeElimination(Array entry_functions); /*! - * \brief Pass that changes calls to supported operators in dataflow blocks into in-place - * implementations. Supported operators will be replaced by calls to `call_tir_inplace` that invoke - * in-place PrimFunc implementations of those operators (which are based on the legalizations of - * those operators). + * \brief Pass that changes calls to operators that can be done in-place + * (generally, these are elementwise operations) in dataflow blocks into in-place implementations. + * Supported operators will be replaced by calls to `call_tir_inplace` that invoke in-place + * PrimFunc implementations of those operators (which are based on the legalizations of those + * operators). * \return The pass. */ TVM_DLL Pass DataflowUseInplaceCalls(); diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 379a22b487f0..99fdc67c29ce 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -254,8 +254,9 @@ def RemovePurityChecking() -> tvm.ir.transform.Pass: def DataflowUseInplaceCalls() -> tvm.ir.transform.Pass: """ - Pass that changes calls to supported operators in dataflow blocks into in-place - implementations. Supported operators will be replaced by calls to `call_tir_inplace` that invoke + Pass that changes calls to operators that can be done in-place + (generally, these are elementwise operations) into in-place implementations. + Supported operators will be replaced by calls to `call_tir_inplace` that invoke in-place PrimFunc implementations of those operators (which are based on the legalizations of those operators). From b6c7c368b5caffbd0324c8a0d13dbcea34364c88 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Sat, 13 Jan 2024 20:31:10 -0500 Subject: [PATCH 36/55] Add check to ensure that testing functions are used only in a testing environment --- python/tvm/relax/testing/transform.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/python/tvm/relax/testing/transform.py b/python/tvm/relax/testing/transform.py index 4521dc46fa93..7e3e124364f7 100644 --- a/python/tvm/relax/testing/transform.py +++ b/python/tvm/relax/testing/transform.py @@ -17,6 +17,8 @@ # pylint: disable=unused-argument, invalid-name, no-else-return, abstract-method, arguments-differ """Relax transformation passes for testing""" +import logging +import os from typing import Dict, List, Set, Tuple import tvm from tvm import ir, relax @@ -135,6 +137,11 @@ def dataflow_liveness_analysis(block: DataflowBlock) -> Dict[Var, Tuple[int, int """ Inner function for the dataflow inplace transformation exposed for testing. """ + if "PYTEST_CURRENT_TEST" not in os.environ: + logging.warning( + "The function dataflow_liveness_analysis is exposed for testing only." + ) + live_ranges = tvm.get_global_func("relax.testing.transform.DataflowLivenessAnalysis")( block ) # type: ignore @@ -150,6 +157,11 @@ def dataflow_alias_analysis( """ Inner function for the dataflow inplace transformation exposed for testing. """ + if "PYTEST_CURRENT_TEST" not in os.environ: + logging.warning( + "The function dataflow_alias_analysis is exposed for testing only." + ) + alias_sets, tuple_map = tvm.get_global_func("relax.testing.transform.DataflowAliasAnalysis")( block, inputs, @@ -169,6 +181,10 @@ def dataflow_inplace_analysis( """ Inner function for the dataflow inplace transformation exposed for testing. """ + if "PYTEST_CURRENT_TEST" not in os.environ: + logging.warning( + "The function dataflow_inplace_analysis is exposed for testing only." + ) index_lists = tvm.get_global_func("relax.testing.transform.DataflowInplaceAnalysis")( block, inputs, @@ -182,6 +198,11 @@ def dataflow_single_inplace_call( """ Inner function for the dataflow inplace transformation exposed for testing. """ + if "PYTEST_CURRENT_TEST" not in os.environ: + logging.warning( + "The function dataflow_single_inplace_call is exposed for testing only." + ) + ret = tvm.get_global_func("relax.testing.transform.SingleInplaceCall")( mod, call, From aad9aabe2002515c6b9e9f4e414ba32adba3b495 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Sat, 13 Jan 2024 20:32:07 -0500 Subject: [PATCH 37/55] Improve size match check readability per review suggestions --- src/relax/transform/dataflow_inplace.cc | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index 4215a01fba79..245c7528eae3 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -394,17 +394,22 @@ std::pair SizeMatches(const StructInfo& target_info, const StructInf } bool all_exact = true; for (size_t i = 0; i < target_tup->fields.size(); i++) { - auto element_match = SizeMatches(target_tup->fields[i], arg_tup->fields[i]); - if (!element_match.first) { - return {false, false}; + // if members aren't either tuples or tensors, simply skip them, + // since they don't matter for in-place computations + if (!(target_tup->fields[i].as() || + target_tup->fields[i].as()) && + !(arg_tup->fields[i].as() || + arg_tup->fields[i].as())) { + continue; } - if (!element_match.second) { - all_exact = false; + auto [field_size_match, field_exact_match] = + SizeMatches(target_tup->fields[i], arg_tup->fields[i]); + if (!field_size_match) { + return {false, false}; } + all_exact = all_exact && field_exact_match; } return {true, all_exact}; - } else if (target_info.as() && arg_info.as()) { - return {true, true}; } else { return {false, false}; } From 1470d24a9c56ebab00e3c6ede8dd47b08feea987 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sat, 13 Jan 2024 20:43:16 -0500 Subject: [PATCH 38/55] Improve the size match check per review suggestions (use PrimExprs) --- src/relax/transform/dataflow_inplace.cc | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index 245c7528eae3..02daf23bd3f4 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -294,15 +294,11 @@ class AliasAnalyzer { int mem_idx_; }; -// given a shape, return the allocation size corresponding to it (product of elements) -int ShapeSize(const ShapeExpr& shape) { - int ret = 1; +// given a shape, return the number of elements corresponding to it (product of elements) +PrimExpr NumElements(const ShapeExpr& shape) { + PrimExpr ret = IntImm(DataType::Int(64), 1); for (auto dim : shape->values) { - if (auto int_dim = dim.as()) { - ret *= static_cast(int_dim->value); - } else { - return -1; - } + ret *= dim; } return ret; } @@ -364,9 +360,10 @@ std::pair SizeMatches(const StructInfo& target_info, const StructInf } auto target_shape = Downcast(target_tensor->shape); auto arg_shape = Downcast(arg_tensor->shape); - int target_size = ShapeSize(target_shape); - int arg_size = ShapeSize(arg_shape); - if (target_size == -1 || arg_size == -1 || target_size < arg_size) { + PrimExpr target_size = NumElements(target_shape); + PrimExpr arg_size = NumElements(arg_shape); + if (!target_size.as() || !arg_size.as() || + Downcast(target_size)->value < Downcast(arg_size)->value) { return {false, false}; } // exact match: number of dims and each dim matches From af5fd05f2ebe5cca3acdc4237a0de942faffdae4 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Sat, 13 Jan 2024 21:32:45 -0500 Subject: [PATCH 39/55] Treat non-dataflow vars as living past the end of the block in all cases --- src/relax/transform/dataflow_inplace.cc | 8 +++++++- tests/python/relax/test_dataflow_inplace.py | 16 +++++++++------- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index 02daf23bd3f4..c2e58dfd785a 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -64,8 +64,14 @@ std::unordered_map, ObjectPtrHash, ObjectPtrEqual> Anal } for (auto var : used_vars) { + int range_end = i; + // if the var is not a dataflow var, then it is live + // after the block (we are not checking later blocks) + if (!var.as()) { + range_end = block->bindings.size(); + } if (!ret.count(var)) { - ret[var] = {-1, i}; + ret[var] = {-1, range_end}; } } diff --git a/tests/python/relax/test_dataflow_inplace.py b/tests/python/relax/test_dataflow_inplace.py index e9e9f2da1de8..c0459e06ac53 100644 --- a/tests/python/relax/test_dataflow_inplace.py +++ b/tests/python/relax/test_dataflow_inplace.py @@ -41,17 +41,20 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): q = R.multiply(z, y) p = R.add(z, q) n = R.multiply(p, p) - R.output(n) + R.output(n, p) return n block = BasicLiveness["main"].body.blocks[0] + print(block.bindings[-2].var, type(block.bindings[-2].var)) live_ranges = dataflow_liveness_analysis(block) expected_ranges = { - "x": (-1, 1), + # x is live past the binding block + "x": (-1, 5), "y": (0, 2), "z": (1, 3), "q": (2, 3), - "p": (3, 4), + # exposed though ultimately not used + "p": (3, 5), "n": (4, 5), } for var, live_range in live_ranges.items(): @@ -326,8 +329,7 @@ def main( ) -> R.Tensor((2, 3), "int32"): with R.dataflow(): z = R.add(x, y) # cannot be done inplace: x and y are live later - q = R.multiply(x, y) # can be done inplace: neither x nor y is used later - p = R.add(z, q) # can be done inplace: neither z nor q is used later + p = R.add(z, z) # can be done inplace: z is not used later r = p # alias of p m = R.multiply(p, p) # p is not used later but r is, so can't do inplace n = R.add(m, r) # can be done inplace: r is not used again @@ -349,10 +351,10 @@ def assert_candidate_list( for j in range(len(expected[i][1])): assert actual[i][j + 1] in expected[i][1] - assert_candidate_list(size_match, [(1, {0, 1}), (2, {0, 1}), (5, {1}), (6, {0, 1})]) + assert_candidate_list(size_match, [(1, {0, 1}), (4, {1}), (5, {0, 1})]) # TODO(@slyubomirsky): I couldn't think of an easy example where sizes don't match, # but broadcasting might cause it to happen - assert_candidate_list(exact_match, [(1, {0, 1}), (2, {0, 1}), (5, {1}), (6, {0, 1})]) + assert_candidate_list(exact_match, [(1, {0, 1}), (4, {1}), (5, {0, 1})]) def test_inplace_single_call(): From 23a4d0dbad96ebee9452e624da3efb08912396d7 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Sat, 13 Jan 2024 21:34:30 -0500 Subject: [PATCH 40/55] Clarify notion of size in comment --- src/relax/transform/dataflow_inplace.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index c2e58dfd785a..9c31b39d35c5 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -351,7 +351,7 @@ std::unordered_set GatherCandidateSin } // Given the two struct info, return a pair of bools where the first element is true if -// the two struct info are the same _size_ in memory and the second element is true +// the two struct info have the same number of elements and dtype and the second element is true // if the shapes match _exactly_. Performs this check recursively and ensures the // stated condition is true for all tensor members of the struct info (return false // if a single pair of corresponding tensors does not meet the condition). From c604d3b11944b2a100a17b03174b8857b8b58219 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Sat, 13 Jan 2024 21:35:06 -0500 Subject: [PATCH 41/55] Remove commented-out code --- src/relax/transform/dataflow_inplace.cc | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index 9c31b39d35c5..ecdacfae4a47 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -788,12 +788,6 @@ class ModuleInplaceTransformer : public ExprMutator { // and replace any valid calls in them BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override { auto block = GetRef(op); - // std::unordered_set free_var_set; - // for (auto binding : block->bindings) { - // auto binding_free_vars = FreeVars(GetBoundValue(binding)); - // free_var_set.insert(binding_free_vars.begin(), binding_free_vars.end()); - // } - // Array free_var_list(free_var_set.begin(), free_var_set.end()); // For now, only handle exact match cases. // Note: Not passing any input values for now, as we can't make any assumptions From 3060ea11bb9d6560364ab90222e03d0d875124df Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Sat, 13 Jan 2024 21:45:02 -0500 Subject: [PATCH 42/55] Assume any op that returns a tuple is returning a fresh one (exceptions can be noted later) --- src/relax/transform/dataflow_inplace.cc | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index ecdacfae4a47..74c9cc4c5b8a 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -264,14 +264,6 @@ class AliasAnalyzer { // call_pure_packed: treat as non-op call if (op_node->name == "relax.call_pure_packed") { return HandleMysteryCall(call_node, bound_var, true); - } else if (op_node->name == "relax.split") { - // split: Returns a tuple, treat as allocation - - // tuple is freshly allocated, but also add components to the tuple map - int tup_idx = get_fresh_idx(); - ret.insert(tup_idx); - // the LHS (the bound var) will definitely have a tuple struct info - InsertFreshTuple(tup_idx, GetStructInfoAs(bound_var)); } else if (op_node->name == "relax.call_tir") { // call_tir: can potentially return a tuple if (auto* tuple_struct_info = call_node->sinfo_args[0].as()) { @@ -282,8 +274,17 @@ class AliasAnalyzer { ret.insert(get_fresh_idx()); } } else { - // We are assuming most op calls return a single fresh allocation. + // We are assuming most op calls return fresh values. // We may have to track more exceptions + + // If the returned value is a tuple, we'll assume it's a fresh tuple + // (there may be exceptions to this too) + if (auto* tup_info = GetStructInfoAs(bound_var)) { + int tup_idx = get_fresh_idx(); + ret.insert(tup_idx); + InsertFreshTuple(tup_idx, tup_info); + return ret; + } ret.insert(get_fresh_idx()); } } else { From 3107ce47dd5a93dfa082902d841fa27610db9cd6 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Sat, 13 Jan 2024 21:50:01 -0500 Subject: [PATCH 43/55] Add full structural equality check in large test case --- tests/python/relax/test_dataflow_inplace.py | 88 ++++++++++++++++++--- 1 file changed, 77 insertions(+), 11 deletions(-) diff --git a/tests/python/relax/test_dataflow_inplace.py b/tests/python/relax/test_dataflow_inplace.py index c0459e06ac53..cbd3a05bcc73 100644 --- a/tests/python/relax/test_dataflow_inplace.py +++ b/tests/python/relax/test_dataflow_inplace.py @@ -437,19 +437,85 @@ def main( R.output(m) return m + @I.ir_module + class Expected: + @T.prim_func(private=True) + def add_inplace( + A: T.Buffer((T.int64(2), T.int64(3)), "float32"), + B: T.Buffer((T.int64(1), T.int64(3)), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1], B[T.int64(0), v_ax1]) + T.writes(A[v_ax0, v_ax1]) + A[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[T.int64(0), v_ax1] + + @T.prim_func(private=True) + def multiply_inplace( + A: T.Buffer((T.int64(2), T.int64(3)), "float32"), + B: T.Buffer((T.int64(1), T.int64(3)), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_multiply"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1], B[T.int64(0), v_ax1]) + T.writes(A[v_ax0, v_ax1]) + A[v_ax0, v_ax1] = A[v_ax0, v_ax1] * B[T.int64(0), v_ax1] + + @T.prim_func(private=True) + def subtract_inplace( + A: T.Buffer((T.int64(1), T.int64(3)), "float32"), + B: T.Buffer((T.int64(1), T.int64(3)), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for ax0, ax1 in T.grid(T.int64(1), T.int64(3)): + with T.block("T_subtract"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1]) + T.writes(B[v_ax0, v_ax1]) + B[v_ax0, v_ax1] = A[v_ax0, v_ax1] - B[v_ax0, v_ax1] + + @R.function + def main( + x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((1, 3), dtype="float32") + ) -> R.Tensor((2, 3), dtype="float32"): + cls = Expected + with R.dataflow(): + z: R.Tensor((2, 3), dtype="float32") = R.add(x, y) + a: R.Tensor((2, 3), dtype="float32") = R.call_tir_inplace( + cls.add_inplace, + (z, y), + inplace_indices=[0], + out_sinfo=[R.Tensor((2, 3), dtype="float32"),], + ) + q: R.Tensor((2, 3), dtype="float32") = R.call_tir_inplace( + cls.multiply_inplace, + (a, y), + inplace_indices=[0], + out_sinfo=[R.Tensor((2, 3), dtype="float32"),], + ) + r: R.Tensor((1, 3), dtype="float32") = R.subtract(y, y) + s: R.Tensor((1, 3), dtype="float32") = R.call_tir_inplace( + cls.subtract_inplace, + (r, r), + inplace_indices=[1], + out_sinfo=[R.Tensor((1, 3), dtype="float32"),], + ) + m: R.Tensor((2, 3), dtype="float32") = R.call_tir_inplace( + cls.multiply_inplace, + (q, s), + inplace_indices=[0], + out_sinfo=[R.Tensor((2, 3), dtype="float32"),], + ) + R.output(m) + return m + transform_pass = DataflowUseInplaceCalls() new_mod = transform_pass(EndToEndTest) - - # check that all operations are done in-place - assert new_mod["add_inplace"] - assert new_mod["subtract_inplace"] - assert new_mod["multiply_inplace"] - expected_ops = ["add_inplace", "multiply_inplace", "subtract_inplace", "multiply_inplace"] - inplace_binding_idxs = [1, 2, 4, 5] - for i, idx in enumerate(inplace_binding_idxs): - binding = new_mod["main"].body.blocks[0].bindings[idx] - assert binding.value.op.name == "relax.call_tir_inplace" - assert binding.value.args[0].name_hint == expected_ops[i] + tvm.ir.assert_structural_equal(new_mod, Expected) x = tvm.nd.array(np.random.rand(2, 3).astype("float32")) y = tvm.nd.array(np.random.rand(1, 3).astype("float32")) From 585a748230b67a94975f0dbd6619bc2598681f9e Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Sat, 13 Jan 2024 22:24:52 -0500 Subject: [PATCH 44/55] Fix parser roundtripping bug with call_tir_inplace --- python/tvm/relax/__init__.py | 8 ++++- src/script/printer/relax/call.cc | 19 +++++++++- tests/python/relax/test_tvmscript_parser.py | 36 +++++++++++++++++++ .../relax/test_tvmscript_printer_relax.py | 25 +++++++++++++ 4 files changed, 86 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index 5bc0d6c56eb4..23cfaf293560 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -63,7 +63,13 @@ from .exec_builder import ExecBuilder # Operator -from .op.base import call_tir, call_pure_packed, call_dps_packed, call_tir_with_grad +from .op.base import ( + call_tir, + call_tir_inplace, + call_pure_packed, + call_dps_packed, + call_tir_with_grad, +) # BlockBuilder from .block_builder import BlockBuilder diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc index 785dc6d96320..ef9438350ce0 100644 --- a/src/script/printer/relax/call.cc +++ b/src/script/printer/relax/call.cc @@ -97,11 +97,13 @@ ExprDoc PrintCallee(const relax::Expr& n, const ObjectPath& n_p, const IRDocsifi Optional PrintCallTIRDPSPacked(const relax::Call& n, const ObjectPath& n_p, const IRDocsifier& d) { static const Op& call_tir_op = Op::Get("relax.call_tir"); + static const Op& call_tir_inplace_op = Op::Get("relax.call_tir_inplace"); static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed"); static const Op& call_tir_with_grad_op = Op::Get("relax.call_tir_with_grad"); static const Op& call_tir_local_view = Op::Get("relax.dist.call_tir_local_view"); if (!n->op.same_as(call_tir_op) && !n->op.same_as(call_dps_packed_op) && - !n->op.same_as(call_tir_with_grad_op) && !n->op.same_as(call_tir_local_view)) { + !n->op.same_as(call_tir_with_grad_op) && !n->op.same_as(call_tir_local_view) && + !n->op.same_as(call_tir_inplace_op)) { return NullOpt; } ICHECK(n->args.size() == 2 || n->args.size() == 3); @@ -135,6 +137,19 @@ Optional PrintCallTIRDPSPacked(const relax::Call& n, const ObjectPath& kwargs_values.push_back(d->AsDoc(o_sinfo, o_sinfo_p)); } + // for call_tir_inplace, we also need to include the inplace args + if (n->op.same_as(call_tir_inplace_op)) { + kwargs_keys.push_back("inplace_indices"); + Array index_fields; + if (auto* call_tir_inplace_attrs = n->attrs.as()) { + for (auto inplace_index : call_tir_inplace_attrs->inplace_indices) { + index_fields.push_back( + LiteralDoc::Int(inplace_index.IntValue(), n_p->Attr("attrs")->Attr("inplace_indices"))); + } + } + kwargs_values.push_back(ListDoc(index_fields)); + } + // start of specially handling call_tir_with_grad if (const auto* call_tir_with_grad_attrs = n->attrs.as()) { kwargs_keys.push_back("te_grad_name"); @@ -163,6 +178,8 @@ Optional PrintCallTIRDPSPacked(const relax::Call& n, const ObjectPath& return Relax(d, "dist.call_tir_local_view")->Call(args, kwargs_keys, kwargs_values); } else if (is_dtensor) { return Relax(d, "dist.call_tir")->Call(args, kwargs_keys, kwargs_values); + } else if (n->op.same_as(call_tir_inplace_op)) { + return Relax(d, "call_tir_inplace")->Call(args, kwargs_keys, kwargs_values); } else { return Relax(d, "call_tir")->Call(args, kwargs_keys, kwargs_values); } diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index b45c3c6e4a93..182d981b98e3 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -986,6 +986,42 @@ def main(v0: R.Tensor([54, 96], "float32")): _check(Module) +def test_call_tir_inplace(): + @tvm.script.ir_module + class Module: + @T.prim_func + def copy( + A: T.Buffer((2, 3), "int32"), + B: T.Buffer((2, 3), "int32"), + out1: T.Buffer((2, 3), "int32"), + ): + # copies the contents of B into A and out1 + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_zeros"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(B[ax0, ax1]) + T.writes(A[ax0, ax1], out1[ax0, ax1]) + A[ax0, ax1] = B[ax0, ax1] + out1[ax0, ax1] = B[ax0, ax1] + + @R.function + def main( + x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32") + ) -> R.Tuple( + R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32") + ): + res = R.call_tir_inplace( + Module.copy, + (x, y), + [0, -1], + [R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")], + ) + return res + + _check(Module) + + def test_local_function(): @R.function def main( diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index dc3334f216c0..530e45e61074 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -399,6 +399,31 @@ def test_call_tir_with_grad(): ) +def test_call_tir_inplace(): + x = relax.Var("x", R.Tensor((32, 32), dtype="int32")) + y = relax.Var("y", R.Tensor((32, 32), dtype="int32")) + t = tir.Var("t", dtype="int64") + call = relax.call_tir_inplace( + relax.GlobalVar("tir_func"), + ( + x, + y, + ), + inplace_indices=[-1, 0], + out_sinfo=[R.Tensor((32, 32), dtype="int32"), R.Tensor((32, 32), dtype="int32")], + tir_vars=[t], + ) + _assert_print( + call, + """ +x: R.Tensor((32, 32), dtype="int32") +y: R.Tensor((32, 32), dtype="int32") +t = T.int64() +R.call_tir_inplace(tir_func, (x, y), out_sinfo=[R.Tensor((32, 32), dtype="int32"), R.Tensor((32, 32), dtype="int32")], inplace_indices=[-1, 0], tir_vars=R.shape([t])) + """, + ) + + def test_seq_expr(): x = tir.Var("x", "int64") a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32")) From c5d2871c067ba7bdea4b27193eb2e268bd228796 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Sat, 13 Jan 2024 22:31:33 -0500 Subject: [PATCH 45/55] Refactor tests to ensure maps are nonempty --- tests/python/relax/test_dataflow_inplace.py | 24 ++++++++++----------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/python/relax/test_dataflow_inplace.py b/tests/python/relax/test_dataflow_inplace.py index cbd3a05bcc73..ffafff4e9d63 100644 --- a/tests/python/relax/test_dataflow_inplace.py +++ b/tests/python/relax/test_dataflow_inplace.py @@ -57,8 +57,8 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): "p": (3, 5), "n": (4, 5), } - for var, live_range in live_ranges.items(): - assert live_range == expected_ranges[var.name_hint] + actual_ranges = {var.name_hint: live_range for var, live_range in live_ranges.items()} + assert actual_ranges == expected_ranges def test_alias_analysis_basic(): @@ -124,8 +124,8 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): "n": {3}, } - for var, alias_set in alias_sets.items(): - assert alias_set == expected[var.name_hint] + actual_alias_sets = {var.name_hint : alias_set for var, alias_set in alias_sets.items()} + assert expected == actual_alias_sets assert 2 in tuple_map assert tuple_map[2] == [{0}, {1}] @@ -157,8 +157,8 @@ def main(x: R.Tensor((60,), "int32")) -> R.Tensor((15,), "int32"): "n": {3}, } - for var, alias_set in alias_sets.items(): - assert alias_set == expected[var.name_hint] + actual_alias_sets = {var.name_hint : alias_set for var, alias_set in alias_sets.items()} + assert expected == actual_alias_sets assert len(tuple_map) == 1 assert 1 in tuple_map assert tuple_map[1] == [{2}, {3}, {4}, {5}] @@ -231,8 +231,8 @@ def main(x: R.Tensor((10, 10), "int32")) -> R.Tensor((10, 10), "int32"): "v": {4}, } - for var, alias_set in alias_sets.items(): - assert alias_set == expected[var.name_hint] + actual_alias_sets = {var.name_hint : alias_set for var, alias_set in alias_sets.items()} + assert expected == actual_alias_sets assert len(tuple_map) == 1 assert 2 in tuple_map assert tuple_map[2] == [{3}, {4}] @@ -277,8 +277,8 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): # (in principle, we can use type information to narrow down the aliasing) } - for var, alias_set in alias_sets.items(): - assert alias_set == expected[var.name_hint] + actual_alias_sets = {var.name_hint : alias_set for var, alias_set in alias_sets.items()} + assert expected == actual_alias_sets assert len(tuple_map) == 2 assert 5 in tuple_map assert tuple_map[5] == [{3}, {4}] @@ -313,8 +313,8 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): "c": {-1}, } - for var, alias_set in alias_sets.items(): - assert alias_set == expected[var.name_hint] + actual_alias_sets = {var.name_hint : alias_set for var, alias_set in alias_sets.items()} + assert expected == actual_alias_sets assert len(tuple_map) == 1 assert 2 in tuple_map assert tuple_map[2] == [{-1}, {1}] From 9dc2e525cacc2a66b3f016a542dbf7f87b285f68 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Sat, 13 Jan 2024 22:33:29 -0500 Subject: [PATCH 46/55] Use .empty() where it's more reasonable --- src/relax/transform/dataflow_inplace.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index 74c9cc4c5b8a..ca0c6f3e55b4 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -610,7 +610,7 @@ std::pair>, std::vector>> FindInpl } } } - if (!candidates.size()) { + if (!candidates.empty()) { continue; } @@ -629,7 +629,7 @@ std::pair>, std::vector>> FindInpl } // if we have a candidate, then this can be made in-place. Report the appropriate candidates - if (!candidates.size()) { + if (!candidates.empty()) { continue; } From b9d01b7313f787614329278cf9e412e23ca27db7 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 15 Jan 2024 19:09:47 -0500 Subject: [PATCH 47/55] linting changes --- python/tvm/relax/testing/transform.py | 16 ++++--------- tests/python/relax/test_dataflow_inplace.py | 26 ++++++++++++++------- tests/python/relax/test_tvmscript_parser.py | 2 +- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/python/tvm/relax/testing/transform.py b/python/tvm/relax/testing/transform.py index 7e3e124364f7..6357c118f71e 100644 --- a/python/tvm/relax/testing/transform.py +++ b/python/tvm/relax/testing/transform.py @@ -138,9 +138,7 @@ def dataflow_liveness_analysis(block: DataflowBlock) -> Dict[Var, Tuple[int, int Inner function for the dataflow inplace transformation exposed for testing. """ if "PYTEST_CURRENT_TEST" not in os.environ: - logging.warning( - "The function dataflow_liveness_analysis is exposed for testing only." - ) + logging.warning("The function dataflow_liveness_analysis is exposed for testing only.") live_ranges = tvm.get_global_func("relax.testing.transform.DataflowLivenessAnalysis")( block @@ -158,9 +156,7 @@ def dataflow_alias_analysis( Inner function for the dataflow inplace transformation exposed for testing. """ if "PYTEST_CURRENT_TEST" not in os.environ: - logging.warning( - "The function dataflow_alias_analysis is exposed for testing only." - ) + logging.warning("The function dataflow_alias_analysis is exposed for testing only.") alias_sets, tuple_map = tvm.get_global_func("relax.testing.transform.DataflowAliasAnalysis")( block, @@ -182,9 +178,7 @@ def dataflow_inplace_analysis( Inner function for the dataflow inplace transformation exposed for testing. """ if "PYTEST_CURRENT_TEST" not in os.environ: - logging.warning( - "The function dataflow_inplace_analysis is exposed for testing only." - ) + logging.warning("The function dataflow_inplace_analysis is exposed for testing only.") index_lists = tvm.get_global_func("relax.testing.transform.DataflowInplaceAnalysis")( block, inputs, @@ -199,9 +193,7 @@ def dataflow_single_inplace_call( Inner function for the dataflow inplace transformation exposed for testing. """ if "PYTEST_CURRENT_TEST" not in os.environ: - logging.warning( - "The function dataflow_single_inplace_call is exposed for testing only." - ) + logging.warning("The function dataflow_single_inplace_call is exposed for testing only.") ret = tvm.get_global_func("relax.testing.transform.SingleInplaceCall")( mod, diff --git a/tests/python/relax/test_dataflow_inplace.py b/tests/python/relax/test_dataflow_inplace.py index ffafff4e9d63..bb20d97b23ee 100644 --- a/tests/python/relax/test_dataflow_inplace.py +++ b/tests/python/relax/test_dataflow_inplace.py @@ -124,7 +124,7 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): "n": {3}, } - actual_alias_sets = {var.name_hint : alias_set for var, alias_set in alias_sets.items()} + actual_alias_sets = {var.name_hint: alias_set for var, alias_set in alias_sets.items()} assert expected == actual_alias_sets assert 2 in tuple_map assert tuple_map[2] == [{0}, {1}] @@ -157,7 +157,7 @@ def main(x: R.Tensor((60,), "int32")) -> R.Tensor((15,), "int32"): "n": {3}, } - actual_alias_sets = {var.name_hint : alias_set for var, alias_set in alias_sets.items()} + actual_alias_sets = {var.name_hint: alias_set for var, alias_set in alias_sets.items()} assert expected == actual_alias_sets assert len(tuple_map) == 1 assert 1 in tuple_map @@ -231,7 +231,7 @@ def main(x: R.Tensor((10, 10), "int32")) -> R.Tensor((10, 10), "int32"): "v": {4}, } - actual_alias_sets = {var.name_hint : alias_set for var, alias_set in alias_sets.items()} + actual_alias_sets = {var.name_hint: alias_set for var, alias_set in alias_sets.items()} assert expected == actual_alias_sets assert len(tuple_map) == 1 assert 2 in tuple_map @@ -277,7 +277,7 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): # (in principle, we can use type information to narrow down the aliasing) } - actual_alias_sets = {var.name_hint : alias_set for var, alias_set in alias_sets.items()} + actual_alias_sets = {var.name_hint: alias_set for var, alias_set in alias_sets.items()} assert expected == actual_alias_sets assert len(tuple_map) == 2 assert 5 in tuple_map @@ -313,7 +313,7 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): "c": {-1}, } - actual_alias_sets = {var.name_hint : alias_set for var, alias_set in alias_sets.items()} + actual_alias_sets = {var.name_hint: alias_set for var, alias_set in alias_sets.items()} assert expected == actual_alias_sets assert len(tuple_map) == 1 assert 2 in tuple_map @@ -489,26 +489,34 @@ def main( cls.add_inplace, (z, y), inplace_indices=[0], - out_sinfo=[R.Tensor((2, 3), dtype="float32"),], + out_sinfo=[ + R.Tensor((2, 3), dtype="float32"), + ], ) q: R.Tensor((2, 3), dtype="float32") = R.call_tir_inplace( cls.multiply_inplace, (a, y), inplace_indices=[0], - out_sinfo=[R.Tensor((2, 3), dtype="float32"),], + out_sinfo=[ + R.Tensor((2, 3), dtype="float32"), + ], ) r: R.Tensor((1, 3), dtype="float32") = R.subtract(y, y) s: R.Tensor((1, 3), dtype="float32") = R.call_tir_inplace( cls.subtract_inplace, (r, r), inplace_indices=[1], - out_sinfo=[R.Tensor((1, 3), dtype="float32"),], + out_sinfo=[ + R.Tensor((1, 3), dtype="float32"), + ], ) m: R.Tensor((2, 3), dtype="float32") = R.call_tir_inplace( cls.multiply_inplace, (q, s), inplace_indices=[0], - out_sinfo=[R.Tensor((2, 3), dtype="float32"),], + out_sinfo=[ + R.Tensor((2, 3), dtype="float32"), + ], ) R.output(m) return m diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 182d981b98e3..f317d04f59ae 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -1018,7 +1018,7 @@ def main( [R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")], ) return res - + _check(Module) From 8dd70df73ca572a03187a7cb8602d9041d384dcc Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 15 Jan 2024 20:32:05 -0500 Subject: [PATCH 48/55] Flipped the check by accident --- src/relax/transform/dataflow_inplace.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index ca0c6f3e55b4..f194e8ccac1d 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -610,7 +610,7 @@ std::pair>, std::vector>> FindInpl } } } - if (!candidates.empty()) { + if (candidates.empty()) { continue; } @@ -629,7 +629,7 @@ std::pair>, std::vector>> FindInpl } // if we have a candidate, then this can be made in-place. Report the appropriate candidates - if (!candidates.empty()) { + if (candidates.empty()) { continue; } From 285c7d6db969e7b841929a8eee6df8c30e232712 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 15 Jan 2024 20:39:34 -0500 Subject: [PATCH 49/55] Remove debug print --- tests/python/relax/test_dataflow_inplace.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/relax/test_dataflow_inplace.py b/tests/python/relax/test_dataflow_inplace.py index bb20d97b23ee..097325c6d56f 100644 --- a/tests/python/relax/test_dataflow_inplace.py +++ b/tests/python/relax/test_dataflow_inplace.py @@ -45,7 +45,6 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): return n block = BasicLiveness["main"].body.blocks[0] - print(block.bindings[-2].var, type(block.bindings[-2].var)) live_ranges = dataflow_liveness_analysis(block) expected_ranges = { # x is live past the binding block From 81b1fcbea1a0a7ddcf625600d2967ef1b6c090fc Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 15 Jan 2024 20:40:21 -0500 Subject: [PATCH 50/55] Factor out data structure for representing matches and match opportunities --- python/tvm/relax/testing/transform.py | 28 ++++++- src/relax/transform/dataflow_inplace.cc | 90 +++++++++++++-------- tests/python/relax/test_dataflow_inplace.py | 8 +- 3 files changed, 86 insertions(+), 40 deletions(-) diff --git a/python/tvm/relax/testing/transform.py b/python/tvm/relax/testing/transform.py index 6357c118f71e..0d1e4a1477e9 100644 --- a/python/tvm/relax/testing/transform.py +++ b/python/tvm/relax/testing/transform.py @@ -28,6 +28,7 @@ from tvm.relax import PyExprMutator from tvm.relax.expr import Call, DataflowBlock, Var from tvm.relay.backend.te_compiler import select_implementation +from tvm.runtime.object import Object from tvm.target import Target @@ -171,9 +172,28 @@ def dataflow_alias_analysis( return res_alias_sets, res_tuple_map # type: ignore +@tvm._ffi.register_object("relax.transform.InplaceOpportunity") +class InplaceOpportunity(Object): + """ + Represents an opportunity to make a binding in-place. Exposed only for testing; + the constructor is not exposed. + + Parameters: + ----------- + binding_idx: int + Index of the binding within its block + + arg_idxs: List[int] + Indices of arguments that are eligible to be used as in-place targets. + """ + + def __init__(_): + raise NotImplementedError("Constructor for InplaceOpportunity not exposed!") + + def dataflow_inplace_analysis( block: DataflowBlock, inputs: List[Var] -) -> Tuple[List[int], List[int]]: +) -> Tuple[List[Tuple[int, Set[int]]], List[Tuple[int, Set[int]]]]: """ Inner function for the dataflow inplace transformation exposed for testing. """ @@ -183,7 +203,11 @@ def dataflow_inplace_analysis( block, inputs, ) # type: ignore - return tuple(map(list, index_lists)) # type: ignore + + def convert(opp_list): + return list(map(lambda opp: (int(opp.binding_idx), set(map(int, opp.arg_idxs))), opp_list)) + + return (convert(index_lists[0]), convert(index_lists[1])) # type: ignore def dataflow_single_inplace_call( diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index f194e8ccac1d..04c94cc48dbb 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -276,7 +276,7 @@ class AliasAnalyzer { } else { // We are assuming most op calls return fresh values. // We may have to track more exceptions - + // If the returned value is a tuple, we'll assume it's a fresh tuple // (there may be exceptions to this too) if (auto* tup_info = GetStructInfoAs(bound_var)) { @@ -528,6 +528,38 @@ static std::unordered_set SUPPORTED_OPS = {"relax.add", "relax "relax.nn.silu", "relax.nn.relu"}; bool OpSupportsInplace(const Op& op) { return SUPPORTED_OPS.count(op->name); } +/*! \brief Corresponds to a binding where at least one argument meets the conditions to be + * made in-place. Contains the binding index and indices of the applicable arguments + */ +class InplaceOpportunityNode : public Object { + public: + // need to use Array for the benefit of the FFI + Integer binding_idx; + Array arg_idxs; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("binding_idx", &binding_idx); + v->Visit("arg_idxs", &arg_idxs); + } + + static constexpr const char* _type_key = "relax.transform.InplaceOpportunity"; + TVM_DECLARE_BASE_OBJECT_INFO(InplaceOpportunityNode, Object); +}; + +TVM_REGISTER_NODE_TYPE(InplaceOpportunityNode); + +class InplaceOpportunity : public ObjectRef { + public: + TVM_DLL InplaceOpportunity(const Integer& binding_idx, const Array& arg_idxs) { + auto node = make_object(); + node->binding_idx = binding_idx; + node->arg_idxs = arg_idxs; + data_ = std::move(node); + } + + TVM_DEFINE_OBJECT_REF_METHODS(InplaceOpportunity, ObjectRef, InplaceOpportunityNode); +}; + // Check for in-place eligibility: // 1. see if there's an arg big enough to hold the result // 2. see if the arg is live past the call @@ -541,16 +573,16 @@ bool OpSupportsInplace(const Op& op) { return SUPPORTED_OPS.count(op->name); } // For both lists, each element is a list of ints of the following format: // The first element is the index of the *binding* in the block. // All remaining elements are the indices of *eligible arguments* in that call. -std::pair>, std::vector>> FindInplaceOpportunities( - const DataflowBlock& block, const Array& inputs) { +std::pair, std::vector> +FindInplaceOpportunities(const DataflowBlock& block, const Array& inputs) { auto live_ranges = AnalyzeLiveness(block); AliasAnalyzer analyzer; auto alias_info = analyzer.Analyze(block, inputs); auto alias_sets = alias_info.first; auto tuple_map = alias_info.second; - std::vector> size_match_list; - std::vector> exact_match_list; + std::vector size_match_list; + std::vector exact_match_list; // sort the live ranges by starting index std::vector live_order; @@ -634,20 +666,24 @@ std::pair>, std::vector>> FindInpl } // produce a list of candidates for this index - std::vector size_match_indices = {static_cast(i)}; - size_match_indices.insert(size_match_indices.end(), candidates.begin(), candidates.end()); - size_match_list.push_back(size_match_indices); + Array size_candidate_list; + for (auto candidate : candidates) { + size_candidate_list.push_back(Integer(candidate)); + } + size_match_list.push_back(InplaceOpportunity(Integer(i), size_candidate_list)); // also gather up the exact match candidates if there are any - std::vector exact_match_indices = {static_cast(i)}; + Array exact_candidate_list; for (auto candidate : candidates) { - if (exact_match_candidates.count(candidate)) { - exact_match_indices.push_back(candidate); + if (!exact_match_candidates.count(candidate)) { + continue; } + exact_candidate_list.push_back(Integer(candidate)); } - if (exact_match_indices.size() > 1) { - exact_match_list.push_back(exact_match_indices); + if (exact_candidate_list.empty()) { + continue; } + exact_match_list.push_back(InplaceOpportunity(Integer(i), exact_candidate_list)); } } @@ -800,7 +836,7 @@ class ModuleInplaceTransformer : public ExprMutator { int current_match_index = 0; for (size_t i = 0; i < block->bindings.size(); i++) { if (current_match_index >= static_cast(exact_matches.size()) || - exact_matches[current_match_index][0] != static_cast(i)) { + exact_matches[current_match_index]->binding_idx.IntValue() != static_cast(i)) { new_bindings.push_back(block->bindings[i]); continue; } @@ -808,7 +844,8 @@ class ModuleInplaceTransformer : public ExprMutator { auto target_binding = block->bindings[i]; auto target_call = Downcast(GetBoundValue(target_binding)); // can just pick the first index arbitrarily (only using one output for now too) - auto new_call = CreateInplaceCall(target_call, {exact_matches[current_match_index][1]}); + auto new_call = + CreateInplaceCall(target_call, {exact_matches[current_match_index]->arg_idxs[0]}); // now replace the binding appropriately if (auto* var_binding_node = target_binding.as()) { auto var_binding = GetRef(var_binding_node); @@ -974,26 +1011,11 @@ tvm::transform::Pass DataflowUseInplaceCalls() { 0, "DataflowInsertInPlaceCalls", {}, false); } -Array>> DataflowInplaceAnalysis(const DataflowBlock& block, - const Array& inputs) { +Array> DataflowInplaceAnalysis(const DataflowBlock& block, + const Array& inputs) { auto index_lists = relax::FindInplaceOpportunities(block, inputs); - Array> size_match_array; - for (auto indices : index_lists.first) { - Array index_array; - for (auto index : indices) { - index_array.push_back(Integer(index)); - } - size_match_array.push_back(index_array); - } - Array> exact_match_array; - for (auto indices : index_lists.second) { - Array index_array; - for (auto index : indices) { - index_array.push_back(Integer(index)); - } - exact_match_array.push_back(index_array); - } - return {size_match_array, exact_match_array}; + return {Array(index_lists.first.begin(), index_lists.first.end()), + Array(index_lists.second.begin(), index_lists.second.end())}; } // these are exposed only for testing diff --git a/tests/python/relax/test_dataflow_inplace.py b/tests/python/relax/test_dataflow_inplace.py index 097325c6d56f..7bbf69378fed 100644 --- a/tests/python/relax/test_dataflow_inplace.py +++ b/tests/python/relax/test_dataflow_inplace.py @@ -341,14 +341,14 @@ def main( # order does not matter for the listing of candidates, so we have to implement as sets def assert_candidate_list( - actual: List[List[int]], expected: List[Tuple[int, Set[int]]] + actual: List[Tuple[int, Set[int]]], expected: List[Tuple[int, Set[int]]] ) -> None: assert len(actual) == len(expected) for i in range(len(actual)): assert actual[i][0] == expected[i][0] - assert len(expected[i][1]) == len(actual[i]) - 1 - for j in range(len(expected[i][1])): - assert actual[i][j + 1] in expected[i][1] + assert len(expected[i][1]) == len(actual[i][1]) + for idx in actual[i][1]: + assert idx in expected[i][1] assert_candidate_list(size_match, [(1, {0, 1}), (4, {1}), (5, {0, 1})]) # TODO(@slyubomirsky): I couldn't think of an easy example where sizes don't match, From cc48b257cc1e6c1ef176e81c04397e5cee68c5c2 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 15 Jan 2024 21:34:25 -0500 Subject: [PATCH 51/55] Style fix --- python/tvm/relax/testing/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/testing/transform.py b/python/tvm/relax/testing/transform.py index 0d1e4a1477e9..85560f97fdd6 100644 --- a/python/tvm/relax/testing/transform.py +++ b/python/tvm/relax/testing/transform.py @@ -187,7 +187,7 @@ class InplaceOpportunity(Object): Indices of arguments that are eligible to be used as in-place targets. """ - def __init__(_): + def __init__(self, _binding_idx, _arg_idxs): raise NotImplementedError("Constructor for InplaceOpportunity not exposed!") From a3683e79976930c28cce944568403f35496eb075 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 15 Jan 2024 22:30:38 -0500 Subject: [PATCH 52/55] Use the analyzer to handle dynamic cases too --- python/tvm/relax/testing/transform.py | 3 +- src/relax/transform/dataflow_inplace.cc | 44 ++++---- tests/python/relax/test_dataflow_inplace.py | 107 +++++++++++++++++++- 3 files changed, 128 insertions(+), 26 deletions(-) diff --git a/python/tvm/relax/testing/transform.py b/python/tvm/relax/testing/transform.py index 85560f97fdd6..d9d90799f55d 100644 --- a/python/tvm/relax/testing/transform.py +++ b/python/tvm/relax/testing/transform.py @@ -192,7 +192,7 @@ def __init__(self, _binding_idx, _arg_idxs): def dataflow_inplace_analysis( - block: DataflowBlock, inputs: List[Var] + block: DataflowBlock, inputs: List[Var], mod: IRModule ) -> Tuple[List[Tuple[int, Set[int]]], List[Tuple[int, Set[int]]]]: """ Inner function for the dataflow inplace transformation exposed for testing. @@ -202,6 +202,7 @@ def dataflow_inplace_analysis( index_lists = tvm.get_global_func("relax.testing.transform.DataflowInplaceAnalysis")( block, inputs, + mod ) # type: ignore def convert(opp_list): diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index 04c94cc48dbb..6e44ac5c958b 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -322,13 +322,8 @@ std::unordered_set GatherCandidateSin return {}; } // don't consider cases where we don't know the shape at compile time - // TODO(@slyubomirsky): variables might be okay if we use the arithmetic analyzer - if (auto* shape_node = tensor_info->shape.as()) { - for (auto dim : shape_node->values) { - if (!dim.as()) { - return {}; - } - } + // (we will use the analyzer to do best-effort analysis where there are vars) + if (tensor_info->shape.as()) { return {GetRef(tensor_info)}; } else { return {}; @@ -356,7 +351,8 @@ std::unordered_set GatherCandidateSin // if the shapes match _exactly_. Performs this check recursively and ensures the // stated condition is true for all tensor members of the struct info (return false // if a single pair of corresponding tensors does not meet the condition). -std::pair SizeMatches(const StructInfo& target_info, const StructInfo& arg_info) { +std::pair SizeMatches(const StructInfo& target_info, const StructInfo& arg_info, + const BlockBuilder& ctx) { if (target_info.as() && arg_info.as()) { auto target_tensor = Downcast(target_info); auto arg_tensor = Downcast(arg_info); @@ -369,18 +365,13 @@ std::pair SizeMatches(const StructInfo& target_info, const StructInf auto arg_shape = Downcast(arg_tensor->shape); PrimExpr target_size = NumElements(target_shape); PrimExpr arg_size = NumElements(arg_shape); - if (!target_size.as() || !arg_size.as() || - Downcast(target_size)->value < Downcast(arg_size)->value) { + if (!ctx->GetAnalyzer()->CanProve(arg_size >= target_size)) { return {false, false}; } // exact match: number of dims and each dim matches if (target_shape->values.size() == arg_shape->values.size()) { for (size_t i = 0; i < target_shape->values.size(); i++) { - if (!arg_shape->values[i].as()) { - return {false, false}; - } - if (Downcast(target_shape->values[i])->value != - Downcast(arg_shape->values[i])->value) { + if (!ctx->GetAnalyzer()->CanProveEqual(target_shape->values[i], arg_shape->values[i])) { return {true, false}; } } @@ -407,7 +398,7 @@ std::pair SizeMatches(const StructInfo& target_info, const StructInf continue; } auto [field_size_match, field_exact_match] = - SizeMatches(target_tup->fields[i], arg_tup->fields[i]); + SizeMatches(target_tup->fields[i], arg_tup->fields[i], ctx); if (!field_size_match) { return {false, false}; } @@ -574,7 +565,8 @@ class InplaceOpportunity : public ObjectRef { // The first element is the index of the *binding* in the block. // All remaining elements are the indices of *eligible arguments* in that call. std::pair, std::vector> -FindInplaceOpportunities(const DataflowBlock& block, const Array& inputs) { +FindInplaceOpportunities(const DataflowBlock& block, const Array& inputs, + const BlockBuilder& ctx) { auto live_ranges = AnalyzeLiveness(block); AliasAnalyzer analyzer; auto alias_info = analyzer.Analyze(block, inputs); @@ -633,10 +625,10 @@ FindInplaceOpportunities(const DataflowBlock& block, const Array& inputs) { for (size_t j = 0; j < call_node->args.size(); j++) { auto arg = call_node->args[j]; for (auto target : target_sinfo) { - std::pair match = SizeMatches(target, GetStructInfo(arg)); - if (match.first) { + auto [matches_size, matches_exactly] = SizeMatches(target, GetStructInfo(arg), ctx); + if (matches_size) { candidates.insert(static_cast(j)); - if (match.second) { + if (matches_exactly) { exact_match_candidates.insert(static_cast(j)); } } @@ -829,7 +821,7 @@ class ModuleInplaceTransformer : public ExprMutator { // For now, only handle exact match cases. // Note: Not passing any input values for now, as we can't make any assumptions // about them. - auto matches_found = FindInplaceOpportunities(block, {}); + auto matches_found = FindInplaceOpportunities(block, {}, builder_); auto exact_matches = matches_found.second; Array new_bindings; @@ -1003,17 +995,21 @@ Array DataflowAliasAnalysis(const DataflowBlock& block, Array in // this would be preferable to do as a dataflow block pass, // but the transformation adds new PrimFuncs, so it affects the module tvm::transform::Pass DataflowUseInplaceCalls() { - return tvm::transform::CreateModulePass( + auto inplace_pass = tvm::transform::CreateModulePass( [](const IRModule& mod, const PassContext& ctx) -> IRModule { ModuleInplaceTransformer transformer(mod); return transformer.Transform(); }, 0, "DataflowInsertInPlaceCalls", {}, false); + // odd quirk: if Normalize is not explicitly called, then the function + // StructInfo will not be properly updated + return tvm::transform::Sequential({inplace_pass, Normalize()}); } Array> DataflowInplaceAnalysis(const DataflowBlock& block, - const Array& inputs) { - auto index_lists = relax::FindInplaceOpportunities(block, inputs); + const Array& inputs, + const IRModule& mod) { + auto index_lists = relax::FindInplaceOpportunities(block, inputs, BlockBuilder::Create(mod)); return {Array(index_lists.first.begin(), index_lists.first.end()), Array(index_lists.second.begin(), index_lists.second.end())}; } diff --git a/tests/python/relax/test_dataflow_inplace.py b/tests/python/relax/test_dataflow_inplace.py index 7bbf69378fed..8d5eb07c7858 100644 --- a/tests/python/relax/test_dataflow_inplace.py +++ b/tests/python/relax/test_dataflow_inplace.py @@ -337,7 +337,9 @@ def main( return ret block = InplaceBasic["main"].body.blocks[0] - size_match, exact_match = dataflow_inplace_analysis(block, InplaceBasic["main"].params) + size_match, exact_match = dataflow_inplace_analysis( + block, InplaceBasic["main"].params, InplaceBasic + ) # order does not matter for the listing of candidates, so we have to implement as sets def assert_candidate_list( @@ -535,5 +537,108 @@ def main( assert (expected == res.numpy()).all() +def test_dynamic(): + @I.ir_module + class DynamicTestCase: + @R.function + def main( + x: R.Tensor(("a", "b"), dtype="float32"), y: R.Tensor(("a", "b"), dtype="float32") + ) -> R.Tensor(("a", "b"), dtype="float32"): + with R.dataflow(): + z = R.add(x, y) + # Cannot be done in-place because x and y are arguments + a = R.add(z, y) # this one can be done in-place + s = R.subtract(a, a) # No broadcast. Can be done in-place + R.output(s) + return s + + # the result should be all zeroes + transform_pass = DataflowUseInplaceCalls() + new_mod = transform_pass(DynamicTestCase) + + @I.ir_module + class Expected: + @T.prim_func(private=True) + def add_inplace(var_A: T.handle, var_B: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + a, b = T.int64(), T.int64() + A = T.match_buffer(var_A, (a, b)) + B = T.match_buffer(var_B, (a, b)) + for ax0, ax1 in T.grid(a, b): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1]) + T.writes(A[v_ax0, v_ax1]) + A[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1] + + @T.prim_func(private=True) + def subtract_inplace(var_A: T.handle, var_B: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + a, b = T.int64(), T.int64() + A = T.match_buffer(var_A, (a, b)) + B = T.match_buffer(var_B, (a, b)) + for ax0, ax1 in T.grid(a, b): + with T.block("T_subtract"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1]) + T.writes(B[v_ax0, v_ax1]) + B[v_ax0, v_ax1] = A[v_ax0, v_ax1] - B[v_ax0, v_ax1] + + @R.function + def main( + x: R.Tensor(("a", "b"), dtype="float32"), y: R.Tensor(("a", "b"), dtype="float32") + ) -> R.Tensor(("a", "b"), dtype="float32"): + a = T.int64() + b = T.int64() + cls = Expected + with R.dataflow(): + z = R.add(x, y) + a_1 = R.call_tir_inplace( + cls.add_inplace, + (z, y), + out_sinfo=R.Tensor((a, b), dtype="float32"), + inplace_indices=[0], + ) + s = R.call_tir_inplace( + cls.subtract_inplace, + (a_1, a_1), + out_sinfo=R.Tensor((a, b), dtype="float32"), + inplace_indices=[1], + ) + R.output(s) + return s + + tvm.ir.assert_structural_equal(new_mod, Expected, map_free_vars=True) + x = tvm.nd.array(np.random.rand(2, 3).astype("float32")) + y = tvm.nd.array(np.random.rand(2, 3).astype("float32")) + expected = np.zeros((2, 3), dtype="float32") + + target = tvm.target.Target("llvm") + ex = relax.build(new_mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + res = vm["main"](x, y) + assert (expected == res.numpy()).all() + + +def test_dynamic_mismatch(): + # cannot statically prove the shapes to be equal so the module should be unchanged + @I.ir_module + class DynamicMistmatchTestCase: + @R.function + def main( + x: R.Tensor(("a", "b"), dtype="float32"), y: R.Tensor(("c", "d"), dtype="float32") + ): + with R.dataflow(): + z = R.add(x, y) + # Cannot be done in-place because x and y are arguments + a = R.add(z, y) # cannot conclude that shapes match + R.output(a) + return a + + transform_pass = DataflowUseInplaceCalls() + new_mod = transform_pass(DynamicMistmatchTestCase) + tvm.ir.assert_structural_equal(new_mod, DynamicMistmatchTestCase) + + if __name__ == "__main__": testing.main() From 463893fecfec1a4a4ba6c7ae5906d88162a5b502 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 15 Jan 2024 22:36:33 -0500 Subject: [PATCH 53/55] Whitespace --- python/tvm/relax/testing/transform.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/tvm/relax/testing/transform.py b/python/tvm/relax/testing/transform.py index d9d90799f55d..42dbd37d2931 100644 --- a/python/tvm/relax/testing/transform.py +++ b/python/tvm/relax/testing/transform.py @@ -200,9 +200,7 @@ def dataflow_inplace_analysis( if "PYTEST_CURRENT_TEST" not in os.environ: logging.warning("The function dataflow_inplace_analysis is exposed for testing only.") index_lists = tvm.get_global_func("relax.testing.transform.DataflowInplaceAnalysis")( - block, - inputs, - mod + block, inputs, mod ) # type: ignore def convert(opp_list): From 559f197ff76e63021c75819874d596ff133369cc Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 16 Jan 2024 21:34:07 -0500 Subject: [PATCH 54/55] Use BlockBuilder APIs more to avoid re-normalizing --- src/relax/transform/dataflow_inplace.cc | 93 +++++++++++++------------ 1 file changed, 48 insertions(+), 45 deletions(-) diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index 6e44ac5c958b..72efb5881a7f 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -788,10 +788,7 @@ class ModuleInplaceTransformer : public ExprMutator { if (auto* func_node = kv.second.as()) { auto gv = kv.first; auto func_params = func_node->params; - auto function = GetRef(func_node); - auto* function_cow = function.CopyOnWrite(); - auto new_body = VisitExpr(function->body); - function_cow->body = new_body; + auto function = Downcast(VisitExpr(GetRef(func_node))); builder_->UpdateFunction(gv, function); } } @@ -804,11 +801,10 @@ class ModuleInplaceTransformer : public ExprMutator { return ret; } - // for handling inner functions Expr VisitExpr_(const FunctionNode* op) override { auto old_func_params = func_params; func_params = op->params; - auto ret = ExprMutator::VisitExpr(GetRef(op)); + auto ret = ExprMutator::VisitExpr_(op); func_params = old_func_params; return ret; } @@ -817,44 +813,51 @@ class ModuleInplaceTransformer : public ExprMutator { // and replace any valid calls in them BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override { auto block = GetRef(op); + auto old_idxs = inplace_idxs; // For now, only handle exact match cases. // Note: Not passing any input values for now, as we can't make any assumptions // about them. auto matches_found = FindInplaceOpportunities(block, {}, builder_); - auto exact_matches = matches_found.second; - - Array new_bindings; - int current_match_index = 0; - for (size_t i = 0; i < block->bindings.size(); i++) { - if (current_match_index >= static_cast(exact_matches.size()) || - exact_matches[current_match_index]->binding_idx.IntValue() != static_cast(i)) { - new_bindings.push_back(block->bindings[i]); - continue; - } + Map> new_idxs; + for (auto match : matches_found.second) { + new_idxs.Set(block->bindings[match->binding_idx.IntValue()], match->arg_idxs); + } - auto target_binding = block->bindings[i]; - auto target_call = Downcast(GetBoundValue(target_binding)); - // can just pick the first index arbitrarily (only using one output for now too) - auto new_call = - CreateInplaceCall(target_call, {exact_matches[current_match_index]->arg_idxs[0]}); - // now replace the binding appropriately - if (auto* var_binding_node = target_binding.as()) { - auto var_binding = GetRef(var_binding_node); - auto* var_binding_cow = var_binding.CopyOnWrite(); - var_binding_cow->value = new_call; - new_bindings.push_back(var_binding); - } else if (auto* match_cast_node = target_binding.as()) { - auto match_cast = GetRef(match_cast_node); - auto* match_cast_cow = match_cast.CopyOnWrite(); - match_cast_cow->value = new_call; - new_bindings.push_back(match_cast); - } else { - CHECK(false) << "Invalid binding type"; - } - current_match_index++; + inplace_idxs = new_idxs; + auto ret = ExprMutator::VisitBindingBlock_(op); + inplace_idxs = old_idxs; + return ret; + } + + Expr ReplaceBoundCall(const Binding& binding) { + // can just pick the first index arbitrarily (only using one output for now too) + // now replace the binding appropriately + auto arg_idxs = inplace_idxs.at(binding); + auto target = Downcast(GetBoundValue(binding)); + auto new_call = CreateInplaceCall(target, {arg_idxs[0]}); + return builder_->Normalize(new_call); + } + + void VisitBinding_(const VarBindingNode* binding) override { + auto binding_ref = GetRef(binding); + if (!inplace_idxs.count(binding_ref)) { + ExprMutator::VisitBinding_(binding); + return; + } + Expr new_value = ReplaceBoundCall(binding_ref); + builder_->EmitNormalized(VarBinding(binding->var, new_value, binding->span)); + } + + void VisitBinding_(const MatchCastNode* binding) override { + auto binding_ref = GetRef(binding); + if (!inplace_idxs.count(binding_ref)) { + ExprMutator::VisitBinding_(binding); + return; } - return DataflowBlock(new_bindings, block->span); + Expr new_value = ReplaceBoundCall(binding_ref); + builder_->EmitNormalized( + MatchCast(binding->var, new_value, binding->struct_info, binding->span)); } // Given the call and indices of arguments that could be done in-place, @@ -947,10 +950,13 @@ class ModuleInplaceTransformer : public ExprMutator { private: const IRModule& mod_; - Array - legalizers_added; // Keep track of legalizers we add so we can clean up at the end. - Array func_params; // The current function's params will be treated as non-aliased - // (we are assuming good behavior on the user's part). + // Keep track of legalizers we add so we can clean up at the end. + Array legalizers_added; + // The current function's params will be treated as non-aliased + // (we are assuming good behavior on the user's part). + Array func_params; + // map of eligible bindings to indices of arguments that can be used as the in-place target + Map> inplace_idxs; }; namespace transform { @@ -995,15 +1001,12 @@ Array DataflowAliasAnalysis(const DataflowBlock& block, Array in // this would be preferable to do as a dataflow block pass, // but the transformation adds new PrimFuncs, so it affects the module tvm::transform::Pass DataflowUseInplaceCalls() { - auto inplace_pass = tvm::transform::CreateModulePass( + return tvm::transform::CreateModulePass( [](const IRModule& mod, const PassContext& ctx) -> IRModule { ModuleInplaceTransformer transformer(mod); return transformer.Transform(); }, 0, "DataflowInsertInPlaceCalls", {}, false); - // odd quirk: if Normalize is not explicitly called, then the function - // StructInfo will not be properly updated - return tvm::transform::Sequential({inplace_pass, Normalize()}); } Array> DataflowInplaceAnalysis(const DataflowBlock& block, From d9a7973b66ac6825304f67f66983d24a2febcc52 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 16 Jan 2024 22:10:18 -0500 Subject: [PATCH 55/55] Check for expired vars at start of loop so that the use of continue does not skip that step --- src/relax/transform/dataflow_inplace.cc | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index 72efb5881a7f..755c5dbab433 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -600,6 +600,18 @@ FindInplaceOpportunities(const DataflowBlock& block, const Array& inputs, currently_live.insert(live_var); last_live++; } + // remove vars whose range has come to an end + // (keep a separate set to avoid changing the set while iterating on it) + std::unordered_set remove; + for (auto var : currently_live) { + auto live_range = live_ranges[var]; + if (live_range.second < static_cast(i)) { + remove.insert(var); + } + } + for (auto var : remove) { + currently_live.erase(var); + } // if we reach a binding check the conditions Binding b = block->bindings[i]; @@ -678,19 +690,6 @@ FindInplaceOpportunities(const DataflowBlock& block, const Array& inputs, exact_match_list.push_back(InplaceOpportunity(Integer(i), exact_candidate_list)); } } - - // remove vars whose range has come to an end - // (keep a separate set to avoid changing the sit while iterating on it) - std::unordered_set remove; - for (auto var : currently_live) { - auto live_range = live_ranges[var]; - if (live_range.second <= static_cast(i)) { - remove.insert(var); - } - } - for (auto var : remove) { - currently_live.erase(var); - } } return {size_match_list, exact_match_list};