issue with the implementation of column_sum_reduce#804
issue with the implementation of column_sum_reduce#804RezaYazdaniAminabadi merged 2 commits intodeepspeedai:masterfrom zmxdream:zmx-patch-1
Conversation
hi, i take a look at the code of column_sum_reduce, i have 2 questions: 1. the goal of column_sum_reduce is to get the column sum of inp matrix with shape[rows, width] and the result shape should be [width],right ? It seems that the judgment condition of pos is not suitable 2. the implementation of cuda kernel based on the asumption that, the thread with same threadIdx.y will group into a thread_block_tile, the blockDim is (32,32), i read the nvidia document https://on-demand.gputechconf.com/gtc/2017/presentation/s7622-Kyrylo-perelygin-robust-and-scalable-cuda.pdf, THREAD BLOCK TILE is a subset of threads of a thread block, divided into tiles in row-major order. doesn't it mean thread with the same threadIdx.x will group into a thread_block_tile ? thanks !!!!
| if (threadIdx.x == 0) { | ||
| int pos = blockIdx.x * TILE_DIM + threadIdx.y; | ||
| if (pos < (rows * width)) out[pos] = sum; | ||
| if (pos < width) out[pos] = sum; |
There was a problem hiding this comment.
Thanks for fixing this! I would say it still was working when the hidden dimension was dividable by 32, however, it would have caused a memory leak for when the hidden is not dividable by 32!
There was a problem hiding this comment.
yes! thanks for your approval!!
Hi @zmx19951103 Thanks for fixing this bug. Regarding your second question, I think both x and y dimensions are assigned to different thread_block tiles, however, since this is a 2-dimensional tile, we are just using the the threadIx.y for saving the output after all is reduced across each tile (here ) whose got the same y index and x index changes from 0 to 31. So, what you are saying is also true, and this is also our assumption when reducing the elements in a row! |
hi, i take a look at the code of column_sum_reduce, i have 2 questions:
thanks !!!!