Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions source/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,6 @@ include_directories(${TensorFlow_INCLUDE_DIRS})
if (BUILD_CPP_IF)
set (LIB_DEEPMD "deepmd")
set (LIB_DEEPMD_OP "deepmd_op")
if (USE_CUDA_TOOLKIT)
set (LIB_DEEPMD_OP_CUDA "deepmd_op_cuda")
else()
set (LIB_DEEPMD_OP_CUDA "deepmd_op")
endif()
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 4.9)
set (LIB_DEEPMD_NATIVE "deepmd_native_md")
set (LIB_DEEPMD_IPI "deepmd_ipi")
Expand Down
2 changes: 1 addition & 1 deletion source/lmp/env.sh.in
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ TF_RPATH=`echo $TENSORFLOW_LIBRARY_PATH | sed "s/;/ -Wl,-rpath=/g"`

NNP_INC=" -std=c++11 @PREC_DEF@ @TTM_DEF@ @OLD_LMP_PPPM_DEF@ -I$TF_INCLUDE_DIRS -I$DEEPMD_ROOT/include/deepmd "
NNP_PATH=" -L$TF_LIBRARY_PATH -L$DEEPMD_ROOT/lib"
NNP_LIB=" -Wl,--no-as-needed -l@LIB_DEEPMD_OP@ -l@LIB_DEEPMD_OP_CUDA@ -l@LIB_DEEPMD@ -ltensorflow_cc -ltensorflow_framework -Wl,-rpath=$TF_RPATH -Wl,-rpath=$DEEPMD_ROOT/lib"
NNP_LIB=" -Wl,--no-as-needed -l@LIB_DEEPMD_OP@ -l@LIB_DEEPMD@ -ltensorflow_cc -ltensorflow_framework -Wl,-rpath=$TF_RPATH -Wl,-rpath=$DEEPMD_ROOT/lib"
21 changes: 17 additions & 4 deletions source/op/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
set(OP_LIB ${PROJECT_SOURCE_DIR}/lib/src/SimulationRegion.cpp ${PROJECT_SOURCE_DIR}/lib/src/NeighborList.cpp)

set (OP_CXX_FLAG -D_GLIBCXX_USE_CXX11_ABI=${OP_CXX_ABI} )
file(GLOB OP_SRC prod_force.cc prod_virial.cc descrpt.cc descrpt_se_a.cc descrpt_se_r.cc tab_inter.cc prod_force_se_a.cc prod_virial_se_a.cc prod_force_se_r.cc prod_virial_se_r.cc soft_min.cc soft_min_force.cc soft_min_virial.cc ewald_recp.cc)
file(GLOB OP_CUDA_SRC prod_force.cc prod_virial.cc descrpt.cc descrpt_se_a_gpu.cc descrpt_se_r_gpu.cc tab_inter.cc prod_force_se_a_gpu.cc prod_virial_se_a_gpu.cc prod_force_se_r_gpu.cc prod_virial_se_r_gpu.cc soft_min.cc soft_min_force.cc soft_min_virial.cc )
file(GLOB OP_SRC prod_force.cc prod_virial.cc descrpt.cc descrpt_se_a.cc descrpt_se_r.cc tab_inter.cc prod_force_se_a.cc prod_virial_se_a.cc prod_force_se_r.cc prod_virial_se_r.cc soft_min.cc soft_min_force.cc soft_min_virial.cc ewald_recp.cc gelu.cc)
file(GLOB OP_PY_CUDA_SRC prod_force.cc prod_virial.cc descrpt.cc descrpt_se_a.cc descrpt_se_r.cc tab_inter.cc prod_force_se_a.cc prod_virial_se_a.cc prod_force_se_r.cc prod_virial_se_r.cc soft_min.cc soft_min_force.cc soft_min_virial.cc ewald_recp.cc gelu_gpu.cc)
file(GLOB OP_CUDA_SRC prod_force.cc prod_virial.cc descrpt.cc descrpt_se_a_gpu.cc descrpt_se_r_gpu.cc tab_inter.cc prod_force_se_a_gpu.cc prod_virial_se_a_gpu.cc prod_force_se_r_gpu.cc prod_virial_se_r_gpu.cc soft_min.cc soft_min_force.cc soft_min_virial.cc gelu_gpu.cc)
file(GLOB OP_GRADS_SRC prod_force_grad.cc prod_force_se_a_grad.cc prod_force_se_r_grad.cc prod_virial_grad.cc prod_virial_se_a_grad.cc prod_virial_se_r_grad.cc soft_min_force_grad.cc soft_min_virial_grad.cc )
file(GLOB OP_PY *.py)

