Skip to content

Conversation

@Lunderberg
Copy link
Contributor

Prior to this commit, intermediate objects produced while executing a Relax function would persist until the end of the Relax function. While re-use of static allocations is handled by the StaticPlanBlockMemory transform, re-use of dynamic allocations is handled by the relax_vm::PooledAllocator. For large Relax functions representing end-to-end model execution, releasing memory from the VM registers to the relax_vm::PooledAllocator at the end of the function call may be insufficient.

This commit introduces a new pass, relax.transform.KillAfterLastUse, which identifies the last usage of each Relax variable and inserts a relax.memory.kill_tensor, relax.memory.kill_storage, or relax.vm.kill_object call depending on the object type. This insertion is suppressed if a Relax variables is already killed, such as static allocations and tensors tracked by StaticPlanBlockMemory.

@Lunderberg Lunderberg force-pushed the unity_kill_after_last_use branch from 894868a to 9559725 Compare September 25, 2023 14:12
@Lunderberg Lunderberg force-pushed the unity_kill_after_last_use branch from 9559725 to d5205dc Compare September 26, 2023 15:22
Lunderberg added a commit to Lunderberg/tvm that referenced this pull request Sep 26, 2023
This was implemented while debugging CI failures in
apache#15810, but is not otherwise related
to the changes in that PR.
namespace tvm {
namespace relax {

class UnusedTrivialBindingRemover : public ExprMutator {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a reasonable simplification utility that could be used elsewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, and that's my plan for a follow-up PR. Well, either that, or to merge it with the existing CanonicalizeBindings utility, which almost always requires a dead-code elimination step afterwards to remove the unused trivial bindings.

Prior to this commit, intermediate objects produced while executing a
Relax function would persist until the end of the Relax function.
While re-use of static allocations is handled by the
`StaticPlanBlockMemory` transform, re-use of dynamic allocations is
handled by the `relax_vm::PooledAllocator`.  For large Relax functions
representing end-to-end model execution, releasing memory from the VM
registers to the `relax_vm::PooledAllocator` at the end of the
function call may be insufficient.

This commit introduces a new pass, `relax.transform.KillAfterLastUse`,
which identifies the last usage of each Relax variable and inserts a
`relax.memory.kill_tensor`, `relax.memory.kill_storage`, or
`relax.vm.kill_object` call depending on the object type.  This
insertion is suppressed if a Relax variables is already killed, such
as static allocations and tensors tracked by `StaticPlanBlockMemory`.
@Lunderberg Lunderberg force-pushed the unity_kill_after_last_use branch from 8b79c6a to ebc90a0 Compare September 27, 2023 19:42
@Lunderberg Lunderberg merged commit aa4587f into apache:unity Sep 28, 2023
@Lunderberg Lunderberg deleted the unity_kill_after_last_use branch September 28, 2023 13:54
@junrushao
Copy link
Member

Unfortunately, this PR broke MLC LLM's build pipeline, more specifically, the command below

python3 -m mlc_llm.build --model dist/models/llama-2-13b-chat-hf/ --quantization q4f16_1

leads to the the following error message:

Details
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/opt/scratch/junrushao/mlc-llm/mlc_llm/build.py", line 13, in <module>
    main()
  File "/opt/scratch/junrushao/mlc-llm/mlc_llm/build.py", line 10, in main
    core.build_model_from_args(parsed_args)
  File "/opt/scratch/junrushao/mlc-llm/mlc_llm/core.py", line 616, in build_model_from_args
    new_params = utils.convert_weights(param_manager, params, args)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/junrushao/mlc-llm/mlc_llm/utils.py", line 258, in convert_weights
    vm["transform_params"]()
  File "/opt/scratch/junrushao/tvm-dev/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/opt/scratch/junrushao/tvm-dev/python/tvm/_ffi/base.py", line 476, in raise_last_ffi_error
    raise py_err
  File "/opt/scratch/junrushao/tvm-dev/src/runtime/relax_vm/vm.cc", line 634, in tvm::runtime::relax_vm::VirtualMachineImpl::InvokeClosurePacked(tvm::runtime::ObjectRef const&, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
    clo->impl.CallPacked(TVMArgs(values.data(), tcodes.data(), args.size() + 1), rv);
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/junrushao/tvm-dev/src/runtime/relax_vm/vm.cc", line 708, in operator()
    *rv = static_cast<VirtualMachineImpl*>(ctx_ptr)->InvokeBytecode(gf_idx, inputs);
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/junrushao/tvm-dev/src/runtime/relax_vm/vm.cc", line 765, in tvm::runtime::relax_vm::VirtualMachineImpl::InvokeBytecode(long, std::vector<tvm::runtime::TVMRetValue, std::allocator<tvm::runtime::TVMRetValue> > const&)
    RunLoop();

  File "/opt/scratch/junrushao/tvm-dev/src/runtime/relax_vm/vm.cc", line 890, in tvm::runtime::relax_vm::VirtualMachineImpl::RunLoop()
    this->RunInstrCall(curr_frame, instr);
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/junrushao/tvm-dev/src/runtime/relax_vm/vm.cc", line 843, in tvm::runtime::relax_vm::VirtualMachineImpl::RunInstrCall(tvm::runtime::relax_vm::VMFrame*, tvm::runtime::relax_vm::Instruction)
    this->InvokeClosurePacked(func_pool_[instr.func_idx], args, &ret);
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/junrushao/mlc-llm/mlc_llm/relax_model/param_manager.py", line 579, in set_item
    loaded_params[i] = tvm.nd.array(computed_param, device=device_cpu)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/junrushao/tvm-dev/python/tvm/runtime/ndarray.py", line 635, in array
    return empty(arr.shape, arr.dtype, device, mem_scope).copyfrom(arr)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/junrushao/tvm-dev/python/tvm/runtime/ndarray.py", line 390, in empty
    dtype = DataType(dtype)
            ^^^^^^^^^^^^^^^
  File "/opt/scratch/junrushao/tvm-dev/python/tvm/_ffi/runtime_ctypes.py", line 174, in __init__
    raise ValueError("Do not know how to handle type %s" % type_str)
ValueError: Do not know how to handle type object

To briefly explain the root cause of this issue, set_item method, as defined here, gets computed_param=None after this commit.

As a temporary solution, I'd love to propose that we revert this commit for now to quickly unblock us from building any LLM model, but I'm happy to get it back immediately as soon as the issue is fixed.

junrushao added a commit to junrushao/tvm that referenced this pull request Oct 1, 2023
)"

This reverts commit aa4587f.

Unfortunately, this PR broke MLC LLM's build pipeline, more specifically, the command below

```python
python3 -m mlc_llm.build --model dist/models/llama-2-13b-chat-hf/ --quantization q4f16_1
```

leads to the the following error message:

<details>

```
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/opt/scratch/junrushao/mlc-llm/mlc_llm/build.py", line 13, in <module>
    main()
  File "/opt/scratch/junrushao/mlc-llm/mlc_llm/build.py", line 10, in main
    core.build_model_from_args(parsed_args)
  File "/opt/scratch/junrushao/mlc-llm/mlc_llm/core.py", line 616, in build_model_from_args
    new_params = utils.convert_weights(param_manager, params, args)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/junrushao/mlc-llm/mlc_llm/utils.py", line 258, in convert_weights
    vm["transform_params"]()
  File "/opt/scratch/junrushao/tvm-dev/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/opt/scratch/junrushao/tvm-dev/python/tvm/_ffi/base.py", line 476, in raise_last_ffi_error
    raise py_err
  File "/opt/scratch/junrushao/tvm-dev/src/runtime/relax_vm/vm.cc", line 634, in tvm::runtime::relax_vm::VirtualMachineImpl::InvokeClosurePacked(tvm::runtime::ObjectRef const&, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
    clo->impl.CallPacked(TVMArgs(values.data(), tcodes.data(), args.size() + 1), rv);
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/junrushao/tvm-dev/src/runtime/relax_vm/vm.cc", line 708, in operator()
    *rv = static_cast<VirtualMachineImpl*>(ctx_ptr)->InvokeBytecode(gf_idx, inputs);
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/junrushao/tvm-dev/src/runtime/relax_vm/vm.cc", line 765, in tvm::runtime::relax_vm::VirtualMachineImpl::InvokeBytecode(long, std::vector<tvm::runtime::TVMRetValue, std::allocator<tvm::runtime::TVMRetValue> > const&)
    RunLoop();

  File "/opt/scratch/junrushao/tvm-dev/src/runtime/relax_vm/vm.cc", line 890, in tvm::runtime::relax_vm::VirtualMachineImpl::RunLoop()
    this->RunInstrCall(curr_frame, instr);
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/junrushao/tvm-dev/src/runtime/relax_vm/vm.cc", line 843, in tvm::runtime::relax_vm::VirtualMachineImpl::RunInstrCall(tvm::runtime::relax_vm::VMFrame*, tvm::runtime::relax_vm::Instruction)
    this->InvokeClosurePacked(func_pool_[instr.func_idx], args, &ret);
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/junrushao/mlc-llm/mlc_llm/relax_model/param_manager.py", line 579, in set_item
    loaded_params[i] = tvm.nd.array(computed_param, device=device_cpu)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/junrushao/tvm-dev/python/tvm/runtime/ndarray.py", line 635, in array
    return empty(arr.shape, arr.dtype, device, mem_scope).copyfrom(arr)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/junrushao/tvm-dev/python/tvm/runtime/ndarray.py", line 390, in empty
    dtype = DataType(dtype)
            ^^^^^^^^^^^^^^^
  File "/opt/scratch/junrushao/tvm-dev/python/tvm/_ffi/runtime_ctypes.py", line 174, in __init__
    raise ValueError("Do not know how to handle type %s" % type_str)
ValueError: Do not know how to handle type object
```

</details>

To briefly explain the root cause of this issue, `set_item` method, as defined [here](https://github.com/mlc-ai/mlc-llm/blob/4f4a93f03fed3d900605c02de575d7d5f429ed79/mlc_llm/relax_model/param_manager.py#L576-L579), gets `computed_param=None` after this commit.

As a temporary solution, I'd love to propose that we revert this commit for now to quickly unblock us from building any LLM model, but I'm happy to get it back immediately as soon as the issue is fixed.
Lunderberg added a commit to Lunderberg/tvm that referenced this pull request Oct 2, 2023
The `KillAfterLastUse` pass that was implemented in
apache#15810 checked for trivial
re-bindings in `VarBinding` nodes, but not in `MatchCast` nodes.  As a
result, `CodeGenVM`'s de-duplication of registers resulted in the
object being killed prematurely.

```python
y = R.match_cast(x, R.Tensor(...)) # Trivial rebinding here.
                                   # CodeGenVM has these share a register.

R.memory.kill_tensor(x)  # Kill x after last usage.
                         # Register is set to None.

_ = R.ExternFunc("set_item")(y) # Use of the cleared register through y.
```
junrushao pushed a commit that referenced this pull request Oct 3, 2023
)

The `KillAfterLastUse` pass that was implemented in
#15810 checked for trivial
re-bindings in `VarBinding` nodes, but not in `MatchCast` nodes.  As a
result, `CodeGenVM`'s de-duplication of registers resulted in the
object being killed prematurely.

```python
y = R.match_cast(x, R.Tensor(...)) # Trivial rebinding here.
                                   # CodeGenVM has these share a register.

R.memory.kill_tensor(x)  # Kill x after last usage.
                         # Register is set to None.

_ = R.ExternFunc("set_item")(y) # Use of the cleared register through y.
```
masahi pushed a commit that referenced this pull request Oct 13, 2023
* [Unity][VM] Improved error message in CodeGenVM::EmitKillObject

This was implemented while debugging CI failures in
#15810, but is not otherwise related
to the changes in that PR.

* ci bump
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants