Skip to content

Conversation

@comaniac
Copy link
Contributor

#10516 used the Relay parameter name when lowering to TE. However, this creates an issue when the parameter name is empty. This is legal in Relay, but results in errors during code generation. For example, this is the generated CUDA kernel for bias add:

extern "C" __global__ void __launch_bounds__(1024) fused_raf_op_tvm_add_kernel0(
    float* __restrict__ T_add,
    float* __restrict__ , /* Name is missing and it results in compile errors. */
    float* __restrict__ _1) {
    T_add[((((int)blockIdx.x) * 1024) + ((int)threadIdx.x))] = ([((((int)blockIdx.x) * 1024) + ((int)threadIdx.x))] + _1[((((((int)blockIdx.x) * 16) + (((int)threadIdx.x) >> 6)) % 54) / 9)]);
}

This PR adds "placeholder" back as a default to make sure no empty string will be passed when lowering to TE.

cc @Lunderberg @tkonolige

Copy link
Member

@junrushao junrushao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@junrushao junrushao merged commit 5ddd35c into main Sep 26, 2022
@comaniac comaniac deleted the fix_placeholder branch September 26, 2022 23:41
@Lunderberg
Copy link
Contributor

Thank you for the catch! Is there a test that should have caught this issue?

@junrushao
Copy link
Member

Agreed with @Lunderberg! Perhaps a regression test will be helpful (I don't know how hard it is to do so though)

xinetzone pushed a commit to daobook/tvm that referenced this pull request Nov 25, 2022
apache#10516 used the Relay parameter name when lowering to TE. However, this creates an issue when the parameter name is empty. This is legal in Relay, but results in errors during code generation. For example, this is the generated CUDA kernel for bias add:

```
extern "C" __global__ void __launch_bounds__(1024) fused_raf_op_tvm_add_kernel0(
    float* __restrict__ T_add,
    float* __restrict__ , /* Name is missing and it results in compile errors. */
    float* __restrict__ _1) {
    T_add[((((int)blockIdx.x) * 1024) + ((int)threadIdx.x))] = ([((((int)blockIdx.x) * 1024) + ((int)threadIdx.x))] + _1[((((((int)blockIdx.x) * 16) + (((int)threadIdx.x) >> 6)) % 54) / 9)]);
}
```

This PR adds "placeholder" back as a default to make sure no empty string will be passed when lowering to TE.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants