Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 41 additions & 18 deletions source/lib/src/cuda/prod_env_mat.cu
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,29 @@ __global__ void format_nlist_fill_a(
}
}

template<typename FPTYPE>
__global__ void fill_nei_iter(
int * nei_iter_dev,
const FPTYPE * key,
const int nloc,
const int max_nbor_size,
const int sec_size)
{
int row = blockIdx.x;
int col = blockIdx.y * blockDim.x + threadIdx.x;
const FPTYPE * key_out = key + nloc * max_nbor_size + row * max_nbor_size;
int nei_type_cur = -1, nbor_idx_cur = 0;
int nei_type_pre = -1, nbor_idx_pre = 0;
if (col < max_nbor_size && key_out[col] != key_out[max_nbor_size - 1]){
if (col >= 1)
decoding_nbor_info(nei_type_pre, nbor_idx_pre, key_out[col - 1]);
decoding_nbor_info(nei_type_cur, nbor_idx_cur, key_out[col]);
}
if (nei_type_cur != nei_type_pre){
nei_iter_dev[row * sec_size + nei_type_cur] = col;
}
}

template<typename FPTYPE>
__global__ void format_nlist_fill_b(
int * nlist,
Expand All @@ -155,23 +178,19 @@ __global__ void format_nlist_fill_b(
int * nei_iter_dev,
const int max_nbor_size)
{
const unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;
if(idx >= nloc) {
return;
}

int * row_nlist = nlist + idx * nlist_size;
int * nei_iter = nei_iter_dev + idx * sec_size;
FPTYPE * key_out = key + nloc * max_nbor_size + idx * max_nbor_size;
for (int ii = 0; ii < sec_size; ii++) {
nei_iter[ii] = sec[ii];
}

int nei_type = 0, nbor_idx = 0;
for (unsigned int kk = 0; key_out[kk] != key_out[max_nbor_size - 1]; kk++) {
decoding_nbor_info(nei_type, nbor_idx, key_out[kk]);
if (nei_iter[nei_type] < sec[nei_type + 1]) {
row_nlist[nei_iter[nei_type]++] = nbor_idx;
int row = blockIdx.x;
int col = blockIdx.y * blockDim.x + threadIdx.x;
int * nei_iter = nei_iter_dev + row * sec_size;
FPTYPE * key_out = key + nloc * max_nbor_size + row * max_nbor_size;
int * row_nlist = nlist + row * nlist_size;
if (col < max_nbor_size){
if (key_out[col] != key_out[max_nbor_size - 1]){
int nei_type = 0, nbor_idx = 0;
decoding_nbor_info(nei_type, nbor_idx, key_out[col]);
int out_indx = col - nei_iter[nei_type] + sec[nei_type];
if (out_indx < sec[nei_type + 1]){
row_nlist[out_indx] = nbor_idx;
}
}
}
}
Expand Down Expand Up @@ -473,7 +492,11 @@ void format_nbor_list_gpu_cuda(
coord, type, gpu_inlist, nloc, rcut, i_idx);
}

format_nlist_fill_b<<<nblock, LEN>>> (
fill_nei_iter <<<dim3(nloc, (max_nbor_size + LEN - 1) / LEN) , LEN>>> (
nei_iter,
key, nloc, max_nbor_size, sec.size());

format_nlist_fill_b <<<dim3(nloc, (max_nbor_size + LEN - 1) / LEN), LEN>>> (
nlist,
nnei, nloc, key, sec_dev, sec.size(), nei_iter, max_nbor_size);
}
Expand Down