diff --git a/source/op/cuda/descrpt_se_a.cu b/source/op/cuda/descrpt_se_a.cu index 8893e8a00a..8b6b3ee575 100644 --- a/source/op/cuda/descrpt_se_a.cu +++ b/source/op/cuda/descrpt_se_a.cu @@ -208,13 +208,15 @@ __global__ void compute_descriptor_se_a (VALUETYPE* descript, const VALUETYPE* coord, const VALUETYPE rmin, const VALUETYPE rmax, - compute_t* sel_a_diff_dev) + compute_t* sel_a_diff_dev, + const int sec_a_size) { // <<>> - const unsigned int idx = blockIdx.x; - const unsigned int idy = threadIdx.x; + const unsigned int idx = blockIdx.y; + const unsigned int idy = blockIdx.x * blockDim.x + threadIdx.x; const int idx_deriv = idy * 4 * 3; // 4 components time 3 directions const int idx_value = idy * 4; // 4 components + if (idy >= sec_a_size) {return;} // else {return;} VALUETYPE * row_descript = descript + idx * ndescrpt; @@ -355,7 +357,9 @@ void DescrptSeALauncher(const VALUETYPE* coord, ); } - compute_descriptor_se_a<<>> ( + const int nblock_ = (sec_a.back() + LEN -1) / LEN; + dim3 block_grid(nblock_, nloc); + compute_descriptor_se_a<<>> ( descript, ndescrpt, descript_deriv, @@ -370,7 +374,8 @@ void DescrptSeALauncher(const VALUETYPE* coord, coord, rcut_r_smth, rcut_r, - sel_a_diff + sel_a_diff, + sec_a.back() ); //// // res = cudaFree(sec_a_dev); cudaErrcheck(res); diff --git a/source/op/cuda/descrpt_se_r.cu b/source/op/cuda/descrpt_se_r.cu index cc7dd4e904..2a4a126166 100644 --- a/source/op/cuda/descrpt_se_r.cu +++ b/source/op/cuda/descrpt_se_r.cu @@ -209,14 +209,16 @@ __global__ void compute_descriptor_se_r (VALUETYPE* descript, const VALUETYPE* coord, const VALUETYPE rmin, const VALUETYPE rmax, - compute_t* sel_diff_dev) + compute_t* sel_diff_dev, + const int sec_size) { // <<>> - const unsigned int idx = blockIdx.x; - const unsigned int idy = threadIdx.x; + const unsigned int idx = blockIdx.y; + const unsigned int idy = blockIdx.x * blockDim.x + threadIdx.x; const int idx_deriv = idy * 3; // 1 components time 3 directions const int idx_value = idy; // 1 components - + if (idy >= sec_size) {return;} + // else {return;} VALUETYPE * row_descript = descript + idx * ndescrpt; VALUETYPE * row_descript_deriv = descript_deriv + idx * descript_deriv_size; @@ -324,7 +326,9 @@ void DescrptSeRLauncher(const VALUETYPE* coord, nei_iter ); } - compute_descriptor_se_r<<>> ( + const int nblock_ = (sec.back() + LEN -1) / LEN; + dim3 block_grid(nblock_, nloc); + compute_descriptor_se_r<<>> ( descript, ndescrpt, descript_deriv, @@ -339,6 +343,7 @@ void DescrptSeRLauncher(const VALUETYPE* coord, coord, rcut_smth, rcut, - sel_diff + sel_diff, + sec.back() ); } \ No newline at end of file