diff --git a/python/tvm/relax/testing/vm.py b/python/tvm/relax/testing/vm.py index 79da54be1010..37bcf870a5df 100644 --- a/python/tvm/relax/testing/vm.py +++ b/python/tvm/relax/testing/vm.py @@ -83,3 +83,8 @@ def check_saved_func(vm: relax.VirtualMachine, func_name: str, *inputs: List[Any res2 = vm[saved_name]() tvm.testing.assert_allclose(res1.numpy(), res2.numpy(), rtol=1e-7, atol=1e-7) return res1 + + +@tvm.register_func("test.vm.check_if_defined") +def check_if_defined(obj: tvm.Object) -> tvm.tir.IntImm: + return tvm.runtime.convert(obj is not None) diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index caee0a0c13d6..f285c93bd7c0 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -122,14 +122,20 @@ class CodeGenVM : public ExprFunctor { Instruction::Arg VisitExpr_(const SeqExprNode* op) final { for (auto block : op->blocks) { for (Binding binding : block->bindings) { - Instruction::Arg value; - if (auto* var_binding = binding.as()) { - value = this->VisitExpr(var_binding->value); - } else if (auto* match_cast = binding.as()) { - value = this->VisitExpr(match_cast->value); - } else { - LOG(FATAL) << "Unsupported binding " << binding->GetTypeKey(); + Expr expr = GetBoundValue(binding); + + Instruction::Arg value = VisitExpr(expr); + if (expr.as()) { + // For a normalized relax module, there should be one + // register for each relax::Binding. This makes the Relax + // semantics of R.vm.kill_* operate the same as the Python + // "del" operator. These bindings may be removable by using + // relax.transform.CanonicalizeBindings earlier in lowering. + RegName new_reg = NewRegister(); + builder_->EmitCall("vm.builtin.copy", {value}, new_reg); + value = Instruction::Arg::Register(new_reg); } + this->var_arg_map_.insert({binding->var, value}); } } diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc index 9ac65f6f6eb1..ec1678e9e0f3 100644 --- a/src/relax/backend/vm/codegen_vm_tir.cc +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -203,14 +203,20 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { Optional VisitExpr_(const SeqExprNode* op) final { for (auto block : op->blocks) { for (Binding binding : block->bindings) { - Optional value; - if (auto* var_binding = binding.as()) { - value = this->VisitExpr(var_binding->value); - } else if (auto* match_cast = binding.as()) { - value = this->VisitExpr(match_cast->value); - } else { - LOG(FATAL) << "Unsupported binding " << binding->GetTypeKey(); + Expr expr = GetBoundValue(binding); + Optional value = VisitExpr(expr); + + if (expr.as() && value.defined()) { + // For a normalized relax module, there should be one + // register for each relax::Binding. This makes the Relax + // semantics of R.vm.kill_* operate the same as the Python + // "del" operator. These bindings may be removable by using + // relax.transform.CanonicalizeBindings earlier in lowering. + auto new_reg = NewRegister(); + EmitCallPacked("vm.builtin.copy", {value.value()}, new_reg); + value = RegListGet(new_reg); } + this->var_map_.insert({binding->var, value}); } } diff --git a/src/relax/transform/kill_after_last_use.cc b/src/relax/transform/kill_after_last_use.cc index 5b6a098ab6a6..fdb32356e67f 100644 --- a/src/relax/transform/kill_after_last_use.cc +++ b/src/relax/transform/kill_after_last_use.cc @@ -149,7 +149,7 @@ class CollectLastUsage : public ExprVisitor { // completes, last_usage_of_ contains the last usage point. If // this occurs in an output, then current_binding_ will be // nullptr. - last_usage_of_[UnwrapTrivialBindings(op)] = current_binding_; + last_usage_of_[op] = current_binding_; } void VisitBinding_(const VarBindingNode* binding, const CallNode* val) override { @@ -169,51 +169,18 @@ class CollectLastUsage : public ExprVisitor { << "but instead found " << val->args.size() << " arguments: " << val->args; auto killed_object = val->args[0].as(); ICHECK(killed_object) << "Internal error: non-normalized expression " << GetRef(val); - killed_objects_.insert(UnwrapTrivialBindings(killed_object)); + killed_objects_.insert(killed_object); } else { // Only recursively visit if it isn't one of the special cases. ExprVisitor::VisitBinding_(binding, val); } } - void VisitBinding_(const VarBindingNode* binding, const VarNode* val) override { - // Because the VM re-uses the same register for variable - // re-binding, we need to de-duplicate across trivial bindings in - // order to avoid calling `vm.kill_object` multiple times on the - // same register. In the future, this can be simplified by - // replacing the de-duplication in CodeGenVM with a call to - // CanonicalizeBindings. - trivial_bindings_.insert({binding->var.get(), UnwrapTrivialBindings(val)}); - - // Do not call ExprVisitor::VisitBinding_ here, as the trivial - // rebinding should not be treated as a point of use. - } - - void VisitBinding_(const MatchCastNode* binding) override { - if (auto rebound = binding->value.as()) { - // Because CodeGenVM treats MatchCast nodes identically to - // VarBinding nodes, we must also de-duplicate at this level. - trivial_bindings_.insert({binding->var.get(), UnwrapTrivialBindings(rebound)}); - } else { - ExprVisitor::VisitBinding_(binding); - } - } - void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val) override { constant_tensors_.insert(binding->var.get()); } private: - const VarNode* UnwrapTrivialBindings(const VarNode* var) const { - while (true) { - if (auto it = trivial_bindings_.find(var); it != trivial_bindings_.end()) { - var = it->second; - } else { - return var; - } - } - } - // The current binding being visited, or nullptr if no binding is // being visited. const VarNode* current_binding_{nullptr}; diff --git a/tests/python/relax/test_kill_after_last_use.py b/tests/python/relax/test_kill_after_last_use.py index 1e4dc71877eb..41f4409f0b99 100644 --- a/tests/python/relax/test_kill_after_last_use.py +++ b/tests/python/relax/test_kill_after_last_use.py @@ -94,8 +94,9 @@ class Expected: def main(w: R.Tensor([16, 32], "float32")): x = R.add(w, R.const(1, "float32")) y = R.match_cast(x, R.Tensor([16, 32])) - z = R.add(y, R.const(1, "float32")) _ = R.memory.kill_tensor(x) + z = R.add(y, R.const(1, "float32")) + _ = R.memory.kill_tensor(y) return z After = KillAfterLastUse()(Before) diff --git a/tests/python/relax/test_transform_lazy_transform_params.py b/tests/python/relax/test_transform_lazy_transform_params.py index 9ecebafae6df..6ce728ba95af 100644 --- a/tests/python/relax/test_transform_lazy_transform_params.py +++ b/tests/python/relax/test_transform_lazy_transform_params.py @@ -14,9 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy as np + import tvm import tvm.testing -import pytest + +from tvm import relax from tvm.script import relax as R, tir as T from tvm.script import ir as I from tvm.relax.transform import LazyTransformParams @@ -319,5 +322,59 @@ def main_transform_params() -> R.Tuple: tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True) +def test_output(): + target = "llvm" + dev = tvm.device(target) + + @I.ir_module + class TransformModule: + @R.function + def transform_params( + params: R.Tuple( + R.Tensor((3, "ic", 3, 3), dtype="float32"), + R.Tensor((16, 16, 3, 3), dtype="float32"), + ) + ) -> R.Tuple( + R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor(("ic", 3, 3, 3), dtype="float32") + ): + R.func_attr({"relax.force_pure": True}) + param0 = params[0] + param1 = params[1] + transformed0 = R.permute_dims(param0, [1, 0, 2, 3]) + transformed = (transformed0, param1) + return transformed + + mod = TransformModule + mod = relax.transform.LazyTransformParams()(mod) + mod = relax.transform.LegalizeOps()(mod) + built = relax.build(mod, target=target) + + params = [ + np.random.random(size=(3, 64, 3, 3)).astype("float32"), + np.random.random(size=(16, 16, 3, 3)).astype("float32"), + ] + transformed = {} + expected = [params[0].transpose(1, 0, 2, 3), params[1]] + + @tvm.register_func("get_item", override=True) + def get_item(i): + return tvm.nd.array(params[i], dev) + + @tvm.register_func("set_item", override=True) + def set_item(i, value): + assert i not in transformed, f"Set item called multiple times for index {i}" + transformed[i] = value.numpy() + + vm = relax.VirtualMachine(built, dev) + vm["transform_params"]() + + assert sorted(transformed) == list(range(len(transformed))) + transformed = [value for i, value in sorted(transformed.items())] + assert len(transformed) == len(expected) + + for expected_i, transformed_i in zip(expected, transformed): + tvm.testing.assert_allclose(expected_i, transformed_i) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_vm_codegen_only.py b/tests/python/relax/test_vm_codegen_only.py index d9fb130f3c02..0d461f0713c2 100644 --- a/tests/python/relax/test_vm_codegen_only.py +++ b/tests/python/relax/test_vm_codegen_only.py @@ -421,5 +421,71 @@ def main() -> R.Tensor((4,), dtype="float32"): tvm.testing.assert_allclose(res.numpy(), np.ones((4,), "float32")) +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_preserve_trivial_bindings(exec_mode): + @I.ir_module + class mod: + @R.function + def main(): + callback = R.ExternFunc("test.vm.check_if_defined") + + storage = R.vm.alloc_storage(R.shape([16]), R.prim_value(0), R.dtype("uint8")) + alloc = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape([4]), R.dtype("float32")) + storage_alias = storage + alloc_alias = alloc + + storage_before = callback(storage) + alloc_before = callback(alloc) + storage_alias_before = callback(storage_alias) + alloc_alias_before = callback(alloc_alias) + + _ = R.vm.kill_object(storage) + _ = R.vm.kill_object(alloc) + + storage_after = callback(storage) + alloc_after = callback(alloc) + storage_alias_after = callback(storage_alias) + alloc_alias_after = callback(alloc_alias) + + return ( + storage_before, + alloc_before, + storage_alias_before, + alloc_alias_before, + storage_after, + alloc_after, + storage_alias_after, + alloc_alias_after, + ) + + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + dev = tvm.cpu() + vm = relax.VirtualMachine(ex, dev) + + result_list = vm["main"]() + + # Making a dictionary of expected results is purely to improve + # readability of test failures. This is equivalent to asserting + # on each element of the result array, but lets pytest give us a + # diff of the dictionaries in case of failure. + expected_results = { + "storage_before": True, + "alloc_before": True, + "storage_alias_before": True, + "alloc_alias_before": True, + "storage_after": False, + "alloc_after": False, + "storage_alias_after": True, + "alloc_alias_after": True, + } + + observed_results = { + name: bool(tir_bool) for name, tir_bool in zip(expected_results.keys(), result_list) + } + + assert observed_results == expected_results + + if __name__ == "__main__": tvm.testing.main()