From e4e1ac7deb117b871fe392a919e10d6b9ffbd2bd Mon Sep 17 00:00:00 2001 From: Lu Date: Mon, 13 Apr 2020 15:03:08 +0800 Subject: [PATCH 1/2] add customized op for gelu activation function Both CPU and GPU version are supported! --- source/op/CMakeLists.txt | 21 +++- source/op/_gelu.py | 15 +++ source/op/cuda/CMakeLists.txt | 11 ++- source/op/cuda/gelu.cu | 77 +++++++++++++++ source/op/gelu.cc | 176 ++++++++++++++++++++++++++++++++++ source/op/gelu_gpu.cc | 171 +++++++++++++++++++++++++++++++++ source/train/Trainer.py | 1 + source/train/common.py | 26 ++--- 8 files changed, 479 insertions(+), 19 deletions(-) create mode 100644 source/op/_gelu.py create mode 100644 source/op/cuda/gelu.cu create mode 100644 source/op/gelu.cc create mode 100644 source/op/gelu_gpu.cc diff --git a/source/op/CMakeLists.txt b/source/op/CMakeLists.txt index 89416056e1..d0dc200236 100644 --- a/source/op/CMakeLists.txt +++ b/source/op/CMakeLists.txt @@ -3,8 +3,9 @@ set(OP_LIB ${PROJECT_SOURCE_DIR}/lib/src/SimulationRegion.cpp ${PROJECT_SOURCE_DIR}/lib/src/NeighborList.cpp) set (OP_CXX_FLAG -D_GLIBCXX_USE_CXX11_ABI=${OP_CXX_ABI} ) -file(GLOB OP_SRC prod_force.cc prod_virial.cc descrpt.cc descrpt_se_a.cc descrpt_se_r.cc tab_inter.cc prod_force_se_a.cc prod_virial_se_a.cc prod_force_se_r.cc prod_virial_se_r.cc soft_min.cc soft_min_force.cc soft_min_virial.cc ewald_recp.cc) -file(GLOB OP_CUDA_SRC prod_force.cc prod_virial.cc descrpt.cc descrpt_se_a_gpu.cc descrpt_se_r_gpu.cc tab_inter.cc prod_force_se_a_gpu.cc prod_virial_se_a_gpu.cc prod_force_se_r_gpu.cc prod_virial_se_r_gpu.cc soft_min.cc soft_min_force.cc soft_min_virial.cc ) +file(GLOB OP_SRC prod_force.cc prod_virial.cc descrpt.cc descrpt_se_a.cc descrpt_se_r.cc tab_inter.cc prod_force_se_a.cc prod_virial_se_a.cc prod_force_se_r.cc prod_virial_se_r.cc soft_min.cc soft_min_force.cc soft_min_virial.cc ewald_recp.cc gelu.cc) +file(GLOB OP_PY_CUDA_SRC prod_force.cc prod_virial.cc descrpt.cc descrpt_se_a.cc descrpt_se_r.cc tab_inter.cc prod_force_se_a.cc prod_virial_se_a.cc prod_force_se_r.cc prod_virial_se_r.cc soft_min.cc soft_min_force.cc soft_min_virial.cc ewald_recp.cc gelu_gpu.cc) +file(GLOB OP_CUDA_SRC prod_force.cc prod_virial.cc descrpt.cc descrpt_se_a_gpu.cc descrpt_se_r_gpu.cc tab_inter.cc prod_force_se_a_gpu.cc prod_virial_se_a_gpu.cc prod_force_se_r_gpu.cc prod_virial_se_r_gpu.cc soft_min.cc soft_min_force.cc soft_min_virial.cc gelu_gpu.cc) file(GLOB OP_GRADS_SRC prod_force_grad.cc prod_force_se_a_grad.cc prod_force_se_r_grad.cc prod_virial_grad.cc prod_virial_se_a_grad.cc prod_virial_se_r_grad.cc soft_min_force_grad.cc soft_min_virial_grad.cc ) file(GLOB OP_PY *.py) @@ -23,8 +24,20 @@ if (BUILD_CPP_IF) endif (BUILD_CPP_IF) if (BUILD_PY_IF) - add_library(op_abi SHARED ${OP_SRC} ${OP_LIB}) - add_library(op_grads SHARED ${OP_GRADS_SRC}) + if (USE_CUDA_TOOLKIT) + add_library(op_abi SHARED ${OP_PY_CUDA_SRC} ${OP_LIB}) + add_library(op_grads SHARED ${OP_GRADS_SRC}) + add_subdirectory(cuda) + find_package(CUDA REQUIRED) + include_directories(${CUDA_INCLUDE_DIRS}) + set (EXTRA_LIBS ${EXTRA_LIBS} deepmd_op_cuda) + target_link_libraries (op_abi ${EXTRA_LIBS}) + target_link_libraries (op_grads ${EXTRA_LIBS}) + message(STATUS ${TensorFlowFramework_LIBRARY}) + else (USE_CUDA_TOOLKIT) + add_library(op_abi SHARED ${OP_SRC} ${OP_LIB}) + add_library(op_grads SHARED ${OP_GRADS_SRC}) + endif(USE_CUDA_TOOLKIT) target_link_libraries( op_abi ${TensorFlowFramework_LIBRARY} ) diff --git a/source/op/_gelu.py b/source/op/_gelu.py new file mode 100644 index 0000000000..9af8d3cbb0 --- /dev/null +++ b/source/op/_gelu.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python3 +""" +First-order derivatives and second-order derivatives for gelu function. +""" + +from tensorflow.python.framework import ops +from deepmd.env import op_module + +@ops.RegisterGradient("Gelu") +def gelu_cc (op, dy) : + return op_module.gelu_grad(dy, op.inputs[0]) + +@ops.RegisterGradient("GeluGrad") +def gelu_grad_cc (op, dy) : + return [None, op_module.gelu_grad_grad(dy, op.inputs[0], op.inputs[1])] diff --git a/source/op/cuda/CMakeLists.txt b/source/op/cuda/CMakeLists.txt index 25f796500b..d3edc6e98e 100644 --- a/source/op/cuda/CMakeLists.txt +++ b/source/op/cuda/CMakeLists.txt @@ -80,9 +80,14 @@ else () endif() set (SOURCE_FILES - descrpt_se_a.cu descrpt_se_r.cu prod_force_se_a.cu prod_force_se_r.cu prod_virial_se_a.cu prod_virial_se_r.cu + descrpt_se_a.cu descrpt_se_r.cu prod_force_se_a.cu prod_force_se_r.cu prod_virial_se_a.cu prod_virial_se_r.cu gelu.cu ) -cuda_add_library(deepmd_op_cuda SHARED ${SOURCE_FILES}) +cuda_add_library(deepmd_op_cuda STATIC ${SOURCE_FILES}) -install(TARGETS deepmd_op_cuda DESTINATION lib/) +if (BUILD_CPP_IF) + install(TARGETS deepmd_op_cuda DESTINATION lib/) +endif (BUILD_CPP_IF) +if (BUILD_PY_IF) + install(TARGETS deepmd_op_cuda DESTINATION deepmd/) +endif (BUILD_PY_IF) diff --git a/source/op/cuda/gelu.cu b/source/op/cuda/gelu.cu new file mode 100644 index 0000000000..99b7b1aed4 --- /dev/null +++ b/source/op/cuda/gelu.cu @@ -0,0 +1,77 @@ +#include +#include + +#define SQRT_2_PI 0.7978845608028654 + +template +__global__ void gelu(const T * in, T * out, int const size) { + int const idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= size) {return;} + + out[idx] = in[idx] * 0.5 * (1.0 + tanh(SQRT_2_PI * (in[idx] + 0.044715 * in[idx] * in[idx] *in[idx]))); +} + +template +__global__ void gelu_grad(const T * dy, const T * in, T * out, int const size) { + int const idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= size) {return;} + + // out[idx] = in[idx] * 0.5 * (1.0 + tanh(SQRT_2_PI * (in[idx] + 0.044715 * in[idx] * in[idx] *in[idx]))); + T const var1 = tanh(SQRT_2_PI * (in[idx] + 0.044715 * in[idx] * in[idx] *in[idx])); + out[idx] = dy[idx] * (0.5 * SQRT_2_PI * in[idx] * (1 - var1 * var1) * (0.134145 * in[idx] * in[idx] + 1) + 0.5 * var1 + 0.5); +} + +template +__global__ void gelu_grad_grad(const T * dy, const T * dy_, const T * in, T * out, int const size) { + int const idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= size) {return;} + + // out[idx] = in[idx] * 0.5 * (1.0 + tanh(SQRT_2_PI * (in[idx] + 0.044715 * in[idx] * in[idx] *in[idx]))); + T const var1 = tanh(SQRT_2_PI * (in[idx] + 0.044715 * in[idx] * in[idx] *in[idx])); + T const var2 = SQRT_2_PI * (1 - var1 * var1) * (0.134145 * in[idx] * in[idx] + 1); + + out[idx] = dy[idx] * dy_[idx] * (0.134145 * SQRT_2_PI * in[idx] * in[idx] * (1 - var1 * var1) - SQRT_2_PI * in[idx] * var2 * (0.134145 * in[idx] * in[idx] + 1) * var1 + var2); +} + + +void GeluLauncher(const float * in, float * out, int const size) { + int const THREAD_ITEMS = 1024; + int const BLOCK_NUMS = (size + THREAD_ITEMS - 1) / THREAD_ITEMS; + + gelu<<>>(in, out, size); +} + +void GeluLauncher(const double * in, double * out, int const size) { + int const THREAD_ITEMS = 1024; + int const BLOCK_NUMS = (size + THREAD_ITEMS - 1) / THREAD_ITEMS; + + gelu<<>>(in, out, size); +} + +void GeluGradLauncher(const float * dy, const float * in, float * out, int const size) { + int const THREAD_ITEMS = 1024; + int const BLOCK_NUMS = (size + THREAD_ITEMS - 1) / THREAD_ITEMS; + + gelu_grad<<>>(dy, in, out, size); +} + +void GeluGradLauncher(const double * dy, const double * in, double * out, int const size) { + int const THREAD_ITEMS = 1024; + int const BLOCK_NUMS = (size + THREAD_ITEMS - 1) / THREAD_ITEMS; + + gelu_grad<<>>(dy, in, out, size); +} + +void GeluGradGradLauncher(const float * dy, const float * dy_, const float * in, float * out, int const size) { + int const THREAD_ITEMS = 1024; + int const BLOCK_NUMS = (size + THREAD_ITEMS - 1) / THREAD_ITEMS; + + gelu_grad_grad<<>>(dy, dy_, in, out, size); +} + +void GeluGradGradLauncher(const double * dy, const double * dy_, const double * in, double * out, int const size) { + int const THREAD_ITEMS = 1024; + int const BLOCK_NUMS = (size + THREAD_ITEMS - 1) / THREAD_ITEMS; + + gelu_grad_grad<<>>(dy, dy_, in, out, size); +} diff --git a/source/op/gelu.cc b/source/op/gelu.cc new file mode 100644 index 0000000000..2e59de9e34 --- /dev/null +++ b/source/op/gelu.cc @@ -0,0 +1,176 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/shape_inference.h" +#define SQRT_2_PI 0.7978845608028654 + +using namespace tensorflow; +using CPUDevice = Eigen::ThreadPoolDevice; +using GPUDevice = Eigen::GpuDevice; + +REGISTER_OP("Gelu") + .Attr("T: {float, double}") + .Input("x: T") + .Output("output: T") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }); + +REGISTER_OP("GeluGrad") + .Attr("T: {float, double}") + .Input("dy: T") + .Input("x: T") + .Output("output: T") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, c->input(1)); + return Status::OK(); + }); + +REGISTER_OP("GeluGradGrad") + .Attr("T: {float, double}") + .Input("dy: T") + .Input("dy_: T") + .Input("x: T") + .Output("output: T") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, c->input(2)); + return Status::OK(); + }); + +template +struct GeluFunctor { + void operator()(const Device& d, const T * in, T * out, int const size) { + #pragma omp parallel for + for (int ii = 0; ii < size; ii++) { + out[ii] = in[ii] * 0.5 * (1.0 + tanh(SQRT_2_PI * (in[ii] + 0.044715 * in[ii] * in[ii] * in[ii]))); + } + } +}; + +template +struct GeluGradFunctor { + void operator()(const Device& d, const T * dy, const T * in, T * out, int const size) { + #pragma omp parallel for + for (int ii = 0; ii < size; ii++) { + T const var1 = tanh(SQRT_2_PI * (in[ii] + 0.044715 * in[ii] * in[ii] *in[ii])); + out[ii] = dy[ii] * (0.5 * SQRT_2_PI * in[ii] * (1 - var1 * var1) * (0.134145 * in[ii] * in[ii] + 1) + 0.5 * var1 + 0.5); + } + } +}; + +template +struct GeluGradGradFunctor { + void operator()(const Device& d, const T * dy, const T * dy_, const T * in, T * out, int const size) { + #pragma omp parallel for + for (int ii = 0; ii < size; ii++) { + T const var1 = tanh(SQRT_2_PI * (in[ii] + 0.044715 * in[ii] * in[ii] *in[ii])); + T const var2 = SQRT_2_PI * (1 - var1 * var1) * (0.134145 * in[ii] * in[ii] + 1); + + out[ii] = dy[ii] * dy_[ii] * (0.134145 * SQRT_2_PI * in[ii] * in[ii] * (1 - var1 * var1) - SQRT_2_PI * in[ii] * var2 * (0.134145 * in[ii] * in[ii] + 1) * var1 + var2); + } + } +}; + +// OpKernel definition. +// template parameter is the datatype of the tensors. +template +class GeluOp : public OpKernel { + public : + explicit GeluOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + const Tensor& x = context->input(0); + + Tensor * output = NULL; + int context_output_index = 0; + OP_REQUIRES_OK(context, context->allocate_output(context_output_index++, + x.shape(), + &output)); + + GeluFunctor()( + context->eigen_device(), + x.flat().data(), + output->flat().data(), + static_cast(output->NumElements()) + ); + // GeluLauncher(x.flat().data(), output->flat().data(), static_cast(output->NumElements())); + } +}; + +// OpKernel definition. +// template parameter is the datatype of the tensors. +template +class GeluGradOp : public OpKernel { + public : + explicit GeluGradOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + const Tensor& dy = context->input(0); + const Tensor& x = context->input(1); + + Tensor * output = NULL; + int context_output_index = 0; + OP_REQUIRES_OK(context, context->allocate_output(context_output_index++, + x.shape(), + &output)); + + GeluGradFunctor()( + context->eigen_device(), + dy.flat().data(), + x.flat().data(), + output->flat().data(), + static_cast(output->NumElements()) + ); + // GeluGradLauncher(dy.flat().data(), x.flat().data(), output->flat().data(), static_cast(output->NumElements())); + } +}; + +// OpKernel definition. +// template parameter is the datatype of the tensors. +template +class GeluGradGradOp : public OpKernel { + public : + explicit GeluGradGradOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + const Tensor& dy = context->input(0); + const Tensor& dy_ = context->input(1); + const Tensor& x = context->input(2); + + Tensor * output = NULL; + int context_output_index = 0; + OP_REQUIRES_OK(context, context->allocate_output(context_output_index++, + x.shape(), + &output)); + + GeluGradGradFunctor()( + context->eigen_device(), + dy.flat().data(), + dy_.flat().data(), + x.flat().data(), + output->flat().data(), + static_cast(output->NumElements()) + ); + // GeluGradGradLauncher(dy.flat().data(), x.flat().data(), output->flat().data(), static_cast(output->NumElements())); + } +}; + +#define REGISTER_CPU(T) \ + /* Declare explicit instantiations in kernel_example.cu.cc. */ \ + REGISTER_KERNEL_BUILDER( \ + Name("Gelu").Device(DEVICE_CPU).TypeConstraint("T"), \ + GeluOp); \ + /* Declare explicit instantiations in kernel_example.cu.cc. */ \ + REGISTER_KERNEL_BUILDER( \ + Name("GeluGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ + GeluGradOp); \ + /* Declare explicit instantiations in kernel_example.cu.cc. */ \ + REGISTER_KERNEL_BUILDER( \ + Name("GeluGradGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ + GeluGradGradOp); + REGISTER_CPU(float); + REGISTER_CPU(double); \ No newline at end of file diff --git a/source/op/gelu_gpu.cc b/source/op/gelu_gpu.cc new file mode 100644 index 0000000000..d41c438882 --- /dev/null +++ b/source/op/gelu_gpu.cc @@ -0,0 +1,171 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/shape_inference.h" + +using namespace tensorflow; +using CPUDevice = Eigen::ThreadPoolDevice; +using GPUDevice = Eigen::GpuDevice; + +REGISTER_OP("Gelu") + .Attr("T: {float, double}") + .Input("x: T") + .Output("output: T") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }); + +REGISTER_OP("GeluGrad") + .Attr("T: {float, double}") + .Input("dy: T") + .Input("x: T") + .Output("output: T") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, c->input(1)); + return Status::OK(); + }); + +REGISTER_OP("GeluGradGrad") + .Attr("T: {float, double}") + .Input("dy: T") + .Input("dy_: T") + .Input("x: T") + .Output("output: T") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, c->input(2)); + return Status::OK(); + }); + +// maybe instead use cudnn activation forward +void GeluLauncher(const float * in, float * out, int const size); +void GeluLauncher(const double * in, double * out, int const size); + +void GeluGradLauncher(const float * dy, const float * in, float * out, int const size); +void GeluGradLauncher(const double * dy, const double * in, double * out, int const size); + +void GeluGradGradLauncher(const float * dy, const float * dy_, const float * in, float * out, int const size); +void GeluGradGradLauncher(const double * dy, const double * dy_, const double * in, double * out, int const size); + +template +struct GeluFunctor { + void operator()(const Device& d, const T * in, T * out, int const size) { + GeluLauncher(in, out, size); + } +}; + +template +struct GeluGradFunctor { + void operator()(const Device& d, const T * dy, const T * in, T * out, int const size) { + GeluGradLauncher(dy, in, out, size); + } +}; + +template +struct GeluGradGradFunctor { + void operator()(const Device& d, const T * dy, const T * dy_, const T * in, T * out, int const size) { + GeluGradGradLauncher(dy, dy_, in, out, size); + } +}; + +// OpKernel definition. +// template parameter is the datatype of the tensors. +template +class GeluOp : public OpKernel { + public : + explicit GeluOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + const Tensor& x = context->input(0); + Tensor * output = NULL; + int context_output_index = 0; + OP_REQUIRES_OK(context, context->allocate_output(context_output_index++, + x.shape(), + &output)); + + GeluFunctor()( + context->eigen_device(), + x.flat().data(), + output->flat().data(), + static_cast(output->NumElements()) + ); + // GeluLauncher(x.flat().data(), output->flat().data(), static_cast(output->NumElements())); + } +}; + +// OpKernel definition. +// template parameter is the datatype of the tensors. +template +class GeluGradOp : public OpKernel { + public : + explicit GeluGradOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + const Tensor& dy = context->input(0); + const Tensor& x = context->input(1); + + Tensor * output = NULL; + int context_output_index = 0; + OP_REQUIRES_OK(context, context->allocate_output(context_output_index++, + x.shape(), + &output)); + + GeluGradFunctor()( + context->eigen_device(), + dy.flat().data(), + x.flat().data(), + output->flat().data(), + static_cast(output->NumElements()) + ); + // GeluGradLauncher(dy.flat().data(), x.flat().data(), output->flat().data(), static_cast(output->NumElements())); + } +}; + +// OpKernel definition. +// template parameter is the datatype of the tensors. +template +class GeluGradGradOp : public OpKernel { + public : + explicit GeluGradGradOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + const Tensor& dy = context->input(0); + const Tensor& dy_ = context->input(1); + const Tensor& x = context->input(2); + + Tensor * output = NULL; + int context_output_index = 0; + OP_REQUIRES_OK(context, context->allocate_output(context_output_index++, + x.shape(), + &output)); + + GeluGradGradFunctor()( + context->eigen_device(), + dy.flat().data(), + dy_.flat().data(), + x.flat().data(), + output->flat().data(), + static_cast(output->NumElements()) + ); + // GeluGradGradLauncher(dy.flat().data(), x.flat().data(), output->flat().data(), static_cast(output->NumElements())); + } +}; + +#define REGISTER_GPU(T) \ + /* Declare explicit instantiations in kernel_example.cu.cc. */ \ + REGISTER_KERNEL_BUILDER( \ + Name("Gelu").Device(DEVICE_GPU).TypeConstraint("T"), \ + GeluOp); \ + /* Declare explicit instantiations in kernel_example.cu.cc. */ \ + REGISTER_KERNEL_BUILDER( \ + Name("GeluGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ + GeluGradOp); \ + /* Declare explicit instantiations in kernel_example.cu.cc. */ \ + REGISTER_KERNEL_BUILDER( \ + Name("GeluGradGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ + GeluGradGradOp); + REGISTER_GPU(float); + REGISTER_GPU(double); diff --git a/source/train/Trainer.py b/source/train/Trainer.py index 50db1adfbc..b6428c987e 100644 --- a/source/train/Trainer.py +++ b/source/train/Trainer.py @@ -28,6 +28,7 @@ import deepmd._prod_virial_se_r_grad import deepmd._soft_min_force_grad import deepmd._soft_min_virial_grad +import deepmd._gelu from deepmd.common import j_must_have, ClassArg diff --git a/source/train/common.py b/source/train/common.py index c250e4cdfc..887669a278 100644 --- a/source/train/common.py +++ b/source/train/common.py @@ -2,19 +2,22 @@ import numpy as np import math from deepmd.env import tf +from deepmd.env import op_module from deepmd.RunOptions import global_tf_float_precision -def gelu(x): - """Gaussian Error Linear Unit. - This is a smoother version of the RELU. - Original paper: https://arxiv.org/abs/1606.08415 - Args: - x: float Tensor to perform activation. - Returns: - `x` with the GELU activation applied. - """ - cdf = 0.5 * (1.0 + tf.tanh((math.sqrt(2 / math.pi) * (x + 0.044715 * tf.pow(x, 3))))) - return x * cdf +# def gelu(x): +# """Gaussian Error Linear Unit. +# This is a smoother version of the RELU. +# Original paper: https://arxiv.org/abs/1606.08415 +# Args: +# x: float Tensor to perform activation. +# Returns: +# `x` with the GELU activation applied. +# """ +# cdf = 0.5 * (1.0 + tf.tanh((math.sqrt(2 / math.pi) * (x + 0.044715 * tf.pow(x, 3))))) +# return x * cdf +def gelu(x) : + return op_module.gelu(x) data_requirement = {} activation_fn_dict = { @@ -40,7 +43,6 @@ def add_data_requirement(key, 'repeat': repeat, } - def select_idx_map(atom_type, type_sel): sort_type_sel = np.sort(type_sel) From 7b0c12ee1038350561e5cbd046d4c4780d2c19ff Mon Sep 17 00:00:00 2001 From: Lu Date: Mon, 13 Apr 2020 15:32:28 +0800 Subject: [PATCH 2/2] remove shared library LIB_DEEPMD_OP_CUDA, instead use static library --- source/CMakeLists.txt | 5 ----- source/lmp/env.sh.in | 2 +- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt index 6b18cb95ac..0066e032a7 100644 --- a/source/CMakeLists.txt +++ b/source/CMakeLists.txt @@ -184,11 +184,6 @@ include_directories(${TensorFlow_INCLUDE_DIRS}) if (BUILD_CPP_IF) set (LIB_DEEPMD "deepmd") set (LIB_DEEPMD_OP "deepmd_op") - if (USE_CUDA_TOOLKIT) - set (LIB_DEEPMD_OP_CUDA "deepmd_op_cuda") - else() - set (LIB_DEEPMD_OP_CUDA "deepmd_op") - endif() if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 4.9) set (LIB_DEEPMD_NATIVE "deepmd_native_md") set (LIB_DEEPMD_IPI "deepmd_ipi") diff --git a/source/lmp/env.sh.in b/source/lmp/env.sh.in index 00bc3b18b6..7f58018b79 100644 --- a/source/lmp/env.sh.in +++ b/source/lmp/env.sh.in @@ -8,4 +8,4 @@ TF_RPATH=`echo $TENSORFLOW_LIBRARY_PATH | sed "s/;/ -Wl,-rpath=/g"` NNP_INC=" -std=c++11 @PREC_DEF@ @TTM_DEF@ @OLD_LMP_PPPM_DEF@ -I$TF_INCLUDE_DIRS -I$DEEPMD_ROOT/include/deepmd " NNP_PATH=" -L$TF_LIBRARY_PATH -L$DEEPMD_ROOT/lib" -NNP_LIB=" -Wl,--no-as-needed -l@LIB_DEEPMD_OP@ -l@LIB_DEEPMD_OP_CUDA@ -l@LIB_DEEPMD@ -ltensorflow_cc -ltensorflow_framework -Wl,-rpath=$TF_RPATH -Wl,-rpath=$DEEPMD_ROOT/lib" +NNP_LIB=" -Wl,--no-as-needed -l@LIB_DEEPMD_OP@ -l@LIB_DEEPMD@ -ltensorflow_cc -ltensorflow_framework -Wl,-rpath=$TF_RPATH -Wl,-rpath=$DEEPMD_ROOT/lib"