diff --git a/source/lib/src/cuda/prod_env_mat.cu b/source/lib/src/cuda/prod_env_mat.cu index 85941b9da0..93a2b6a787 100644 --- a/source/lib/src/cuda/prod_env_mat.cu +++ b/source/lib/src/cuda/prod_env_mat.cu @@ -7,6 +7,8 @@ __device__ inline double _sqrt(double x) {return sqrt(x);} __device__ inline float _sqrt(float x) {return sqrtf(x);} +__device__ inline double _rsqrt(double x) {return rsqrt(x);} +__device__ inline float _rsqrt(float x) {return rsqrtf(x);} // common part of prod_env_mat template < @@ -408,7 +410,7 @@ __global__ void compute_env_mat_a( } // const FPTYPE * rr = &row_rij[ii * 3]; FPTYPE nr2 = dev_dot(rr, rr); - FPTYPE inr = (FPTYPE)1./_sqrt(nr2); + FPTYPE inr = _rsqrt(nr2); FPTYPE nr = nr2 * inr; FPTYPE inr2 = inr * inr; FPTYPE inr4 = inr2 * inr2; @@ -494,7 +496,7 @@ __global__ void compute_env_mat_r( } // const FPTYPE * rr = &row_rij[ii * 3]; FPTYPE nr2 = dev_dot(rr, rr); - FPTYPE inr = (FPTYPE)1./_sqrt(nr2); + FPTYPE inr = _rsqrt(nr2); FPTYPE nr = nr2 * inr; FPTYPE inr2 = inr * inr; FPTYPE inr4 = inr2 * inr2; diff --git a/source/lib/src/rocm/prod_env_mat.hip.cu b/source/lib/src/rocm/prod_env_mat.hip.cu index 6a437bd3e0..506a844a04 100644 --- a/source/lib/src/rocm/prod_env_mat.hip.cu +++ b/source/lib/src/rocm/prod_env_mat.hip.cu @@ -5,6 +5,8 @@ __device__ inline double _sqrt(double x) {return sqrt(x);} __device__ inline float _sqrt(float x) {return sqrtf(x);} +__device__ inline double _rsqrt(double x) {return rsqrt(x);} +__device__ inline float _rsqrt(float x) {return rsqrtf(x);} // common part of prod_env_mat template < @@ -406,7 +408,7 @@ __global__ void compute_env_mat_a( } // const FPTYPE * rr = &row_rij[ii * 3]; FPTYPE nr2 = dev_dot(rr, rr); - FPTYPE inr = (FPTYPE)1./_sqrt(nr2); + FPTYPE inr = _rsqrt(nr2); FPTYPE nr = nr2 * inr; FPTYPE inr2 = inr * inr; FPTYPE inr4 = inr2 * inr2; @@ -492,7 +494,7 @@ __global__ void compute_env_mat_r( } // const FPTYPE * rr = &row_rij[ii * 3]; FPTYPE nr2 = dev_dot(rr, rr); - FPTYPE inr = (FPTYPE)1./_sqrt(nr2); + FPTYPE inr = _rsqrt(nr2); FPTYPE nr = nr2 * inr; FPTYPE inr2 = inr * inr; FPTYPE inr4 = inr2 * inr2;