Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions lib/kernels/include/kernels/pool_2d_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,50 @@
#define _FLEXFLOW_OPS_KERNELS_POOL_2D_KERNELS_H

#include "kernels/device.h"
#include "kernels/ff_handle.h"
#include "op-attrs/activation.h"
#include "op-attrs/ops/pool_2d.h"
#include "utils/visitable.h"

namespace FlexFlow {

class Pool2DPerDeviceState : public PerDeviceOpState {
public:
Pool2DPerDeviceState(FFHandler handle);
struct Pool2DPerDeviceState {
PerDeviceFFHandle handle;
ffTensorDescriptor_t inputTensor, outputTensor;
ffActivationDescriptor_t actiDesc;
ffPoolingDescriptor_t poolDesc;
bool relu;
char op_name[MAX_OPNAME];
};
}

FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(Pool2DPerDeviceState,
handle,
inputTensor,
outputTensor,
actiDesc,
poolDesc,
relu);

namespace Kernels {
namespace Pool2D {

Pool2DPerDeviceState init_kernel(PerDeviceFFHandle handle,
optional<Activation> activation,
int input_w,
int input_h,
int input_c,
int input_n,
int output_w,
int output_h,
int output_c,
int output_n,
int pad_h,
int pad_w,
int kernel_h,
int kernel_w,
int stride_h,
int stride_w,
PoolOp pool_type);

void init_kernel(Pool2DPerDeviceState *m,
int input_w,
int input_h,
Expand All @@ -36,12 +64,12 @@ void init_kernel(Pool2DPerDeviceState *m,
PoolType pool_type);

void forward_kernel(ffStream_t stream,
Pool2DPerDeviceState const *m,
Pool2DPerDeviceState const &m,
void const *input_ptr,
void *output_ptr);

void backward_kernel(ffStream_t stream,
Pool2DPerDeviceState const *m,
Pool2DPerDeviceState const &m,
void const *input_ptr,
void *input_grad_ptr,
void const *output_ptr,
Expand Down
110 changes: 87 additions & 23 deletions lib/kernels/src/cuda/pool_2d_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,79 @@

namespace FlexFlow {

Pool2DPerDeviceState::Pool2DPerDeviceState(FFHandler handler)
: PerDeviceOpState(handler) {
namespace Kernels {
namespace Pool2D {

Pool2DPerDeviceState init_kernel(PerDeviceFFHandle handle,
optional<Activation> activation,
int input_w,
int input_h,
int input_c,
int input_n,
int output_w,
int output_h,
int output_c,
int output_n,
int pad_h,
int pad_w,
int kernel_h,
int kernel_w,
int stride_h,
int stride_w,
PoolOp pool_type) {
ffTensorDescriptor_t inputTensor;
ffTensorDescriptor_t outputTensor;
ffActivationDescriptor_t actiDesc;
ffPoolingDescriptor_t poolDesc;

checkCUDNN(cudnnCreateTensorDescriptor(&inputTensor));
checkCUDNN(cudnnCreateTensorDescriptor(&outputTensor));
checkCUDNN(cudnnCreateActivationDescriptor(&actiDesc));
checkCUDNN(cudnnCreatePoolingDescriptor(&poolDesc));
}

namespace Kernels {
namespace Pool2D {
checkCUDNN(cudnnSetTensor4dDescriptor(inputTensor,
CUDNN_TENSOR_NCHW,
CUDNN_DATA_FLOAT,
input_n,
input_c,
input_h,
input_w));
cudnnPoolingMode_t mode;
if (pool_type == PoolOp::MAX) {
mode = CUDNN_POOLING_MAX;
} else {
assert(pool_type == PoolOp::AVG);
mode = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
}

checkCUDNN(cudnnSetPooling2dDescriptor(poolDesc,
mode,
CUDNN_PROPAGATE_NAN,
kernel_h,
kernel_w,
pad_h,
pad_w,
stride_h,
stride_w));

int n, c, h, w;
checkCUDNN(
cudnnGetPooling2dForwardOutputDim(poolDesc, inputTensor, &n, &c, &h, &w));
assert(n == output_n);
assert(c == output_c);
assert(h == output_h);
assert(w == output_w);

checkCUDNN(cudnnSetTensor4dDescriptor(
outputTensor, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w));
bool relu = false;
if (activation == Activation::RELU) {
relu = true;
}
Pool2DPerDeviceState state = {
handle, inputTensor, outputTensor, actiDesc, poolDesc, relu};
return state;
}

void init_kernel(Pool2DPerDeviceState *m,
int input_w,
Expand All @@ -44,7 +108,7 @@ void init_kernel(Pool2DPerDeviceState *m,
int stride_h,
int stride_w,
PoolType pool_type) {
checkCUDNN(cudnnSetTensor4dDescriptor(m->inputTensor,
checkCUDNN(cudnnSetTensor4dDescriptor(m.inputTensor,
CUDNN_TENSOR_NCHW,
CUDNN_DATA_FLOAT,
input_n,
Expand All @@ -59,7 +123,7 @@ void init_kernel(Pool2DPerDeviceState *m,
assert(pool_type == POOL_AVG);
mode = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
}
checkCUDNN(cudnnSetPooling2dDescriptor(m->poolDesc,
checkCUDNN(cudnnSetPooling2dDescriptor(m.poolDesc,
mode,
CUDNN_PROPAGATE_NAN,
kernel_h,
Expand All @@ -70,55 +134,55 @@ void init_kernel(Pool2DPerDeviceState *m,
stride_w));
int n, c, h, w;
checkCUDNN(cudnnGetPooling2dForwardOutputDim(
m->poolDesc, m->inputTensor, &n, &c, &h, &w));
m.poolDesc, m.inputTensor, &n, &c, &h, &w));
assert(n == output_n);
assert(c == output_c);
assert(h == output_h);
assert(w == output_w);

checkCUDNN(cudnnSetTensor4dDescriptor(
m->outputTensor, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w));
m.outputTensor, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w));
}

void forward_kernel(cudaStream_t stream,
Pool2DPerDeviceState const *m,
Pool2DPerDeviceState const &m,
void const *input_ptr,
void *output_ptr) {

checkCUDNN(cudnnSetStream(m->handle.dnn, stream));
checkCUDNN(cudnnSetStream(m.handle.dnn, stream));

float alpha = 1.0f, beta = 0.0f;
checkCUDNN(cudnnPoolingForward(m->handle.dnn,
m->poolDesc,
checkCUDNN(cudnnPoolingForward(m.handle.dnn,
m.poolDesc,
&alpha,
m->inputTensor,
m.inputTensor,
input_ptr,
&beta,
m->outputTensor,
m.outputTensor,
output_ptr));
}

void backward_kernel(cudaStream_t stream,
Pool2DPerDeviceState const *m,
Pool2DPerDeviceState const &m,
void const *input_ptr,
void *input_grad_ptr,
void const *output_ptr,
void const *output_grad_ptr) {

checkCUDNN(cudnnSetStream(m->handle.dnn, stream));
checkCUDNN(cudnnSetStream(m.handle.dnn, stream));

float alpha = 1.0f;
checkCUDNN(cudnnPoolingBackward(m->handle.dnn,
m->poolDesc,
checkCUDNN(cudnnPoolingBackward(m.handle.dnn,
m.poolDesc,
&alpha,
m->outputTensor,
m.outputTensor,
output_ptr,
m->outputTensor,
m.outputTensor,
output_grad_ptr,
m->inputTensor,
m.inputTensor,
input_ptr,
&alpha,
m->inputTensor,
m.inputTensor,
input_grad_ptr));
}

Expand Down
Loading