CUDA: Optimize PAD_REFLECT_1D#15957
Conversation
feat: add more test cases for PAD_REFLECT_1D
|
Thanks for the PR — really nice work! I was exploring a similar implementation and we both use threadblock layout related to ne0, but I think that keeping a loop structure in this kernel may hide memory access latency, so I want to introduce a parameter UNROLL to control grid-stride loop. In short, I want ne0 / UNROLL per line, each process UNROLL elements in a single iteration, and i think your approach implied UNROLL=1. Here are some benchmark results (Device: 3060M 6G):
not a big change, but it might improve robustness across workloads. If you think it makes sense, you can take a look at my draft implementation: |
|
Hummm, your code is really good, if the tensor shape is small, no loop version performs better, if the shape is large enough, your unrolled version is the best. |
JohannesGaessler
left a comment
There was a problem hiding this comment.
On some GPUs it's faster to pass float pointers and to calculate offsets like s01 = nb01 / sizeof(float) in host code.
You are right. If the tensor is small, I guess we need to activate more cuda core to increase the occupancy of SM. Maybe we can have a strategy to set different UNROLL value according to the tensor size, but alse it will make code more complex. Such as: template<int UNROLL=1>
static __global__ void pad_reflect_1d_kernel_f32(...);
void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
// same as your code
if (TENSORSIZE <= LIMIT) {
pad_reflect_1d_kernel_f32<1><<<grid, block, 0, stream>>>(...); // better for small tensor
} else {
pad_reflect_1d_kernel_f32<4><<<grid, block, 0, stream>>>(...); // better for large tensor
}
}I think the amount of LIMIT is related to the GPU architecture, in my device its about 32768 (i0 * i1 * i2 * i3), but I am not sure. |
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
JohannesGaessler
left a comment
There was a problem hiding this comment.
I forgot: try adding __launch_bounds__(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1) to the kernel. This tells the compiler that the kernel will only ever be launched with 256 threads so the compiler can optimize it more aggressively.
I tried,
Hi, I did add it, but the performance stays the same, |
|
We don't need the absolute fastest implementation of padding, most likely it takes up almost nothing of the actual runtime. I'm willing to maintain a single variant of the kernel unless it can be demonstrated that there is a use case where having two kernels makes a meaningful difference for the end-to-end performance. |
# Conflicts: # ggml/src/ggml-cuda/pad_reflect_1d.cu
got it, thank you for the guidance, the final speed improvement of shape [512, 34, 2, 1] is +29.6%, Thank you for the help and hints. |
* CUDA: Optimize PAD_REFLECT_1D feat: add more test cases for PAD_REFLECT_1D * use fast_div to improve performance * Apply suggestion from JohannesGaessler Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * Apply suggestion from JohannesGaessler Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * optimize * use a concise expression to further speedup the cuda kernel --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
* CUDA: Optimize PAD_REFLECT_1D feat: add more test cases for PAD_REFLECT_1D * use fast_div to improve performance * Apply suggestion from JohannesGaessler Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * Apply suggestion from JohannesGaessler Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * optimize * use a concise expression to further speedup the cuda kernel --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
* CUDA: Optimize PAD_REFLECT_1D feat: add more test cases for PAD_REFLECT_1D * use fast_div to improve performance * Apply suggestion from JohannesGaessler Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * Apply suggestion from JohannesGaessler Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * optimize * use a concise expression to further speedup the cuda kernel --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
in the previous PR #14659 , JohannesGaessler said #14659 (comment)
so I wrote a version without loops. it benefits the most for smaller tensor sizes, but in general it improves on any size as well.
by the way I added more test cases for PAD_REFLECT_1D.
here is the benchmark summary:
Benchmark Results
Highlights
Overall Result
[512,34,2,1],[3000,80,1,1],[3000,80,4,1]).[3000,384,1,1],[3000,384,4,1]).raw benchmark data:
old kernel
new kernel