Expand All @@ -23,8 +24,20 @@ if (BUILD_CPP_IF)
endif (BUILD_CPP_IF)

if (BUILD_PY_IF)
add_library(op_abi SHARED ${OP_SRC} ${OP_LIB})
add_library(op_grads SHARED ${OP_GRADS_SRC})
if (USE_CUDA_TOOLKIT)
add_library(op_abi SHARED ${OP_PY_CUDA_SRC} ${OP_LIB})
add_library(op_grads SHARED ${OP_GRADS_SRC})
add_subdirectory(cuda)
find_package(CUDA REQUIRED)
include_directories(${CUDA_INCLUDE_DIRS})
set (EXTRA_LIBS ${EXTRA_LIBS} deepmd_op_cuda)
target_link_libraries (op_abi ${EXTRA_LIBS})
target_link_libraries (op_grads ${EXTRA_LIBS})
message(STATUS ${TensorFlowFramework_LIBRARY})
else (USE_CUDA_TOOLKIT)
add_library(op_abi SHARED ${OP_SRC} ${OP_LIB})
add_library(op_grads SHARED ${OP_GRADS_SRC})
endif(USE_CUDA_TOOLKIT)
target_link_libraries(
op_abi ${TensorFlowFramework_LIBRARY}
)
Expand Down
15 changes: 15 additions & 0 deletions source/op/_gelu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/usr/bin/env python3
"""
First-order derivatives and second-order derivatives for gelu function.
"""

from tensorflow.python.framework import ops
from deepmd.env import op_module

@ops.RegisterGradient("Gelu")
def gelu_cc (op, dy) :
return op_module.gelu_grad(dy, op.inputs[0])

@ops.RegisterGradient("GeluGrad")
def gelu_grad_cc (op, dy) :
return [None, op_module.gelu_grad_grad(dy, op.inputs[0], op.inputs[1])]
11 changes: 8 additions & 3 deletions source/op/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,14 @@ else ()
endif()

set (SOURCE_FILES
descrpt_se_a.cu descrpt_se_r.cu prod_force_se_a.cu prod_force_se_r.cu prod_virial_se_a.cu prod_virial_se_r.cu
descrpt_se_a.cu descrpt_se_r.cu prod_force_se_a.cu prod_force_se_r.cu prod_virial_se_a.cu prod_virial_se_r.cu gelu.cu
)

cuda_add_library(deepmd_op_cuda SHARED ${SOURCE_FILES})
cuda_add_library(deepmd_op_cuda STATIC ${SOURCE_FILES})

install(TARGETS deepmd_op_cuda DESTINATION lib/)
if (BUILD_CPP_IF)
install(TARGETS deepmd_op_cuda DESTINATION lib/)
endif (BUILD_CPP_IF)
if (BUILD_PY_IF)
install(TARGETS deepmd_op_cuda DESTINATION deepmd/)
endif (BUILD_PY_IF)
77 changes: 77 additions & 0 deletions source/op/cuda/gelu.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#include <cuda_runtime.h>
#include <stdio.h>

#define SQRT_2_PI 0.7978845608028654

template <typename T>
__global__ void gelu(const T * in, T * out, int const size) {
int const idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= size) {return;}

out[idx] = in[idx] * 0.5 * (1.0 + tanh(SQRT_2_PI * (in[idx] + 0.044715 * in[idx] * in[idx] *in[idx])));
}

template <typename T>
__global__ void gelu_grad(const T * dy, const T * in, T * out, int const size) {
int const idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= size) {return;}

// out[idx] = in[idx] * 0.5 * (1.0 + tanh(SQRT_2_PI * (in[idx] + 0.044715 * in[idx] * in[idx] *in[idx])));
T const var1 = tanh(SQRT_2_PI * (in[idx] + 0.044715 * in[idx] * in[idx] *in[idx]));
out[idx] = dy[idx] * (0.5 * SQRT_2_PI * in[idx] * (1 - var1 * var1) * (0.134145 * in[idx] * in[idx] + 1) + 0.5 * var1 + 0.5);
}

template <typename T>
__global__ void gelu_grad_grad(const T * dy, const T * dy_, const T * in, T * out, int const size) {
int const idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= size) {return;}

