-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Unity] Implement LowerAllocTensor to remove R.builtin.alloc_tensor #15809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| passes.append(relax.transform.LowerAllocTensor()) | ||
|
|
||
| if tvm.transform.PassContext.current().config.get("relax.backend.use_cuda_graph", False): | ||
| passes.append(relax.transform.RewriteCUDAGraph()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you look in to what assumptions relax.transform.RewriteCUDAGraph makes on static vs dynamic allocations to ensure the expectations match before and after this change, given that there is some dependence on alloc_tensor in RewriteCudaGraph?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you on the reminder. I had assumed that the tests in tests/python/relax/test_transform_rewrite_cuda_graph.py and tests/python/relax/test_vm_cuda_graph.py would be sufficient. On closer inspection, it turns out that the former provides the input of RewriteCUDAGraph, the latter provides the output from RewriteCUDAGraph, and neither test the behavior of the pass as it exists within a lowering flow. On writing a quick end-to-end test, there is an issue that occurs within the cudagraph rewriting pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It appears to be a bug in RewriteCUDAGraph, which occurs when there is a R.memory.alloc_storage that is then used in a trivial var-to-var rebinding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This bug ended up being trickier to track down than I had expected. It will be much simpler to solve after #15810 lands, since the .kill_* methods won't be inserted yet. For now, I've re-ordered the passes so the LowerAllocTensor occurs after RewriteCUDAGraph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, many thanks for looking into this @Lunderberg, changing the ordering until the fix is in makes sense. Let's land this an #15810.
| passes.append(relax.transform.LowerAllocTensor()) | ||
|
|
||
| if tvm.transform.PassContext.current().config.get("relax.backend.use_cuda_graph", False): | ||
| passes.append(relax.transform.RewriteCUDAGraph()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, many thanks for looking into this @Lunderberg, changing the ordering until the fix is in makes sense. Let's land this an #15810.
The `StaticPlanBlockMemory` transform is provided a module that expresses all allocations with `R.builtin.alloc_tensor`, and produces a module that uses `R.memory.alloc_storage` and `R.memory.alloc_tensor` to express static allocations, while dynamic allocations continue to use `R.builtin.alloc_tensor`. Prior to this commit, this mixed output was handled as part of `VMBuiltinLower`. This commit extracts the lowering of `R.builtin.alloc_tensor` to a new pass, `LowerAllocTensor`. This pass runs after `StaticPlanBlockMemory`, and replaces any remaining `R.builtin.alloc_tensor` with calls to `R.memory.alloc_storage` and `R.memory.alloc_tensor`.
1cf89ad to
253e013
Compare
The `R.memory.alloc_storage` produced by `LowerAllocTensor` must be present in order to be appropriately deleted by `KillAfterLastUse`.
The
StaticPlanBlockMemorytransform is provided a module that expresses all allocations withR.builtin.alloc_tensor, and produces a module that usesR.memory.alloc_storageandR.memory.alloc_tensorto express static allocations, while dynamic allocations continue to useR.builtin.alloc_tensor.Prior to this commit, this mixed output was handled as part of
VMBuiltinLower. This commit extracts the lowering ofR.builtin.alloc_tensorto a new pass,LowerAllocTensor. This pass runs afterStaticPlanBlockMemory, and replaces any remainingR.builtin.alloc_tensorwith calls toR.memory.alloc_storageandR.memory.alloc_tensor.