Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 20 additions & 21 deletions source/op/cuda/descrpt_se_a.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ limitations under the License.
typedef float VALUETYPE;
#endif

typedef double compute_t;

typedef unsigned long long int_64;

#define cudaErrcheck(res) { cudaAssert((res), __FILE__, __LINE__); }
Expand Down Expand Up @@ -77,19 +75,20 @@ __device__ inline T dev_dot(T * arr1, T * arr2) {
return arr1[0] * arr2[0] + arr1[1] * arr2[1] + arr1[2] * arr2[2];
}

__device__ inline void spline5_switch(compute_t & vv,
compute_t & dd,
compute_t & xx,
const compute_t & rmin,
const compute_t & rmax)
template<typename FPTYPE>
__device__ inline void spline5_switch(FPTYPE & vv,
FPTYPE & dd,
FPTYPE & xx,
const float & rmin,
const float & rmax)
{
if (xx < rmin) {
dd = 0;
vv = 1;
}
else if (xx < rmax) {
compute_t uu = (xx - rmin) / (rmax - rmin) ;
compute_t du = 1. / (rmax - rmin) ;
FPTYPE uu = (xx - rmin) / (rmax - rmin) ;
FPTYPE du = 1. / (rmax - rmin) ;
vv = uu*uu*uu * (-6 * uu*uu + 15 * uu - 10) + 1;
dd = ( 3 * uu*uu * (-6 * uu*uu + 15 * uu - 10) + uu*uu*uu * (-12 * uu + 15) ) * du;
}
Expand Down Expand Up @@ -133,12 +132,12 @@ __global__ void format_nlist_fill_a_se_a(const VALUETYPE * coord,

int_64 * key_in = key + idx * MAGIC_NUMBER;

compute_t diff[3];
VALUETYPE diff[3];
const int & j_idx = nei_idx[idy];
for (int dd = 0; dd < 3; dd++) {
diff[dd] = coord[j_idx * 3 + dd] - coord[idx * 3 + dd];
}
compute_t rr = sqrt(dev_dot(diff, diff));
VALUETYPE rr = sqrt(dev_dot(diff, diff));
if (rr <= rcut) {
key_in[idy] = type[j_idx] * 1E15+ (int_64)(rr * 1.0E13) / 100000 * 100000 + j_idx;
}
Expand Down Expand Up @@ -192,8 +191,8 @@ __global__ void compute_descriptor_se_a (VALUETYPE* descript,
int* nlist,
const int nlist_size,
const VALUETYPE* coord,
const VALUETYPE rmin,
const VALUETYPE rmax,
const float rmin,
const float rmax,
const int sec_a_size)
{
// <<<nloc, sec_a.back()>>>
Expand All @@ -214,14 +213,14 @@ __global__ void compute_descriptor_se_a (VALUETYPE* descript,
for (int kk = 0; kk < 3; kk++) {
row_rij[idy * 3 + kk] = coord[j_idx * 3 + kk] - coord[idx * 3 + kk];
}
const compute_t * rr = &row_rij[idy * 3 + 0];
compute_t nr2 = dev_dot(rr, rr);
compute_t inr = 1./sqrt(nr2);
compute_t nr = nr2 * inr;
compute_t inr2 = inr * inr;
compute_t inr4 = inr2 * inr2;
compute_t inr3 = inr4 * nr;
compute_t sw, dsw;
const VALUETYPE * rr = &row_rij[idy * 3 + 0];
VALUETYPE nr2 = dev_dot(rr, rr);
VALUETYPE inr = 1./sqrt(nr2);
VALUETYPE nr = nr2 * inr;
VALUETYPE inr2 = inr * inr;
VALUETYPE inr4 = inr2 * inr2;
VALUETYPE inr3 = inr4 * nr;
VALUETYPE sw, dsw;
spline5_switch(sw, dsw, nr, rmin, rmax);
row_descript[idx_value + 0] = (1./nr) ;//* sw;
row_descript[idx_value + 1] = (rr[0] / nr2) ;//* sw;
Expand Down