Skip to content
39 changes: 29 additions & 10 deletions lib/kernels/include/kernels/linear_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,46 @@

namespace FlexFlow {

class LinearPerDeviceState : public PerDeviceOpState {
public:
LinearPerDeviceState(FFHandler handle, int batch_size);
struct LinearPerDeviceState {
PerDeviceFFHandle handle;
ffTensorDescriptor_t outputTensor;
ffActivationDescriptor_t actiDesc;

public:
float const *one_ptr;
ActiMode activation;
float const *one_ptr; // how to handle this?
cudnnActivationMode_t activation;
optional<Regularizer> regularizer;
bool use_bias;
DataType input_type, weight_type, output_type;
};

FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(LinearPerDeviceState,
handle,
outputTensor,
actiDesc,
one_ptr,
activation,
regularizer,
use_bias,
input_type,
weight_type,
output_type);

namespace Kernels {
namespace Linear {
void init_kernel(LinearPerDeviceState *m, int batch_size, int channel);

LinearPerDeviceState
init_kernel(PerDeviceFFHandle handle, Allocator allocator, float *one_ptr;
optional<Regularizer> regularizer,
bool use_bias,
DataType input_type,
DataType weight_type,
DataType output_type,
int batch_size,
int channel);

bool use_activation(ActiMode mode);

void forward_kernel(ffStream_t stream,
LinearPerDeviceState const *m,
LinearPerDeviceState const &m,
void const *input_ptr,
void *output_ptr,
void const *filter_ptr,
Expand All @@ -34,7 +53,7 @@ void forward_kernel(ffStream_t stream,
int out_dim,
int batch_size);
void backward_kernel(ffStream_t stream,
LinearPerDeviceState const *m,
LinearPerDeviceState const &m,
void const *input_ptr,
void *input_grad_ptr,
void const *output_ptr,
Expand Down
182 changes: 92 additions & 90 deletions lib/kernels/src/cuda/linear_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,76 +13,78 @@
* limitations under the License.
*/

#include "kernels/allocation.h"
#include "kernels/cuda_helper.h"
#include "kernels/linear_kernels.h"

namespace FlexFlow {

LinearPerDeviceState::LinearPerDeviceState(FFHandler handler, int batch_size)
: PerDeviceOpState(handler) {
// Allocate an all-one's vector
float *dram_one_ptr = (float *)malloc(sizeof(float) * batch_size);
for (int i = 0; i < batch_size; i++) {
dram_one_ptr[i] = 1.0f;
}
float *fb_one_ptr;
checkCUDA(cudaMalloc(&fb_one_ptr, sizeof(float) * batch_size));
checkCUDA(cudaMemcpy(fb_one_ptr,
dram_one_ptr,
sizeof(float) * batch_size,
cudaMemcpyHostToDevice));
one_ptr = (float const *)fb_one_ptr;
// Allocate descriptors
checkCUDNN(cudnnCreateActivationDescriptor(&actiDesc));
checkCUDNN(cudnnCreateTensorDescriptor(&outputTensor));
}

namespace Kernels {
namespace Linear {

bool use_activation(ActiMode mode) {
switch (mode) {
case AC_MODE_RELU:
case AC_MODE_SIGMOID:
case AC_MODE_TANH:
return true;
case AC_MODE_NONE:
return false;
default:
assert(0);
// what's the float * one_ptr
LinearPerDeviceState
init_kernel(PerDeviceFFHandle handle, Allocator allocator, float *one_ptr;
ActiMode activation,
Regularizer regularizer,
bool use_bias,
DataType input_type,
DataType weight_type,
DataType output_type,
int batch_size,
int channel) {
ffTensorDescriptor_t outputTensor;
ffActivationDescriptor_t actiDesc;
checkCUDNN(cudnnCreateTensorDescriptor(&outputTensor));
checkCUDNN(cudnnCreateActivationDescriptor(&actiDesc));
checkCUDNN(cudnnSetTensor4dDescriptor(outputTensor,
CUDNN_TENSOR_NCHW,
ff_to_cudnn_datatype(output_type),
batch_size,
channel,
1,
1));
cudnnActivationMode_t mode;
switch (activation) {
case RELU:
mode = CUDNN_ACTIVATION_RELU;
break;
case SIGMOID:
mode = CUDNN_ACTIVATION_SIGMOID;
break;
case TANH:
mode = CUDNN_ACTIVATION_TANH;
break;
case GELU:
mode = CUDNN_ACTIVATION_GELU;
break;
default:
// Unsupported activation mode
assert(false);
}
return false;
}
checkCUDNN(
cudnnSetActivationDescriptor(actiDesc, mode, CUDNN_PROPAGATE_NAN, 0.0));
checkCUDNN(
cudnnSetTensorDescriptorFromArrayShape(outputTensor, output_shape));

void init_kernel(LinearPerDeviceState *m, int batch_size, int channel) {
if (use_activation(m->activation)) {
cudnnActivationMode_t mode;
switch (m->activation) {
case AC_MODE_RELU:
mode = CUDNN_ACTIVATION_RELU;
break;
case AC_MODE_SIGMOID:
mode = CUDNN_ACTIVATION_SIGMOID;
break;
default:
// Unsupported activation mode
assert(false);
}
checkCUDNN(cudnnSetActivationDescriptor(
m->actiDesc, mode, CUDNN_PROPAGATE_NAN, 0.0));
checkCUDNN(cudnnSetTensor4dDescriptor(m->outputTensor,
CUDNN_TENSOR_NCHW,
ff_to_cudnn_datatype(m->output_type),
batch_size,
channel,
1,
1));
}
// todo: how to use allocator to allocate memory for float * one_ptr, how many
// bytes to allocate?
checkCUDA(cudaMalloc(&one_ptr, sizeof(float) * batch_size));
LinearPerDeviceState per_device_state = {handle,
outputTensor,
actiDesc,
one_ptr,
activation,
regularizer,
use_bias,
input_type,
weight_type,
output_type};
return per_device_state;
}

void forward_kernel(cudaStream_t stream,
LinearPerDeviceState const *m,
LinearPerDeviceState const &m,
void const *input_ptr,
void *output_ptr,
void const *weight_ptr,
Expand All @@ -91,19 +93,19 @@ void forward_kernel(cudaStream_t stream,
int out_dim,
int batch_size) {

checkCUDA(cublasSetStream(m->handle.blas, stream));
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));
checkCUDA(cublasSetStream(m.handle.blas, stream));
checkCUDNN(cudnnSetStream(m.handle.dnn, stream));
float alpha = 1.0f, beta = 0.0f;
cudaDataType_t input_type = ff_to_cuda_datatype(m->input_type);
cudaDataType_t weight_type = ff_to_cuda_datatype(m->weight_type);
cudaDataType_t output_type = ff_to_cuda_datatype(m->output_type);
cudaDataType_t input_type = ff_to_cuda_datatype(m.input_type);
cudaDataType_t weight_type = ff_to_cuda_datatype(m.weight_type);
cudaDataType_t output_type = ff_to_cuda_datatype(m.output_type);
#if CUDA_VERSION >= 11000
// TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance
cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
#else
cudaDataType_t compute_type = CUDA_R_32F;
#endif
checkCUDA(cublasGemmEx(m->handle.blas,
checkCUDA(cublasGemmEx(m.handle.blas,
CUBLAS_OP_T,
CUBLAS_OP_N,
out_dim,
Expand All @@ -124,7 +126,7 @@ void forward_kernel(cudaStream_t stream,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// use_bias = True
if (bias_ptr != NULL) {
checkCUDA(cublasGemmEx(m->handle.blas,
checkCUDA(cublasGemmEx(m.handle.blas,
CUBLAS_OP_T,
CUBLAS_OP_N,
out_dim,
Expand All @@ -134,7 +136,7 @@ void forward_kernel(cudaStream_t stream,
bias_ptr,
weight_type,
1,
m->one_ptr,
m.one_ptr,
CUDA_R_32F,
1,
&alpha,
Expand All @@ -144,30 +146,30 @@ void forward_kernel(cudaStream_t stream,
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
if (use_activation(m->activation)) {
checkCUDNN(cudnnActivationForward(m->handle.dnn,
m->actiDesc,
if (use_activation(m.activation)) {
checkCUDNN(cudnnActivationForward(m.handle.dnn,
m.actiDesc,
&alpha,
m->outputTensor,
m.outputTensor,
output_ptr,
&beta,
m->outputTensor,
m.outputTensor,
output_ptr));
} else if (m->activation == AC_MODE_GELU) {
} else if (m.activation == AC_MODE_GELU) {
size_t elements = (size_t)out_dim * (size_t)batch_size;
constexpr float B = 0.7978845608028654f; // sqrt(2.0/M_PI)
constexpr float C = 0.035677408136300125f; // 0.044715 * sqrt(2.0/M_PI)
gelu_forward_kernel<<<GET_BLOCKS(elements), CUDA_NUM_THREADS>>>(
elements, B, C, (float *)output_ptr);
} else if (m->activation == AC_MODE_NONE) {
} else if (m.activation == AC_MODE_NONE) {
// Do nothing
} else {
assert(false && "Unsupported activation for Linear");
}
}

void backward_kernel(cudaStream_t stream,
LinearPerDeviceState const *m,
LinearPerDeviceState const &m,
void const *input_ptr,
void *input_grad_ptr,
void const *output_ptr,
Expand All @@ -179,33 +181,33 @@ void backward_kernel(cudaStream_t stream,
int out_dim,
int batch_size) {

checkCUDA(cublasSetStream(m->handle.blas, stream));
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));
checkCUDA(cublasSetStream(m.handle.blas, stream));
checkCUDNN(cudnnSetStream(m.handle.dnn, stream));

float alpha = 1.0f;
cudaDataType_t input_type = ff_to_cuda_datatype(m->input_type);
cudaDataType_t weight_type = ff_to_cuda_datatype(m->weight_type);
cudaDataType_t output_type = ff_to_cuda_datatype(m->output_type);
cudaDataType_t input_type = ff_to_cuda_datatype(m.input_type);
cudaDataType_t weight_type = ff_to_cuda_datatype(m.weight_type);
cudaDataType_t output_type = ff_to_cuda_datatype(m.output_type);
#if CUDA_VERSION >= 11000
// TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance
cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
#else
cudaDataType_t compute_type = CUDA_R_32F;
#endif
int output_size = out_dim * batch_size;
if (m->activation == AC_MODE_RELU) {
if (m.activation == AC_MODE_RELU) {
relu_backward_kernel(
m->output_type, output_grad_ptr, output_ptr, output_size, stream);
} else if (m->activation == AC_MODE_SIGMOID) {
m.output_type, output_grad_ptr, output_ptr, output_size, stream);
} else if (m.activation == AC_MODE_SIGMOID) {
sigmoid_backward_kernel(
m->output_type, output_grad_ptr, output_ptr, output_size, stream);
m.output_type, output_grad_ptr, output_ptr, output_size, stream);
} else {
// TODO: only support relu and sigmoid for now
assert(m->activation == AC_MODE_NONE);
assert(m.activation == AC_MODE_NONE);
}
// Compute weight gradiant
// NOTE: we use alpha=1 for kernel_grad to accumulate gradients
checkCUDA(cublasGemmEx(m->handle.blas,
checkCUDA(cublasGemmEx(m.handle.blas,
CUBLAS_OP_N,
CUBLAS_OP_T,
in_dim,
Expand All @@ -224,18 +226,18 @@ void backward_kernel(cudaStream_t stream,
in_dim,
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
if (m->kernel_reg_type == REG_MODE_NONE) {
if (m.kernel_reg_type == REG_MODE_NONE) {
// do nothing
} else if (m->kernel_reg_type == REG_MODE_L2) {
checkCUDA(cublasSgeam(m->handle.blas,
} else if (m.kernel_reg_type == REG_MODE_L2) {
checkCUDA(cublasSgeam(m.handle.blas,
CUBLAS_OP_N,
CUBLAS_OP_N,
in_dim,
out_dim,
&alpha,
(float *)kernel_grad_ptr,
in_dim,
&(m->kernel_reg_lambda),
&(m.kernel_reg_lambda),
(float *)kernel_ptr,
in_dim,
(float *)kernel_grad_ptr,
Expand All @@ -248,14 +250,14 @@ void backward_kernel(cudaStream_t stream,
// NOTE: we use alpha=1 for bias_grad to accumulate gradients
// use_bias = True
if (bias_grad_ptr != NULL) {
checkCUDA(cublasGemmEx(m->handle.blas,
checkCUDA(cublasGemmEx(m.handle.blas,
CUBLAS_OP_N,
CUBLAS_OP_T,
1,
out_dim,
batch_size,
&alpha,
m->one_ptr,
m.one_ptr,
CUDA_R_32F,
1,
output_grad_ptr,
Expand All @@ -271,7 +273,7 @@ void backward_kernel(cudaStream_t stream,
// Compute data gradiant
// NOTE: we use alpha=1 for input_grad to accumulate gradients
if (input_grad_ptr != NULL) {
checkCUDA(cublasGemmEx(m->handle.blas,
checkCUDA(cublasGemmEx(m.handle.blas,
CUBLAS_OP_N,
CUBLAS_OP_N,
in_dim,
Expand Down
2 changes: 1 addition & 1 deletion lib/runtime/src/ops/element_binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ OpTaskSignature init_signature<ELEMENTBINARY_INIT_TASK_ID>() {

init.add_return_value<ElementBinaryPerDeviceState>();

return init;
return init; // todo:this may be wrong, because the headfile retrun void
}

template <>
Expand Down
Loading