// out[idx] = in[idx] * 0.5 * (1.0 + tanh(SQRT_2_PI * (in[idx] + 0.044715 * in[idx] * in[idx] *in[idx])));
T const var1 = tanh(SQRT_2_PI * (in[idx] + 0.044715 * in[idx] * in[idx] *in[idx]));
T const var2 = SQRT_2_PI * (1 - var1 * var1) * (0.134145 * in[idx] * in[idx] + 1);

out[idx] = dy[idx] * dy_[idx] * (0.134145 * SQRT_2_PI * in[idx] * in[idx] * (1 - var1 * var1) - SQRT_2_PI * in[idx] * var2 * (0.134145 * in[idx] * in[idx] + 1) * var1 + var2);
}


void GeluLauncher(const float * in, float * out, int const size) {
int const THREAD_ITEMS = 1024;
int const BLOCK_NUMS = (size + THREAD_ITEMS - 1) / THREAD_ITEMS;

gelu<<<BLOCK_NUMS, THREAD_ITEMS>>>(in, out, size);
}

void GeluLauncher(const double * in, double * out, int const size) {
int const THREAD_ITEMS = 1024;
int const BLOCK_NUMS = (size + THREAD_ITEMS - 1) / THREAD_ITEMS;

gelu<<<BLOCK_NUMS, THREAD_ITEMS>>>(in, out, size);
}

void GeluGradLauncher(const float * dy, const float * in, float * out, int const size) {
int const THREAD_ITEMS = 1024;
int const BLOCK_NUMS = (size + THREAD_ITEMS - 1) / THREAD_ITEMS;

gelu_grad<<<BLOCK_NUMS, THREAD_ITEMS>>>(dy, in, out, size);
}

void GeluGradLauncher(const double * dy, const double * in, double * out, int const size) {
int const THREAD_ITEMS = 1024;
int const BLOCK_NUMS = (size + THREAD_ITEMS - 1) / THREAD_ITEMS;

gelu_grad<<<BLOCK_NUMS, THREAD_ITEMS>>>(dy, in, out, size);
}

void GeluGradGradLauncher(const float * dy, const float * dy_, const float * in, float * out, int const size) {
int const THREAD_ITEMS = 1024;
int const BLOCK_NUMS = (size + THREAD_ITEMS - 1) / THREAD_ITEMS;

gelu_grad_grad<<<BLOCK_NUMS, THREAD_ITEMS>>>(dy, dy_, in, out, size);
}

void GeluGradGradLauncher(const double * dy, const double * dy_, const double * in, double * out, int const size) {
int const THREAD_ITEMS = 1024;
int const BLOCK_NUMS = (size + THREAD_ITEMS - 1) / THREAD_ITEMS;

gelu_grad_grad<<<BLOCK_NUMS, THREAD_ITEMS>>>(dy, dy_, in, out, size);
}
176 changes: 176 additions & 0 deletions source/op/gelu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
#define SQRT_2_PI 0.7978845608028654

using namespace tensorflow;
using CPUDevice = Eigen::ThreadPoolDevice;
using GPUDevice = Eigen::GpuDevice;

REGISTER_OP("Gelu")
.Attr("T: {float, double}")
.Input("x: T")
.Output("output: T")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
});

REGISTER_OP("GeluGrad")
.Attr("T: {float, double}")
.Input("dy: T")
.Input("x: T")
.Output("output: T")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(1));
return Status::OK();
});

REGISTER_OP("GeluGradGrad")
.Attr("T: {float, double}")
.Input("dy: T")
.Input("dy_: T")
.Input("x: T")
.Output("output: T")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(2));
return Status::OK();
});

template <typename Device, typename T>
struct GeluFunctor {
void operator()(const Device& d, const T * in, T * out, int const size) {
#pragma omp parallel for
for (int ii = 0; ii < size; ii++) {
out[ii] = in[ii] * 0.5 * (1.0 + tanh(SQRT_2_PI * (in[ii] + 0.044715 * in[ii] * in[ii] * in[ii])));
}
}
};

