From 93dc34679c61e9f884bdf765f98b21db481477cc Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 1 Oct 2023 21:59:43 +0000 Subject: [PATCH] Revert "[Unity] Implement relax.transform.KillAfterLastUse (#15810)" This reverts commit aa4587feb5103927d95e5e931149debd0a0aeafc. Unfortunately, this PR broke MLC LLM's build pipeline, more specifically, the command below ```python python3 -m mlc_llm.build --model dist/models/llama-2-13b-chat-hf/ --quantization q4f16_1 ``` leads to the the following error message:
``` Traceback (most recent call last): File "", line 198, in _run_module_as_main File "", line 88, in _run_code File "/opt/scratch/junrushao/mlc-llm/mlc_llm/build.py", line 13, in main() File "/opt/scratch/junrushao/mlc-llm/mlc_llm/build.py", line 10, in main core.build_model_from_args(parsed_args) File "/opt/scratch/junrushao/mlc-llm/mlc_llm/core.py", line 616, in build_model_from_args new_params = utils.convert_weights(param_manager, params, args) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/scratch/junrushao/mlc-llm/mlc_llm/utils.py", line 258, in convert_weights vm["transform_params"]() File "/opt/scratch/junrushao/tvm-dev/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__ raise_last_ffi_error() File "/opt/scratch/junrushao/tvm-dev/python/tvm/_ffi/base.py", line 476, in raise_last_ffi_error raise py_err File "/opt/scratch/junrushao/tvm-dev/src/runtime/relax_vm/vm.cc", line 634, in tvm::runtime::relax_vm::VirtualMachineImpl::InvokeClosurePacked(tvm::runtime::ObjectRef const&, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) clo->impl.CallPacked(TVMArgs(values.data(), tcodes.data(), args.size() + 1), rv); ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/scratch/junrushao/tvm-dev/src/runtime/relax_vm/vm.cc", line 708, in operator() *rv = static_cast(ctx_ptr)->InvokeBytecode(gf_idx, inputs); ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/scratch/junrushao/tvm-dev/src/runtime/relax_vm/vm.cc", line 765, in tvm::runtime::relax_vm::VirtualMachineImpl::InvokeBytecode(long, std::vector > const&) RunLoop(); File "/opt/scratch/junrushao/tvm-dev/src/runtime/relax_vm/vm.cc", line 890, in tvm::runtime::relax_vm::VirtualMachineImpl::RunLoop() this->RunInstrCall(curr_frame, instr); ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/scratch/junrushao/tvm-dev/src/runtime/relax_vm/vm.cc", line 843, in tvm::runtime::relax_vm::VirtualMachineImpl::RunInstrCall(tvm::runtime::relax_vm::VMFrame*, tvm::runtime::relax_vm::Instruction) this->InvokeClosurePacked(func_pool_[instr.func_idx], args, &ret); ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/scratch/junrushao/mlc-llm/mlc_llm/relax_model/param_manager.py", line 579, in set_item loaded_params[i] = tvm.nd.array(computed_param, device=device_cpu) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/scratch/junrushao/tvm-dev/python/tvm/runtime/ndarray.py", line 635, in array return empty(arr.shape, arr.dtype, device, mem_scope).copyfrom(arr) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/scratch/junrushao/tvm-dev/python/tvm/runtime/ndarray.py", line 390, in empty dtype = DataType(dtype) ^^^^^^^^^^^^^^^ File "/opt/scratch/junrushao/tvm-dev/python/tvm/_ffi/runtime_ctypes.py", line 174, in __init__ raise ValueError("Do not know how to handle type %s" % type_str) ValueError: Do not know how to handle type object ```
To briefly explain the root cause of this issue, `set_item` method, as defined [here](https://github.com/mlc-ai/mlc-llm/blob/4f4a93f03fed3d900605c02de575d7d5f429ed79/mlc_llm/relax_model/param_manager.py#L576-L579), gets `computed_param=None` after this commit. As a temporary solution, I'd love to propose that we revert this commit for now to quickly unblock us from building any LLM model, but I'm happy to get it back immediately as soon as the issue is fixed. --- python/tvm/relax/transform/transform.py | 10 - python/tvm/relax/vm_build.py | 1 - src/relax/transform/kill_after_last_use.cc | 289 ------------------ src/relax/transform/utils.h | 13 - .../python/relax/test_kill_after_last_use.py | 55 ---- 5 files changed, 368 deletions(-) delete mode 100644 src/relax/transform/kill_after_last_use.cc delete mode 100644 tests/python/relax/test_kill_after_last_use.py diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 2b7a788e3233..0184bb122842 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -388,16 +388,6 @@ def LowerAllocTensor() -> tvm.ir.transform.Pass: return _ffi_api.LowerAllocTensor() # type: ignore -def KillAfterLastUse() -> tvm.ir.transform.Pass: - """Drop all tensor/storage objects after last use - - Returns - ------- - ret : tvm.ir.transform.Pass - """ - return _ffi_api.KillAfterLastUse() # type: ignore - - def VMBuiltinLower() -> tvm.ir.transform.Pass: """Lowering generic intrinsic to VM intrinsics. diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py index d5edeeec69b2..85c45c490140 100644 --- a/python/tvm/relax/vm_build.py +++ b/python/tvm/relax/vm_build.py @@ -315,7 +315,6 @@ def foo(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")): passes.append(relax.transform.RewriteCUDAGraph()) passes.append(relax.transform.LowerAllocTensor()) - passes.append(relax.transform.KillAfterLastUse()) passes.append(relax.transform.VMBuiltinLower()) passes.append(relax.transform.VMShapeLower()) diff --git a/src/relax/transform/kill_after_last_use.cc b/src/relax/transform/kill_after_last_use.cc deleted file mode 100644 index 0f28c6c2b99a..000000000000 --- a/src/relax/transform/kill_after_last_use.cc +++ /dev/null @@ -1,289 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file src/relax/transform/kill_after_last_use.cc - * \brief Kill storage/tensor objects after last use, if not already killed - */ -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#include "utils.h" - -namespace tvm { -namespace relax { - -class UnusedTrivialBindingRemover : public ExprMutator { - public: - static Expr Apply(Expr expr) { - struct UsedCollector : ExprVisitor { - void VisitExpr_(const VarNode* val) override { used.insert(val); } - void VisitExpr_(const DataflowVarNode* val) override { - VisitExpr_(static_cast(val)); - } - - void VisitBinding_(const VarBindingNode* binding, const VarNode* val) override { - has_trivial_binding.insert(binding->var.get()); - ExprVisitor::VisitBinding_(binding, val); - } - void VisitBinding_(const VarBindingNode* binding, const DataflowVarNode* val) override { - VisitBinding_(binding, static_cast(val)); - } - - std::unordered_set used; - std::unordered_set has_trivial_binding; - }; - - UsedCollector collector; - collector(expr); - - auto to_remove = std::move(collector.has_trivial_binding); - for (const auto& used : collector.used) { - to_remove.erase(used); - } - - UnusedTrivialBindingRemover remover(to_remove); - return remover(expr); - } - - private: - explicit UnusedTrivialBindingRemover(std::unordered_set to_remove) - : to_remove_(std::move(to_remove)) {} - - void VisitBinding(const Binding& binding) override { - if (!to_remove_.count(binding->var.get())) { - ExprMutator::VisitBinding(binding); - } - } - - std::unordered_set to_remove_; -}; - -class CollectLastUsage : public ExprVisitor { - public: - struct LastUsage { - std::vector tensors; - std::vector storage; - std::vector objects; - }; - using Result = std::unordered_map; - - static Result Collect(const Expr& expr) { - CollectLastUsage visitor; - visitor(expr); - - Result output; - for (const auto* var : visitor.binding_order_) { - if (auto it = visitor.last_usage_of_.find(var); it != visitor.last_usage_of_.end()) { - const auto* last_usage_point = it->second; - bool is_output = last_usage_point == nullptr; - bool already_killed = visitor.killed_objects_.count(var); - - // Currently, the VM requires that objects to be killed - // objects only exist in VM registers. This requires - // KillAfterLastUse to have more knowledge about the VM - // implementation than should exist at this stage of lowering. - // In the future, this may be handled more easily at the - // CodeGenVM level. - bool stored_in_vm_register = - !(visitor.constant_tensors_.count(var) || var->struct_info_.as() || - var->struct_info_.as() || - var->struct_info_.as()); - - if (!is_output && !already_killed) { - if (visitor.storage_objects_.count(var)) { - output[last_usage_point].storage.push_back(var); - } else if (var->struct_info_.as() && stored_in_vm_register) { - output[last_usage_point].tensors.push_back(var); - } else if (stored_in_vm_register) { - output[last_usage_point].objects.push_back(var); - } - } - } - } - - return output; - } - - void VisitBinding(const Binding& binding) override { - auto cache = current_binding_; - current_binding_ = binding->var.get(); - binding_order_.push_back(current_binding_); - ExprVisitor::VisitBinding(binding); - current_binding_ = cache; - } - - void VisitExpr_(const VarNode* op) override { - ExprVisitor::VisitExpr_(op); - // Overwrite any previous usage, such that after the visitor - // completes, last_usage_of_ contains the last usage point. If - // this occurs in an output, then current_binding_ will be - // nullptr. - last_usage_of_[UnwrapTrivialBindings(op)] = current_binding_; - } - - void VisitBinding_(const VarBindingNode* binding, const CallNode* val) override { - static const Op& vm_alloc_storage = Op::Get("relax.vm.alloc_storage"); - static const Op& mem_alloc_storage = Op::Get("relax.memory.alloc_storage"); - - static const Op& mem_kill_tensor = Op::Get("relax.memory.kill_tensor"); - static const Op& mem_kill_storage = Op::Get("relax.memory.kill_storage"); - static const Op& vm_kill_object = Op::Get("relax.vm.kill_object"); - - if (val->op.same_as(vm_alloc_storage) || val->op.same_as(mem_alloc_storage)) { - storage_objects_.insert(binding->var.get()); - } else if (val->op.same_as(mem_kill_tensor) || val->op.same_as(mem_kill_storage) || - val->op.same_as(vm_kill_object)) { - CHECK_EQ(val->args.size(), 1) - << "Operator " << val->op << " should have one argument, " - << "but instead found " << val->args.size() << " arguments: " << val->args; - auto killed_object = val->args[0].as(); - ICHECK(killed_object) << "Internal error: non-normalized expression " << GetRef(val); - killed_objects_.insert(UnwrapTrivialBindings(killed_object)); - } else { - // Only recursively visit if it isn't one of the special cases. - ExprVisitor::VisitBinding_(binding, val); - } - } - - void VisitBinding_(const VarBindingNode* binding, const VarNode* val) override { - // Because the VM re-uses the same register for variable - // re-binding, we need to de-duplicate across trivial bindings in - // order to avoid calling `vm.kill_object` multiple times on the - // same register. In the future, this can be simplified by - // replacing the de-duplication in CodeGenVM with a call to - // CanonicalizeBindings. - trivial_bindings_.insert({binding->var.get(), UnwrapTrivialBindings(val)}); - - // Do not call ExprVisitor::VisitBinding_ here, as the trivial - // rebinding should not be treated as a point of use. - } - - void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val) override { - constant_tensors_.insert(binding->var.get()); - } - - private: - const VarNode* UnwrapTrivialBindings(const VarNode* var) const { - while (true) { - if (auto it = trivial_bindings_.find(var); it != trivial_bindings_.end()) { - var = it->second; - } else { - return var; - } - } - } - - // The current binding being visited, or nullptr if no binding is - // being visited. - const VarNode* current_binding_{nullptr}; - - // Order of bindings, to ensure consistent order of destruction, in - // case a Binding is the last usage for more than one variable. - std::vector binding_order_; - - // Map from a variable to the last variable binding that makes use - // of it. - std::unordered_map last_usage_of_; - - // Storage objects, eligible for R.vm.kill_object. This cannot be - // determined solely from the StructInfo, because the - // `R.*.alloc_storage` operators return ObjectStructInfo - std::unordered_set storage_objects_; - - // Constants, which do not have a VM register, and may *not* have - // R.builtin.kill_tensor called on them. - std::unordered_set constant_tensors_; - - // Set of objects that already have a call node to kill them. Should not have a duplicate - std::unordered_set killed_objects_; - - // Trivial var-to-var bindings. - std::unordered_map trivial_bindings_; -}; - -class KillInserter : public ExprMutator { - private: - Expr VisitExpr_(const FunctionNode* op) override { - last_usage_ = CollectLastUsage::Collect(GetRef(op)); - auto mutated = ExprMutator::VisitExpr_(op); - last_usage_.clear(); - return mutated; - } - - Expr VisitExpr_(const SeqExprNode* op) override { - last_usage_ = CollectLastUsage::Collect(GetRef(op)); - auto mutated = ExprMutator::VisitExpr_(op); - last_usage_.clear(); - return mutated; - } - - void VisitBinding(const Binding& binding) override { - ExprMutator::VisitBinding(binding); - if (auto it = last_usage_.find(binding->var.get()); it != last_usage_.end()) { - static const Op& mem_kill_tensor = Op::Get("relax.memory.kill_tensor"); - for (const auto& tensor_obj : it->second.tensors) { - builder_->Emit(Call(mem_kill_tensor, {GetRef(tensor_obj)}), /*name_hint=*/"_"); - } - - static const Op& mem_kill_storage = Op::Get("relax.memory.kill_storage"); - for (const VarNode* storage_obj : it->second.storage) { - builder_->Emit(Call(mem_kill_storage, {GetRef(storage_obj)}), /*name_hint=*/"_"); - } - - static const Op& vm_kill_object = Op::Get("relax.vm.kill_object"); - for (const VarNode* obj : it->second.objects) { - builder_->Emit(Call(vm_kill_object, {GetRef(obj)}), /*name_hint=*/"_"); - } - } - } - - CollectLastUsage::Result last_usage_; -}; - -Expr KillAfterLastUse(Expr expr) { - expr = CanonicalizeBindings(expr); - expr = UnusedTrivialBindingRemover::Apply(expr); - - KillInserter mutator; - return mutator(expr); -} - -namespace transform { - -Pass KillAfterLastUse() { - runtime::TypedPackedFunc pass_func = - [=](Function func, IRModule m, PassContext pc) { - return Downcast(relax::KillAfterLastUse(std::move(func))); - }; - return CreateFunctionPass(pass_func, /*opt_level=*/0, "KillAfterLastUse", {}); -} - -TVM_REGISTER_GLOBAL("relax.transform.KillAfterLastUse").set_body_typed(KillAfterLastUse); - -} // namespace transform -} // namespace relax -} // namespace tvm diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h index 78e5c31c7589..6e44f07aa63f 100644 --- a/src/relax/transform/utils.h +++ b/src/relax/transform/utils.h @@ -388,19 +388,6 @@ inline String GetCodegenName(const std::string& composite_name) { */ Expr EliminateCommonSubexpr(const Expr& expr, bool call_only = false); -/* \brief Remove use of trivial bindings - * - * Utility for simplifying relax expressions by folding var bindings - * and match shape nodes. May include other forms of simplification - * in the future. Ideally should be used before constant folding and - * eliminating unused bindings. - * - * \param expr The expression to be canonicalized - * - * \ret The canonicalized expression - */ -Expr CanonicalizeBindings(const Expr& expr); - } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_kill_after_last_use.py b/tests/python/relax/test_kill_after_last_use.py deleted file mode 100644 index eb6e0777ae2a..000000000000 --- a/tests/python/relax/test_kill_after_last_use.py +++ /dev/null @@ -1,55 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import tvm -import tvm.relax -import tvm.testing - -from tvm.script import ir as I, relax as R - -from tvm.relax.transform import KillAfterLastUse - - -def test_basic(): - @I.ir_module - class Before: - @R.function(pure=False) - def main(x: R.Tensor([16, 32], "float32")): - storage = R.memory.alloc_storage(R.shape([2048]), 0, "global", "uint8") - y = R.memory.alloc_tensor(storage, 0, R.shape([16, 32]), "float32") - _dummy = R.call_packed("add_tensors", [x, y], sinfo_args=(R.Tuple,)) - z = R.add(x, y) - return z - - @I.ir_module - class Expected: - @R.function(pure=False) - def main(x: R.Tensor([16, 32], "float32")): - storage = R.memory.alloc_storage(R.shape([2048]), 0, "global", "uint8") - y = R.memory.alloc_tensor(storage, 0, R.shape([16, 32]), "float32") - _ = R.memory.kill_storage(storage) - _dummy = R.call_packed("add_tensors", [x, y], sinfo_args=(R.Tuple,)) - z = R.add(x, y) - _ = R.memory.kill_tensor(y) - return z - - After = KillAfterLastUse()(Before) - tvm.ir.assert_structural_equal(Expected, After) - - -if __name__ == "__main__": - tvm.testing.main()