-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Unity] Implement relax.transform.KillAfterLastUse #15810
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Unity] Implement relax.transform.KillAfterLastUse #15810
Conversation
894868a to
9559725
Compare
9559725 to
d5205dc
Compare
This was implemented while debugging CI failures in apache#15810, but is not otherwise related to the changes in that PR.
| namespace tvm { | ||
| namespace relax { | ||
|
|
||
| class UnusedTrivialBindingRemover : public ExprMutator { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like a reasonable simplification utility that could be used elsewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, and that's my plan for a follow-up PR. Well, either that, or to merge it with the existing CanonicalizeBindings utility, which almost always requires a dead-code elimination step afterwards to remove the unused trivial bindings.
Prior to this commit, intermediate objects produced while executing a Relax function would persist until the end of the Relax function. While re-use of static allocations is handled by the `StaticPlanBlockMemory` transform, re-use of dynamic allocations is handled by the `relax_vm::PooledAllocator`. For large Relax functions representing end-to-end model execution, releasing memory from the VM registers to the `relax_vm::PooledAllocator` at the end of the function call may be insufficient. This commit introduces a new pass, `relax.transform.KillAfterLastUse`, which identifies the last usage of each Relax variable and inserts a `relax.memory.kill_tensor`, `relax.memory.kill_storage`, or `relax.vm.kill_object` call depending on the object type. This insertion is suppressed if a Relax variables is already killed, such as static allocations and tensors tracked by `StaticPlanBlockMemory`.
8b79c6a to
ebc90a0
Compare
|
Unfortunately, this PR broke MLC LLM's build pipeline, more specifically, the command below python3 -m mlc_llm.build --model dist/models/llama-2-13b-chat-hf/ --quantization q4f16_1leads to the the following error message: DetailsTo briefly explain the root cause of this issue, As a temporary solution, I'd love to propose that we revert this commit for now to quickly unblock us from building any LLM model, but I'm happy to get it back immediately as soon as the issue is fixed. |
)" This reverts commit aa4587f. Unfortunately, this PR broke MLC LLM's build pipeline, more specifically, the command below ```python python3 -m mlc_llm.build --model dist/models/llama-2-13b-chat-hf/ --quantization q4f16_1 ``` leads to the the following error message: <details> ``` Traceback (most recent call last): File "<frozen runpy>", line 198, in _run_module_as_main File "<frozen runpy>", line 88, in _run_code File "/opt/scratch/junrushao/mlc-llm/mlc_llm/build.py", line 13, in <module> main() File "/opt/scratch/junrushao/mlc-llm/mlc_llm/build.py", line 10, in main core.build_model_from_args(parsed_args) File "/opt/scratch/junrushao/mlc-llm/mlc_llm/core.py", line 616, in build_model_from_args new_params = utils.convert_weights(param_manager, params, args) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/scratch/junrushao/mlc-llm/mlc_llm/utils.py", line 258, in convert_weights vm["transform_params"]() File "/opt/scratch/junrushao/tvm-dev/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__ raise_last_ffi_error() File "/opt/scratch/junrushao/tvm-dev/python/tvm/_ffi/base.py", line 476, in raise_last_ffi_error raise py_err File "/opt/scratch/junrushao/tvm-dev/src/runtime/relax_vm/vm.cc", line 634, in tvm::runtime::relax_vm::VirtualMachineImpl::InvokeClosurePacked(tvm::runtime::ObjectRef const&, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) clo->impl.CallPacked(TVMArgs(values.data(), tcodes.data(), args.size() + 1), rv); ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/scratch/junrushao/tvm-dev/src/runtime/relax_vm/vm.cc", line 708, in operator() *rv = static_cast<VirtualMachineImpl*>(ctx_ptr)->InvokeBytecode(gf_idx, inputs); ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/scratch/junrushao/tvm-dev/src/runtime/relax_vm/vm.cc", line 765, in tvm::runtime::relax_vm::VirtualMachineImpl::InvokeBytecode(long, std::vector<tvm::runtime::TVMRetValue, std::allocator<tvm::runtime::TVMRetValue> > const&) RunLoop(); File "/opt/scratch/junrushao/tvm-dev/src/runtime/relax_vm/vm.cc", line 890, in tvm::runtime::relax_vm::VirtualMachineImpl::RunLoop() this->RunInstrCall(curr_frame, instr); ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/scratch/junrushao/tvm-dev/src/runtime/relax_vm/vm.cc", line 843, in tvm::runtime::relax_vm::VirtualMachineImpl::RunInstrCall(tvm::runtime::relax_vm::VMFrame*, tvm::runtime::relax_vm::Instruction) this->InvokeClosurePacked(func_pool_[instr.func_idx], args, &ret); ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/scratch/junrushao/mlc-llm/mlc_llm/relax_model/param_manager.py", line 579, in set_item loaded_params[i] = tvm.nd.array(computed_param, device=device_cpu) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/scratch/junrushao/tvm-dev/python/tvm/runtime/ndarray.py", line 635, in array return empty(arr.shape, arr.dtype, device, mem_scope).copyfrom(arr) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/scratch/junrushao/tvm-dev/python/tvm/runtime/ndarray.py", line 390, in empty dtype = DataType(dtype) ^^^^^^^^^^^^^^^ File "/opt/scratch/junrushao/tvm-dev/python/tvm/_ffi/runtime_ctypes.py", line 174, in __init__ raise ValueError("Do not know how to handle type %s" % type_str) ValueError: Do not know how to handle type object ``` </details> To briefly explain the root cause of this issue, `set_item` method, as defined [here](https://github.com/mlc-ai/mlc-llm/blob/4f4a93f03fed3d900605c02de575d7d5f429ed79/mlc_llm/relax_model/param_manager.py#L576-L579), gets `computed_param=None` after this commit. As a temporary solution, I'd love to propose that we revert this commit for now to quickly unblock us from building any LLM model, but I'm happy to get it back immediately as soon as the issue is fixed.
The `KillAfterLastUse` pass that was implemented in apache#15810 checked for trivial re-bindings in `VarBinding` nodes, but not in `MatchCast` nodes. As a result, `CodeGenVM`'s de-duplication of registers resulted in the object being killed prematurely. ```python y = R.match_cast(x, R.Tensor(...)) # Trivial rebinding here. # CodeGenVM has these share a register. R.memory.kill_tensor(x) # Kill x after last usage. # Register is set to None. _ = R.ExternFunc("set_item")(y) # Use of the cleared register through y. ```
) The `KillAfterLastUse` pass that was implemented in #15810 checked for trivial re-bindings in `VarBinding` nodes, but not in `MatchCast` nodes. As a result, `CodeGenVM`'s de-duplication of registers resulted in the object being killed prematurely. ```python y = R.match_cast(x, R.Tensor(...)) # Trivial rebinding here. # CodeGenVM has these share a register. R.memory.kill_tensor(x) # Kill x after last usage. # Register is set to None. _ = R.ExternFunc("set_item")(y) # Use of the cleared register through y. ```
* [Unity][VM] Improved error message in CodeGenVM::EmitKillObject This was implemented while debugging CI failures in #15810, but is not otherwise related to the changes in that PR. * ci bump
Prior to this commit, intermediate objects produced while executing a Relax function would persist until the end of the Relax function. While re-use of static allocations is handled by the
StaticPlanBlockMemorytransform, re-use of dynamic allocations is handled by therelax_vm::PooledAllocator. For large Relax functions representing end-to-end model execution, releasing memory from the VM registers to therelax_vm::PooledAllocatorat the end of the function call may be insufficient.This commit introduces a new pass,
relax.transform.KillAfterLastUse, which identifies the last usage of each Relax variable and inserts arelax.memory.kill_tensor,relax.memory.kill_storage, orrelax.vm.kill_objectcall depending on the object type. This insertion is suppressed if a Relax variables is already killed, such as static allocations and tensors tracked byStaticPlanBlockMemory.