Skip to content

Conversation

@Lunderberg
Copy link
Contributor

The StaticPlanBlockMemory transform is provided a module that expresses all allocations with R.builtin.alloc_tensor, and produces a module that uses R.memory.alloc_storage and
R.memory.alloc_tensor to express static allocations, while dynamic allocations continue to use R.builtin.alloc_tensor.

Prior to this commit, this mixed output was handled as part of VMBuiltinLower. This commit extracts the lowering of R.builtin.alloc_tensor to a new pass, LowerAllocTensor. This pass runs after StaticPlanBlockMemory, and replaces any remaining R.builtin.alloc_tensor with calls to R.memory.alloc_storage and R.memory.alloc_tensor.

passes.append(relax.transform.LowerAllocTensor())

if tvm.transform.PassContext.current().config.get("relax.backend.use_cuda_graph", False):
passes.append(relax.transform.RewriteCUDAGraph())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you look in to what assumptions relax.transform.RewriteCUDAGraph makes on static vs dynamic allocations to ensure the expectations match before and after this change, given that there is some dependence on alloc_tensor in RewriteCudaGraph?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you on the reminder. I had assumed that the tests in tests/python/relax/test_transform_rewrite_cuda_graph.py and tests/python/relax/test_vm_cuda_graph.py would be sufficient. On closer inspection, it turns out that the former provides the input of RewriteCUDAGraph, the latter provides the output from RewriteCUDAGraph, and neither test the behavior of the pass as it exists within a lowering flow. On writing a quick end-to-end test, there is an issue that occurs within the cudagraph rewriting pass.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears to be a bug in RewriteCUDAGraph, which occurs when there is a R.memory.alloc_storage that is then used in a trivial var-to-var rebinding.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This bug ended up being trickier to track down than I had expected. It will be much simpler to solve after #15810 lands, since the .kill_* methods won't be inserted yet. For now, I've re-ordered the passes so the LowerAllocTensor occurs after RewriteCUDAGraph.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, many thanks for looking into this @Lunderberg, changing the ordering until the fix is in makes sense. Let's land this an #15810.

passes.append(relax.transform.LowerAllocTensor())

if tvm.transform.PassContext.current().config.get("relax.backend.use_cuda_graph", False):
passes.append(relax.transform.RewriteCUDAGraph())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, many thanks for looking into this @Lunderberg, changing the ordering until the fix is in makes sense. Let's land this an #15810.

The `StaticPlanBlockMemory` transform is provided a module that
expresses all allocations with `R.builtin.alloc_tensor`, and produces
a module that uses `R.memory.alloc_storage` and
`R.memory.alloc_tensor` to express static allocations, while dynamic
allocations continue to use `R.builtin.alloc_tensor`.

Prior to this commit, this mixed output was handled as part of
`VMBuiltinLower`. This commit extracts the lowering of
`R.builtin.alloc_tensor` to a new pass, `LowerAllocTensor`.  This pass
runs after `StaticPlanBlockMemory`, and replaces any remaining
`R.builtin.alloc_tensor` with calls to `R.memory.alloc_storage`
and `R.memory.alloc_tensor`.
@Lunderberg Lunderberg force-pushed the unity_lower_alloc_tensor branch from 1cf89ad to 253e013 Compare September 28, 2023 13:56
The `R.memory.alloc_storage` produced by `LowerAllocTensor` must be
present in order to be appropriately deleted by `KillAfterLastUse`.
@csullivan csullivan merged commit 4a8a7b9 into apache:unity Sep 29, 2023
@Lunderberg Lunderberg deleted the unity_lower_alloc_tensor branch September 30, 2023 13:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants