From da74844b1852bd67565b648a618e08a2fcce677e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 2 Oct 2023 11:56:06 -0500 Subject: [PATCH 1/5] [Unity] Ensure one VM register for each relax binding Prior to this commit, if a relax variable were assigned to itself, through either `VarBinding` or `MatchCast` nodes, the two relax variables would share the same register in the VM. As a result, any upstream transform that deletes an object with `R.memory.kill_tensor` or `R.memory.kill_storage` must be aware of this VM behavior, and to only output one such instruction for each set of aliased registers. This commit updates the VM to produce one register for each aliased relax variable. The trivial bindings can be removed applying `relax.transform.CanonicalizeBindings`, instead of being implicitly de-duplicated at the codegen level. This PR is a follow-up to https://github.com/apache/tvm/pull/15854, with a better long-term solution, but which may have knock-on effects that must also be resolved. In addition, this adds a usage example for the bug reported in https://github.com/apache/tvm/pull/15852, to avoid re-occurrence of similar issues. --- src/relax/backend/vm/codegen_vm.cc | 28 ++++++--- src/relax/transform/kill_after_last_use.cc | 37 +----------- .../python/relax/test_kill_after_last_use.py | 3 +- .../test_transform_lazy_transform_params.py | 59 ++++++++++++++++++- 4 files changed, 83 insertions(+), 44 deletions(-) diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index caee0a0c13d6..0a6657f95793 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -122,14 +122,28 @@ class CodeGenVM : public ExprFunctor { Instruction::Arg VisitExpr_(const SeqExprNode* op) final { for (auto block : op->blocks) { for (Binding binding : block->bindings) { - Instruction::Arg value; - if (auto* var_binding = binding.as()) { - value = this->VisitExpr(var_binding->value); - } else if (auto* match_cast = binding.as()) { - value = this->VisitExpr(match_cast->value); - } else { - LOG(FATAL) << "Unsupported binding " << binding->GetTypeKey(); + Expr expr = [&binding]() { + if (auto* var_binding = binding.as()) { + return var_binding->value; + } else if (auto* match_cast = binding.as()) { + return match_cast->value; + } else { + LOG(FATAL) << "Unsupported binding " << binding->GetTypeKey(); + } + }(); + + Instruction::Arg value = VisitExpr(expr); + if (expr.as()) { + // For a normalized relax module, there should be one + // register for each relax::Binding. This makes the Relax + // semantics of R.vm.kill_* operate the same as the Python + // "del" operator. These bindings may be removable by using + // relax.transform.CanonicalizeBindings earlier in lowering. + RegName new_reg = NewRegister(); + builder_->EmitCall("vm.builtin.copy", {value}, new_reg); + value = Instruction::Arg::Register(new_reg); } + this->var_arg_map_.insert({binding->var, value}); } } diff --git a/src/relax/transform/kill_after_last_use.cc b/src/relax/transform/kill_after_last_use.cc index 5b6a098ab6a6..fdb32356e67f 100644 --- a/src/relax/transform/kill_after_last_use.cc +++ b/src/relax/transform/kill_after_last_use.cc @@ -149,7 +149,7 @@ class CollectLastUsage : public ExprVisitor { // completes, last_usage_of_ contains the last usage point. If // this occurs in an output, then current_binding_ will be // nullptr. - last_usage_of_[UnwrapTrivialBindings(op)] = current_binding_; + last_usage_of_[op] = current_binding_; } void VisitBinding_(const VarBindingNode* binding, const CallNode* val) override { @@ -169,51 +169,18 @@ class CollectLastUsage : public ExprVisitor { << "but instead found " << val->args.size() << " arguments: " << val->args; auto killed_object = val->args[0].as(); ICHECK(killed_object) << "Internal error: non-normalized expression " << GetRef(val); - killed_objects_.insert(UnwrapTrivialBindings(killed_object)); + killed_objects_.insert(killed_object); } else { // Only recursively visit if it isn't one of the special cases. ExprVisitor::VisitBinding_(binding, val); } } - void VisitBinding_(const VarBindingNode* binding, const VarNode* val) override { - // Because the VM re-uses the same register for variable - // re-binding, we need to de-duplicate across trivial bindings in - // order to avoid calling `vm.kill_object` multiple times on the - // same register. In the future, this can be simplified by - // replacing the de-duplication in CodeGenVM with a call to - // CanonicalizeBindings. - trivial_bindings_.insert({binding->var.get(), UnwrapTrivialBindings(val)}); - - // Do not call ExprVisitor::VisitBinding_ here, as the trivial - // rebinding should not be treated as a point of use. - } - - void VisitBinding_(const MatchCastNode* binding) override { - if (auto rebound = binding->value.as()) { - // Because CodeGenVM treats MatchCast nodes identically to - // VarBinding nodes, we must also de-duplicate at this level. - trivial_bindings_.insert({binding->var.get(), UnwrapTrivialBindings(rebound)}); - } else { - ExprVisitor::VisitBinding_(binding); - } - } - void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val) override { constant_tensors_.insert(binding->var.get()); } private: - const VarNode* UnwrapTrivialBindings(const VarNode* var) const { - while (true) { - if (auto it = trivial_bindings_.find(var); it != trivial_bindings_.end()) { - var = it->second; - } else { - return var; - } - } - } - // The current binding being visited, or nullptr if no binding is // being visited. const VarNode* current_binding_{nullptr}; diff --git a/tests/python/relax/test_kill_after_last_use.py b/tests/python/relax/test_kill_after_last_use.py index 1e4dc71877eb..41f4409f0b99 100644 --- a/tests/python/relax/test_kill_after_last_use.py +++ b/tests/python/relax/test_kill_after_last_use.py @@ -94,8 +94,9 @@ class Expected: def main(w: R.Tensor([16, 32], "float32")): x = R.add(w, R.const(1, "float32")) y = R.match_cast(x, R.Tensor([16, 32])) - z = R.add(y, R.const(1, "float32")) _ = R.memory.kill_tensor(x) + z = R.add(y, R.const(1, "float32")) + _ = R.memory.kill_tensor(y) return z After = KillAfterLastUse()(Before) diff --git a/tests/python/relax/test_transform_lazy_transform_params.py b/tests/python/relax/test_transform_lazy_transform_params.py index 9ecebafae6df..6ce728ba95af 100644 --- a/tests/python/relax/test_transform_lazy_transform_params.py +++ b/tests/python/relax/test_transform_lazy_transform_params.py @@ -14,9 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy as np + import tvm import tvm.testing -import pytest + +from tvm import relax from tvm.script import relax as R, tir as T from tvm.script import ir as I from tvm.relax.transform import LazyTransformParams @@ -319,5 +322,59 @@ def main_transform_params() -> R.Tuple: tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True) +def test_output(): + target = "llvm" + dev = tvm.device(target) + + @I.ir_module + class TransformModule: + @R.function + def transform_params( + params: R.Tuple( + R.Tensor((3, "ic", 3, 3), dtype="float32"), + R.Tensor((16, 16, 3, 3), dtype="float32"), + ) + ) -> R.Tuple( + R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor(("ic", 3, 3, 3), dtype="float32") + ): + R.func_attr({"relax.force_pure": True}) + param0 = params[0] + param1 = params[1] + transformed0 = R.permute_dims(param0, [1, 0, 2, 3]) + transformed = (transformed0, param1) + return transformed + + mod = TransformModule + mod = relax.transform.LazyTransformParams()(mod) + mod = relax.transform.LegalizeOps()(mod) + built = relax.build(mod, target=target) + + params = [ + np.random.random(size=(3, 64, 3, 3)).astype("float32"), + np.random.random(size=(16, 16, 3, 3)).astype("float32"), + ] + transformed = {} + expected = [params[0].transpose(1, 0, 2, 3), params[1]] + + @tvm.register_func("get_item", override=True) + def get_item(i): + return tvm.nd.array(params[i], dev) + + @tvm.register_func("set_item", override=True) + def set_item(i, value): + assert i not in transformed, f"Set item called multiple times for index {i}" + transformed[i] = value.numpy() + + vm = relax.VirtualMachine(built, dev) + vm["transform_params"]() + + assert sorted(transformed) == list(range(len(transformed))) + transformed = [value for i, value in sorted(transformed.items())] + assert len(transformed) == len(expected) + + for expected_i, transformed_i in zip(expected, transformed): + tvm.testing.assert_allclose(expected_i, transformed_i) + + if __name__ == "__main__": tvm.testing.main() From c21577353dd89daf044187272e5dfe92492120df Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 25 Oct 2023 08:33:53 -0500 Subject: [PATCH 2/5] Add unit test to validate that the alias is preserved --- tests/python/relax/test_vm_codegen_only.py | 84 ++++++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/tests/python/relax/test_vm_codegen_only.py b/tests/python/relax/test_vm_codegen_only.py index d9fb130f3c02..43ce1f75bc66 100644 --- a/tests/python/relax/test_vm_codegen_only.py +++ b/tests/python/relax/test_vm_codegen_only.py @@ -421,5 +421,89 @@ def main() -> R.Tensor((4,), dtype="float32"): tvm.testing.assert_allclose(res.numpy(), np.ones((4,), "float32")) +@pytest.fixture +def packed_func_check_if_exists(): + func_name = "testing.check_if_none" + + cached = tvm.get_global_func(func_name, allow_missing=True) + + @tvm.register_func(func_name, override=True) + def func(obj: tvm.Object) -> tvm.tir.IntImm: + return tvm.runtime.convert(obj is not None) + + yield func_name + + if cached is None: + tvm._ffi.registry.remove_global_func(func_name) + else: + tvm.register_func(func_name, cached, override=True) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_preserve_trivial_bindings(exec_mode, packed_func_check_if_exists): + @I.ir_module + class mod: + @R.function + def main(): + callback = R.ExternFunc(packed_func_check_if_exists) + + storage = R.vm.alloc_storage(R.shape([16]), R.prim_value(0), R.dtype("uint8")) + alloc = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape([4]), R.dtype("float32")) + storage_alias = storage + alloc_alias = alloc + + storage_before = callback(storage) + alloc_before = callback(alloc) + storage_alias_before = callback(storage_alias) + alloc_alias_before = callback(alloc_alias) + + _ = R.vm.kill_object(storage) + _ = R.vm.kill_object(alloc) + + storage_after = callback(storage) + alloc_after = callback(alloc) + storage_alias_after = callback(storage_alias) + alloc_alias_after = callback(alloc_alias) + + return ( + storage_before, + alloc_before, + storage_alias_before, + alloc_alias_before, + storage_after, + alloc_after, + storage_alias_after, + alloc_alias_after, + ) + + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + dev = tvm.cpu() + vm = relax.VirtualMachine(ex, dev) + + result_list = vm["main"]() + + # Making a dictionary of expected results is purely to improve + # readability of test failures. This is equivalent to asserting + # on each element of the result array, but lets pytest give us a + # diff of the dictionaries in case of failure. + expected_results = { + "storage_before": True, + "alloc_before": True, + "storage_alias_before": True, + "alloc_alias_before": True, + "storage_after": False, + "alloc_after": False, + "storage_alias_after": True, + "alloc_alias_after": True, + } + + observed_results = { + name: bool(tir_bool) for name, tir_bool in zip(expected_results.keys(), result_list) + } + + assert observed_results == expected_results + + if __name__ == "__main__": tvm.testing.main() From 9d1cb70c492de23a60936aed60740a5f6368164e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 25 Oct 2023 08:49:11 -0500 Subject: [PATCH 3/5] Use the new GetBoundValue utility function --- src/relax/backend/vm/codegen_vm.cc | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 0a6657f95793..f285c93bd7c0 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -122,15 +122,7 @@ class CodeGenVM : public ExprFunctor { Instruction::Arg VisitExpr_(const SeqExprNode* op) final { for (auto block : op->blocks) { for (Binding binding : block->bindings) { - Expr expr = [&binding]() { - if (auto* var_binding = binding.as()) { - return var_binding->value; - } else if (auto* match_cast = binding.as()) { - return match_cast->value; - } else { - LOG(FATAL) << "Unsupported binding " << binding->GetTypeKey(); - } - }(); + Expr expr = GetBoundValue(binding); Instruction::Arg value = VisitExpr(expr); if (expr.as()) { From e1d889cd24ffebab8ab7f33b7d00deef97f6dc4c Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 25 Oct 2023 08:49:23 -0500 Subject: [PATCH 4/5] Update VMTIRCodeGen to also avoid de-duplication of bindings --- src/relax/backend/vm/codegen_vm_tir.cc | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc index 9ac65f6f6eb1..ec1678e9e0f3 100644 --- a/src/relax/backend/vm/codegen_vm_tir.cc +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -203,14 +203,20 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { Optional VisitExpr_(const SeqExprNode* op) final { for (auto block : op->blocks) { for (Binding binding : block->bindings) { - Optional value; - if (auto* var_binding = binding.as()) { - value = this->VisitExpr(var_binding->value); - } else if (auto* match_cast = binding.as()) { - value = this->VisitExpr(match_cast->value); - } else { - LOG(FATAL) << "Unsupported binding " << binding->GetTypeKey(); + Expr expr = GetBoundValue(binding); + Optional value = VisitExpr(expr); + + if (expr.as() && value.defined()) { + // For a normalized relax module, there should be one + // register for each relax::Binding. This makes the Relax + // semantics of R.vm.kill_* operate the same as the Python + // "del" operator. These bindings may be removable by using + // relax.transform.CanonicalizeBindings earlier in lowering. + auto new_reg = NewRegister(); + EmitCallPacked("vm.builtin.copy", {value.value()}, new_reg); + value = RegListGet(new_reg); } + this->var_map_.insert({binding->var, value}); } } From 625230c5cf12e341b220a2378e3271c93857dae3 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 30 Oct 2023 14:24:56 -0500 Subject: [PATCH 5/5] Move the callback definition to tvm.relax.testing.vm namespace --- python/tvm/relax/testing/vm.py | 5 +++++ tests/python/relax/test_vm_codegen_only.py | 22 ++-------------------- 2 files changed, 7 insertions(+), 20 deletions(-) diff --git a/python/tvm/relax/testing/vm.py b/python/tvm/relax/testing/vm.py index 79da54be1010..37bcf870a5df 100644 --- a/python/tvm/relax/testing/vm.py +++ b/python/tvm/relax/testing/vm.py @@ -83,3 +83,8 @@ def check_saved_func(vm: relax.VirtualMachine, func_name: str, *inputs: List[Any res2 = vm[saved_name]() tvm.testing.assert_allclose(res1.numpy(), res2.numpy(), rtol=1e-7, atol=1e-7) return res1 + + +@tvm.register_func("test.vm.check_if_defined") +def check_if_defined(obj: tvm.Object) -> tvm.tir.IntImm: + return tvm.runtime.convert(obj is not None) diff --git a/tests/python/relax/test_vm_codegen_only.py b/tests/python/relax/test_vm_codegen_only.py index 43ce1f75bc66..0d461f0713c2 100644 --- a/tests/python/relax/test_vm_codegen_only.py +++ b/tests/python/relax/test_vm_codegen_only.py @@ -421,31 +421,13 @@ def main() -> R.Tensor((4,), dtype="float32"): tvm.testing.assert_allclose(res.numpy(), np.ones((4,), "float32")) -@pytest.fixture -def packed_func_check_if_exists(): - func_name = "testing.check_if_none" - - cached = tvm.get_global_func(func_name, allow_missing=True) - - @tvm.register_func(func_name, override=True) - def func(obj: tvm.Object) -> tvm.tir.IntImm: - return tvm.runtime.convert(obj is not None) - - yield func_name - - if cached is None: - tvm._ffi.registry.remove_global_func(func_name) - else: - tvm.register_func(func_name, cached, override=True) - - @pytest.mark.parametrize("exec_mode", EXEC_MODE) -def test_preserve_trivial_bindings(exec_mode, packed_func_check_if_exists): +def test_preserve_trivial_bindings(exec_mode): @I.ir_module class mod: @R.function def main(): - callback = R.ExternFunc(packed_func_check_if_exists) + callback = R.ExternFunc("test.vm.check_if_defined") storage = R.vm.alloc_storage(R.shape([16]), R.prim_value(0), R.dtype("uint8")) alloc = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape([4]), R.dtype("float32"))