Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 19 additions & 23 deletions source/lib/src/cuda/tabulate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,13 @@ __global__ void tabulate_fusion_fifth_order_polynomial(
const int nnei,
const int last_layer_size)
{
extern __shared__ int _data[];
const int block_idx = blockIdx.x; // nloc
const int thread_idx = threadIdx.x; // last_layer_size
FPTYPE ago = __shfl_sync(0xffffffff, em_x[block_idx * nnei + nnei - 1], 0);
bool unloop = false;
int breakpoint = nnei - 1;
FPTYPE * iteratorC = (FPTYPE*) &_data[0];
for (int kk = 0; kk < MTILE; kk++)
iteratorC[kk * last_layer_size + thread_idx] = 0.f;
__syncthreads();

FPTYPE sum[MTILE] = {0.f};
for (int ii = 0; ii < nnei; ii++) {
FPTYPE var[6];
FPTYPE xx = em_x[block_idx * nnei + ii];
Expand All @@ -102,12 +98,12 @@ __global__ void tabulate_fusion_fifth_order_polynomial(
FPTYPE res = var[0] + (var[1] + (var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) * xx;

for (int kk = 0; kk < MTILE; kk++) {
iteratorC[kk * last_layer_size + thread_idx] += (nnei - breakpoint) * em[block_idx * nnei * MTILE + ii * MTILE + kk] * res;
sum[kk] += (nnei - breakpoint) * em[block_idx * nnei * MTILE + ii * MTILE + kk] * res;
}
if (unloop) break;
}
for (int ii = 0; ii < MTILE; ii++) {
out[block_idx * MTILE * last_layer_size + ii * last_layer_size + thread_idx] = iteratorC[ii * last_layer_size + thread_idx];
out[block_idx * MTILE * last_layer_size + ii * last_layer_size + thread_idx] = sum[ii];
}
}

Expand All @@ -133,8 +129,8 @@ __global__ void tabulate_fusion_grad_fifth_order_polynomial(
extern __shared__ int _data[];
const int block_idx = blockIdx.x; // nloc
const int thread_idx = threadIdx.x; // KTILE * WARP_SIZE, usally 128 here~
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / WARP_SIZE, 0);
int lane_idx = threadIdx.x % WARP_SIZE;
int breakpoint = nnei - 1;
bool unloop = false;
FPTYPE * iteratorA = (FPTYPE *)&_data[0]; // dy
Expand All @@ -145,16 +141,16 @@ __global__ void tabulate_fusion_grad_fifth_order_polynomial(
}
__syncthreads();
FPTYPE ago = __shfl_sync(0xffffffff, em_x[block_idx * nnei + nnei - 1], 0);
for (int ii = 0; ii < nnei; ii += KTILE) {
FPTYPE xx = em_x[block_idx * nnei + ii + warp_idx];
for (int ii = warp_idx; ii < nnei; ii += KTILE) {
FPTYPE xx = em_x[block_idx * nnei + ii];
if (ago == xx) {
unloop = true;
breakpoint = ii + warp_idx;
breakpoint = ii;
}

int table_idx = 0;
locate_xx(xx, table_idx, lower, upper, max, stride0, stride1);
FPTYPE sum[KTILE] = {0.f};
FPTYPE sum[MTILE] = {0.f};
FPTYPE Csub = 0.f;
for (int jj = lane_idx; jj < last_layer_size; jj += WARP_SIZE) {
FPTYPE var[6];
Expand All @@ -167,25 +163,25 @@ __global__ void tabulate_fusion_grad_fifth_order_polynomial(
var[5] = table[table_idx * last_layer_size * 6 + 6 * jj + 5];
FPTYPE res = var[0] + (var[1] + (var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) * xx;

for (int kk = 0; kk < KTILE; kk++) {
for (int kk = 0; kk < MTILE; kk++) {
sum[kk] += (nnei - breakpoint) * iteratorA[kk * last_layer_size + jj] * res;
}
res = em[block_idx * nnei * MTILE + (ii + warp_idx) * 4 + 0] * iteratorA[0 * last_layer_size + jj];
res += em[block_idx * nnei * MTILE + (ii + warp_idx) * 4 + 1] * iteratorA[1 * last_layer_size + jj];
res += em[block_idx * nnei * MTILE + (ii + warp_idx) * 4 + 2] * iteratorA[2 * last_layer_size + jj];
res += em[block_idx * nnei * MTILE + (ii + warp_idx) * 4 + 3] * iteratorA[3 * last_layer_size + jj];
res = em[block_idx * nnei * MTILE + ii * 4 + 0] * iteratorA[0 * last_layer_size + jj];
res += em[block_idx * nnei * MTILE + ii * 4 + 1] * iteratorA[1 * last_layer_size + jj];
res += em[block_idx * nnei * MTILE + ii * 4 + 2] * iteratorA[2 * last_layer_size + jj];
res += em[block_idx * nnei * MTILE + ii * 4 + 3] * iteratorA[3 * last_layer_size + jj];
Csub += (nnei - breakpoint) * (var[1] + (2 * var[2] + (3 * var[3] + (4 * var[4] + 5 * var[5] * xx) * xx) * xx) * xx) * res;
}
__syncwarp();
for (int kk = 0; kk < KTILE; kk++) {
for (int kk = 0; kk < MTILE; kk++) {
warp_reduce(sum[kk]);
}
warp_reduce(Csub);
if (lane_idx == 0) {
for (int kk = 0; kk < KTILE; kk++) {
dy_dem[block_idx * nnei * MTILE + (ii + warp_idx) * 4 + kk] = sum[kk];
for (int kk = 0; kk < MTILE; kk++) {
dy_dem[block_idx * nnei * MTILE + ii * 4 + kk] = sum[kk];
}
dy_dem_x[block_idx * nnei + ii + warp_idx] = Csub;
dy_dem_x[block_idx * nnei + ii] = Csub;
}
if (unloop) break;
}
Expand All @@ -204,7 +200,7 @@ void tabulate_fusion_gpu_cuda(
const int last_layer_size)
{
if (nloc <= 0) {return;}
tabulate_fusion_fifth_order_polynomial<FPTYPE, MM, KK> <<<nloc, last_layer_size, sizeof(FPTYPE) * MM * last_layer_size>>>(
tabulate_fusion_fifth_order_polynomial<FPTYPE, MM, KK> <<<nloc, last_layer_size>>>(
out,
table, em_x, em, table_info[0], table_info[1], table_info[2], table_info[3], table_info[4], nnei, last_layer_size);
}
Expand Down