template <typename Device, typename T>
struct GeluGradFunctor {
void operator()(const Device& d, const T * dy, const T * in, T * out, int const size) {
#pragma omp parallel for
for (int ii = 0; ii < size; ii++) {
T const var1 = tanh(SQRT_2_PI * (in[ii] + 0.044715 * in[ii] * in[ii] *in[ii]));
out[ii] = dy[ii] * (0.5 * SQRT_2_PI * in[ii] * (1 - var1 * var1) * (0.134145 * in[ii] * in[ii] + 1) + 0.5 * var1 + 0.5);
}
}
};

template <typename Device, typename T>
struct GeluGradGradFunctor {
void operator()(const Device& d, const T * dy, const T * dy_, const T * in, T * out, int const size) {
#pragma omp parallel for
for (int ii = 0; ii < size; ii++) {
T const var1 = tanh(SQRT_2_PI * (in[ii] + 0.044715 * in[ii] * in[ii] *in[ii]));
T const var2 = SQRT_2_PI * (1 - var1 * var1) * (0.134145 * in[ii] * in[ii] + 1);

out[ii] = dy[ii] * dy_[ii] * (0.134145 * SQRT_2_PI * in[ii] * in[ii] * (1 - var1 * var1) - SQRT_2_PI * in[ii] * var2 * (0.134145 * in[ii] * in[ii] + 1) * var1 + var2);
}
}
};

// OpKernel definition.
// template parameter <T> is the datatype of the tensors.
template <typename Device, typename T>
class GeluOp : public OpKernel {
public :
explicit GeluOp(OpKernelConstruction* context) : OpKernel(context) {}

void Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& x = context->input(0);

Tensor * output = NULL;
int context_output_index = 0;
OP_REQUIRES_OK(context, context->allocate_output(context_output_index++,
x.shape(),
&output));

GeluFunctor<Device, T>()(
context->eigen_device<Device>(),
x.flat<T>().data(),
output->flat<T>().data(),
static_cast<int>(output->NumElements())
);
// GeluLauncher(x.flat<T>().data(), output->flat<T>().data(), static_cast<int>(output->NumElements()));
}
};

// OpKernel definition.
// template parameter <T> is the datatype of the tensors.
template <typename Device, typename T>
class GeluGradOp : public OpKernel {
public :
explicit GeluGradOp(OpKernelConstruction* context) : OpKernel(context) {}

void Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& dy = context->input(0);
const Tensor& x = context->input(1);

Tensor * output = NULL;
int context_output_index = 0;
OP_REQUIRES_OK(context, context->allocate_output(context_output_index++,
x.shape(),
&output));

GeluGradFunctor<Device, T>()(
context->eigen_device<Device>(),
dy.flat<T>().data(),
x.flat<T>().data(),
output->flat<T>().data(),
static_cast<int>(output->NumElements())
);
// GeluGradLauncher(dy.flat<T>().data(), x.flat<T>().data(), output->flat<T>().data(), static_cast<int>(output->NumElements()));
}
};

// OpKernel definition.
// template parameter <T> is the datatype of the tensors.
template <typename Device, typename T>
class GeluGradGradOp : public OpKernel {
public :
explicit GeluGradGradOp(OpKernelConstruction* context) : OpKernel(context) {}

void Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& dy = context->input(0);
const Tensor& dy_ = context->input(1);
const Tensor& x = context->input(2);

Tensor * output = NULL;
int context_output_index = 0;
OP_REQUIRES_OK(context, context->allocate_output(context_output_index++,
x.shape(),
&output));

GeluGradGradFunctor<Device, T>()(
context->eigen_device<Device>(),
dy.flat<T>().data(),
dy_.flat<T>().data(),
x.flat<T>().data(),
output->flat<T>().data(),
static_cast<int>(output->NumElements())
);
// GeluGradGradLauncher(dy.flat<T>().data(), x.flat<T>().data(), output->flat<T>().data(), static_cast<int>(output->NumElements()));
}
};

#define REGISTER_CPU(T) \
/* Declare explicit instantiations in kernel_example.cu.cc. */ \
REGISTER_KERNEL_BUILDER( \
Name("Gelu").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
GeluOp<CPUDevice, T>); \
/* Declare explicit instantiations in kernel_example.cu.cc. */ \
REGISTER_KERNEL_BUILDER( \
Name("GeluGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
GeluGradOp<CPUDevice, T>); \
/* Declare explicit instantiations in kernel_example.cu.cc. */ \
REGISTER_KERNEL_BUILDER( \
Name("GeluGradGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
GeluGradGradOp<CPUDevice, T>);
REGISTER_CPU(float);
REGISTER_CPU(double);
Loading