-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Relay][TE] Use Relay parameter name to generated TE tensor name #10516
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][TE] Use Relay parameter name to generated TE tensor name #10516
Conversation
c37278f to
6d7b337
Compare
|
Rebase onto main following #10535 |
564bf5e to
c5f19dc
Compare
|
Rebase onto main following merge conflict with #10577. |
|
Current CI failures due to checks for the name |
tkonolige
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a very helpful change. Thanks @Lunderberg!
68d2c5b to
9ac6636
Compare
|
@Lunderberg what's the status on this one? |
|
Currently there's one CI failure that I haven't had time to track down yet. The error message occurs during memory verification, that a buffer is undefined, but likely is caused by an earlier pass failing to update a buffer. |
9ac6636 to
f163bdd
Compare
|
Rebased onto main, since the failing test case was disabled in #10717. |
47e9550 to
30c4c8b
Compare
|
I let this one languish long enough that I no longer trust the CI checks to be up-to-date. Re-running CI. |
Previously, the TE placeholders representing relay function parameters were all named `"placeholder"`, which could be difficult to follow when debugging larger functions.
The tensor name "ethos-u", once passed through from relay to TE, resulted in invalid C++ codegen.
The tensor name is part of the hash used for these results, so the previous hashes are no longer valid.
8a70448 to
c7dc0ff
Compare
#10516 used the Relay parameter name when lowering to TE. However, this creates an issue when the parameter name is empty. This is legal in Relay, but results in errors during code generation. For example, this is the generated CUDA kernel for bias add: ``` extern "C" __global__ void __launch_bounds__(1024) fused_raf_op_tvm_add_kernel0( float* __restrict__ T_add, float* __restrict__ , /* Name is missing and it results in compile errors. */ float* __restrict__ _1) { T_add[((((int)blockIdx.x) * 1024) + ((int)threadIdx.x))] = ([((((int)blockIdx.x) * 1024) + ((int)threadIdx.x))] + _1[((((((int)blockIdx.x) * 16) + (((int)threadIdx.x) >> 6)) % 54) / 9)]); } ``` This PR adds "placeholder" back as a default to make sure no empty string will be passed when lowering to TE.
…che#10516) * [Relay][TE] Use Relay parameter name to generated TE tensor name Previously, the TE placeholders representing relay function parameters were all named `"placeholder"`, which could be difficult to follow when debugging larger functions.
apache#10516 used the Relay parameter name when lowering to TE. However, this creates an issue when the parameter name is empty. This is legal in Relay, but results in errors during code generation. For example, this is the generated CUDA kernel for bias add: ``` extern "C" __global__ void __launch_bounds__(1024) fused_raf_op_tvm_add_kernel0( float* __restrict__ T_add, float* __restrict__ , /* Name is missing and it results in compile errors. */ float* __restrict__ _1) { T_add[((((int)blockIdx.x) * 1024) + ((int)threadIdx.x))] = ([((((int)blockIdx.x) * 1024) + ((int)threadIdx.x))] + _1[((((((int)blockIdx.x) * 16) + (((int)threadIdx.x) >> 6)) % 54) / 9)]); } ``` This PR adds "placeholder" back as a default to make sure no empty string will be passed when lowering to TE.
Previously, the TE placeholders representing relay function parameters were all named
"placeholder", which could be difficult to follow when debugging larger functions.