diff --git a/source/lib/src/cuda/tabulate.cu b/source/lib/src/cuda/tabulate.cu index 606b96537e..b71a989819 100644 --- a/source/lib/src/cuda/tabulate.cu +++ b/source/lib/src/cuda/tabulate.cu @@ -73,17 +73,13 @@ __global__ void tabulate_fusion_fifth_order_polynomial( const int nnei, const int last_layer_size) { - extern __shared__ int _data[]; const int block_idx = blockIdx.x; // nloc const int thread_idx = threadIdx.x; // last_layer_size FPTYPE ago = __shfl_sync(0xffffffff, em_x[block_idx * nnei + nnei - 1], 0); bool unloop = false; int breakpoint = nnei - 1; - FPTYPE * iteratorC = (FPTYPE*) &_data[0]; - for (int kk = 0; kk < MTILE; kk++) - iteratorC[kk * last_layer_size + thread_idx] = 0.f; - __syncthreads(); + FPTYPE sum[MTILE] = {0.f}; for (int ii = 0; ii < nnei; ii++) { FPTYPE var[6]; FPTYPE xx = em_x[block_idx * nnei + ii]; @@ -102,12 +98,12 @@ __global__ void tabulate_fusion_fifth_order_polynomial( FPTYPE res = var[0] + (var[1] + (var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) * xx; for (int kk = 0; kk < MTILE; kk++) { - iteratorC[kk * last_layer_size + thread_idx] += (nnei - breakpoint) * em[block_idx * nnei * MTILE + ii * MTILE + kk] * res; + sum[kk] += (nnei - breakpoint) * em[block_idx * nnei * MTILE + ii * MTILE + kk] * res; } if (unloop) break; } for (int ii = 0; ii < MTILE; ii++) { - out[block_idx * MTILE * last_layer_size + ii * last_layer_size + thread_idx] = iteratorC[ii * last_layer_size + thread_idx]; + out[block_idx * MTILE * last_layer_size + ii * last_layer_size + thread_idx] = sum[ii]; } } @@ -133,8 +129,8 @@ __global__ void tabulate_fusion_grad_fifth_order_polynomial( extern __shared__ int _data[]; const int block_idx = blockIdx.x; // nloc const int thread_idx = threadIdx.x; // KTILE * WARP_SIZE, usally 128 here~ - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - int lane_idx = threadIdx.x % 32; + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / WARP_SIZE, 0); + int lane_idx = threadIdx.x % WARP_SIZE; int breakpoint = nnei - 1; bool unloop = false; FPTYPE * iteratorA = (FPTYPE *)&_data[0]; // dy @@ -145,16 +141,16 @@ __global__ void tabulate_fusion_grad_fifth_order_polynomial( } __syncthreads(); FPTYPE ago = __shfl_sync(0xffffffff, em_x[block_idx * nnei + nnei - 1], 0); - for (int ii = 0; ii < nnei; ii += KTILE) { - FPTYPE xx = em_x[block_idx * nnei + ii + warp_idx]; + for (int ii = warp_idx; ii < nnei; ii += KTILE) { + FPTYPE xx = em_x[block_idx * nnei + ii]; if (ago == xx) { unloop = true; - breakpoint = ii + warp_idx; + breakpoint = ii; } int table_idx = 0; locate_xx(xx, table_idx, lower, upper, max, stride0, stride1); - FPTYPE sum[KTILE] = {0.f}; + FPTYPE sum[MTILE] = {0.f}; FPTYPE Csub = 0.f; for (int jj = lane_idx; jj < last_layer_size; jj += WARP_SIZE) { FPTYPE var[6]; @@ -167,25 +163,25 @@ __global__ void tabulate_fusion_grad_fifth_order_polynomial( var[5] = table[table_idx * last_layer_size * 6 + 6 * jj + 5]; FPTYPE res = var[0] + (var[1] + (var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) * xx; - for (int kk = 0; kk < KTILE; kk++) { + for (int kk = 0; kk < MTILE; kk++) { sum[kk] += (nnei - breakpoint) * iteratorA[kk * last_layer_size + jj] * res; } - res = em[block_idx * nnei * MTILE + (ii + warp_idx) * 4 + 0] * iteratorA[0 * last_layer_size + jj]; - res += em[block_idx * nnei * MTILE + (ii + warp_idx) * 4 + 1] * iteratorA[1 * last_layer_size + jj]; - res += em[block_idx * nnei * MTILE + (ii + warp_idx) * 4 + 2] * iteratorA[2 * last_layer_size + jj]; - res += em[block_idx * nnei * MTILE + (ii + warp_idx) * 4 + 3] * iteratorA[3 * last_layer_size + jj]; + res = em[block_idx * nnei * MTILE + ii * 4 + 0] * iteratorA[0 * last_layer_size + jj]; + res += em[block_idx * nnei * MTILE + ii * 4 + 1] * iteratorA[1 * last_layer_size + jj]; + res += em[block_idx * nnei * MTILE + ii * 4 + 2] * iteratorA[2 * last_layer_size + jj]; + res += em[block_idx * nnei * MTILE + ii * 4 + 3] * iteratorA[3 * last_layer_size + jj]; Csub += (nnei - breakpoint) * (var[1] + (2 * var[2] + (3 * var[3] + (4 * var[4] + 5 * var[5] * xx) * xx) * xx) * xx) * res; } __syncwarp(); - for (int kk = 0; kk < KTILE; kk++) { + for (int kk = 0; kk < MTILE; kk++) { warp_reduce(sum[kk]); } warp_reduce(Csub); if (lane_idx == 0) { - for (int kk = 0; kk < KTILE; kk++) { - dy_dem[block_idx * nnei * MTILE + (ii + warp_idx) * 4 + kk] = sum[kk]; + for (int kk = 0; kk < MTILE; kk++) { + dy_dem[block_idx * nnei * MTILE + ii * 4 + kk] = sum[kk]; } - dy_dem_x[block_idx * nnei + ii + warp_idx] = Csub; + dy_dem_x[block_idx * nnei + ii] = Csub; } if (unloop) break; } @@ -204,7 +200,7 @@ void tabulate_fusion_gpu_cuda( const int last_layer_size) { if (nloc <= 0) {return;} - tabulate_fusion_fifth_order_polynomial <<>>( + tabulate_fusion_fifth_order_polynomial <<>>( out, table, em_x, em, table_info[0], table_info[1], table_info[2], table_info[3], table_info[4], nnei, last_layer_size); }