From 143afce2e05ba90382ed67ece785790938969d02 Mon Sep 17 00:00:00 2001 From: Lu Date: Mon, 14 Sep 2020 13:25:41 +0800 Subject: [PATCH] fix bug of error compilation when using float precision --- source/op/cuda/descrpt_se_a.cu | 41 +++++++++++++++++----------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/source/op/cuda/descrpt_se_a.cu b/source/op/cuda/descrpt_se_a.cu index 1636df8ff5..1dd255fff2 100644 --- a/source/op/cuda/descrpt_se_a.cu +++ b/source/op/cuda/descrpt_se_a.cu @@ -24,8 +24,6 @@ limitations under the License. typedef float VALUETYPE; #endif -typedef double compute_t; - typedef unsigned long long int_64; #define cudaErrcheck(res) { cudaAssert((res), __FILE__, __LINE__); } @@ -77,19 +75,20 @@ __device__ inline T dev_dot(T * arr1, T * arr2) { return arr1[0] * arr2[0] + arr1[1] * arr2[1] + arr1[2] * arr2[2]; } -__device__ inline void spline5_switch(compute_t & vv, - compute_t & dd, - compute_t & xx, - const compute_t & rmin, - const compute_t & rmax) +template +__device__ inline void spline5_switch(FPTYPE & vv, + FPTYPE & dd, + FPTYPE & xx, + const float & rmin, + const float & rmax) { if (xx < rmin) { dd = 0; vv = 1; } else if (xx < rmax) { - compute_t uu = (xx - rmin) / (rmax - rmin) ; - compute_t du = 1. / (rmax - rmin) ; + FPTYPE uu = (xx - rmin) / (rmax - rmin) ; + FPTYPE du = 1. / (rmax - rmin) ; vv = uu*uu*uu * (-6 * uu*uu + 15 * uu - 10) + 1; dd = ( 3 * uu*uu * (-6 * uu*uu + 15 * uu - 10) + uu*uu*uu * (-12 * uu + 15) ) * du; } @@ -133,12 +132,12 @@ __global__ void format_nlist_fill_a_se_a(const VALUETYPE * coord, int_64 * key_in = key + idx * MAGIC_NUMBER; - compute_t diff[3]; + VALUETYPE diff[3]; const int & j_idx = nei_idx[idy]; for (int dd = 0; dd < 3; dd++) { diff[dd] = coord[j_idx * 3 + dd] - coord[idx * 3 + dd]; } - compute_t rr = sqrt(dev_dot(diff, diff)); + VALUETYPE rr = sqrt(dev_dot(diff, diff)); if (rr <= rcut) { key_in[idy] = type[j_idx] * 1E15+ (int_64)(rr * 1.0E13) / 100000 * 100000 + j_idx; } @@ -192,8 +191,8 @@ __global__ void compute_descriptor_se_a (VALUETYPE* descript, int* nlist, const int nlist_size, const VALUETYPE* coord, - const VALUETYPE rmin, - const VALUETYPE rmax, + const float rmin, + const float rmax, const int sec_a_size) { // <<>> @@ -214,14 +213,14 @@ __global__ void compute_descriptor_se_a (VALUETYPE* descript, for (int kk = 0; kk < 3; kk++) { row_rij[idy * 3 + kk] = coord[j_idx * 3 + kk] - coord[idx * 3 + kk]; } - const compute_t * rr = &row_rij[idy * 3 + 0]; - compute_t nr2 = dev_dot(rr, rr); - compute_t inr = 1./sqrt(nr2); - compute_t nr = nr2 * inr; - compute_t inr2 = inr * inr; - compute_t inr4 = inr2 * inr2; - compute_t inr3 = inr4 * nr; - compute_t sw, dsw; + const VALUETYPE * rr = &row_rij[idy * 3 + 0]; + VALUETYPE nr2 = dev_dot(rr, rr); + VALUETYPE inr = 1./sqrt(nr2); + VALUETYPE nr = nr2 * inr; + VALUETYPE inr2 = inr * inr; + VALUETYPE inr4 = inr2 * inr2; + VALUETYPE inr3 = inr4 * nr; + VALUETYPE sw, dsw; spline5_switch(sw, dsw, nr, rmin, rmax); row_descript[idx_value + 0] = (1./nr) ;//* sw; row_descript[idx_value + 1] = (rr[0] / nr2) ;//* sw;