Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
411 changes: 252 additions & 159 deletions lib/kernels/src/hip/ops/attention_kernels.cpp

Large diffs are not rendered by default.

8 changes: 2 additions & 6 deletions lib/kernels/src/hip/ops/batch_matmul_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
*/

#include "kernels/batch_matmul_kernels.h"
#include "kernels/hip_helper.h"
#include "device.h"
#include <hip/hip_runtime.h>

namespace FlexFlow {
Expand All @@ -38,13 +38,9 @@ void forward_kernel(hipStream_t stream,
int batch,
int a_seq_length_dim,
int b_seq_length_dim,
int seq_length = -1) {
int seq_length) {
checkCUDA(hipblasSetStream(handle.blas, stream));
checkCUDNN(miopenSetStream(handle.dnn, stream));

// int a_stride = n * k;
// int b_stride = m * k;
// int o_stride = n * m;
int lda = k;
int ldb = m;
int ldo = m;
Expand Down
84 changes: 43 additions & 41 deletions lib/kernels/src/hip/ops/batch_norm_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,60 +14,48 @@
*/

#include "kernels/batch_norm_kernels.h"
#include "kernels/hip_helper.h"
#include "device.h"
#include "kernels/allocation.h"
#include "kernels/ff_handle.h"
#include <hip/hip_runtime.h>

namespace FlexFlow {

// declare Legion names
using Legion::Context;
using Legion::coord_t;
using Legion::Domain;
using Legion::Machine;
using Legion::Memory;
using Legion::PhysicalRegion;
using Legion::Rect;
using Legion::Runtime;
using Legion::Task;

#define MIOPEN_BN_MIN_EPSILON 0.001

namespace Kernels {
namespace BatchNorm {

void forward_kernel(hipStream_t stream,
BatchNormPerDeviceState const *m,
BatchNormPerDeviceState const &m,
float const *input_ptr,
float *output_ptr,
float const *scale_ptr,
float const *bias_ptr) {

checkCUDNN(miopenSetStream(m->handle.dnn, stream));
checkCUDNN(miopenSetStream(m.handle.dnn, stream));

float alpha = 1.0f, beta = 0.0f;
// coord_t numChannels = m->numChannels;
checkCUDNN(miopenBatchNormalizationForwardTraining(
m->handle.dnn,
m->mode,
m.handle.dnn,
m.mode,
&alpha,
&beta,
m->inputTensor,
m.inputTensor,
input_ptr,
m->outputTensor,
m.outputTensor,
output_ptr,
m->biasTensor,
m.biasTensor,
static_cast<void *>(const_cast<float *>(scale_ptr)),
static_cast<void *>(const_cast<float *>(bias_ptr)),
1.0,
m->runningMean,
m->runningVar,
m.runningMean,
m.runningVar,
MIOPEN_BN_MIN_EPSILON,
m->saveMean,
m->saveVar));
m.saveMean,
m.saveVar));
}

void backward_kernel(hipStream_t stream,
BatchNormPerDeviceState *m,
BatchNormPerDeviceState &m,
float const *input_ptr,
float *output_grad_ptr,
float const *output_ptr,
Expand All @@ -77,10 +65,10 @@ void backward_kernel(hipStream_t stream,
float *bias_grad_ptr,
size_t numElements) {

checkCUDNN(miopenSetStream(m->handle.dnn, stream));
checkCUDNN(miopenSetStream(m.handle.dnn, stream));

float alpha = 1.0f;
if (m->relu) {
if (m.relu) {
hipLaunchKernelGGL(reluBackward,
GET_BLOCKS(numElements),
CUDA_NUM_THREADS,
Expand All @@ -90,28 +78,28 @@ void backward_kernel(hipStream_t stream,
output_ptr,
numElements);
}
checkCUDNN(miopenBatchNormalizationBackward(m->handle.dnn,
m->mode,
checkCUDNN(miopenBatchNormalizationBackward(m.handle.dnn,
m.mode,
&alpha,
&alpha,
&alpha,
&alpha,
m->inputTensor,
m.inputTensor,
input_ptr,
m->outputTensor,
m.outputTensor,
output_grad_ptr,
m->inputTensor,
m.inputTensor,
input_grad_ptr,
m->biasTensor,
m.biasTensor,
scale_ptr,
scale_grad_ptr,
bias_grad_ptr,
MIOPEN_BN_MIN_EPSILON,
m->saveMean,
m->saveVar));
m.saveMean,
m.saveVar));
}

BatchNormPerDeviceState init_kernel(PerDeviceFFHandle handler,
BatchNormPerDeviceState init_kernel(PerDeviceFFHandle handle,
Allocator allocator,
float *runningMean,
int output_n,
Expand All @@ -128,9 +116,6 @@ BatchNormPerDeviceState init_kernel(PerDeviceFFHandle handler,
checkCUDNN(miopenCreateTensorDescriptor(&biasTensor));
checkCUDNN(miopenCreateTensorDescriptor(&outputTensor));
mode = miopenBNSpatial;
// #if HIPDNN_VERSION >= 7000
// mode = HIPDNN_BATCHNORM_SPATIAL_PERSISTENT;
// #endif
fprintf(
stderr, "output(%d,%d,%d,%d)\n", output_n, output_c, output_h, output_w);
checkCUDNN(miopenSet4dTensorDescriptor(
Expand Down Expand Up @@ -170,6 +155,23 @@ BatchNormPerDeviceState init_kernel(PerDeviceFFHandle handler,
checkCUDNN(miopenSetActivationDescriptor(
actiDesc, miopenActivationRELU, 0.0, 0.0, 0.0));
}

BatchNormPerDeviceState per_device_state = {handle,
inputTensor,
outputTensor,
biasTensor,
actiDesc,
mode,
runningMean,
runningVar,
saveMean,
saveVar,
output_n,
output_c,
output_h,
output_w,
relu};
return per_device_state;
}

void cleanup_kernel(Allocator allocator,
Expand Down
4 changes: 3 additions & 1 deletion lib/kernels/src/hip/ops/cast_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
*/

#include "kernels/cast_kernels.h"
#include "device.h"
#include "kernels/datatype_dispatch.h"
#include "kernels/hip_helper.h"
#include <hip/hip_runtime.h>

namespace FlexFlow {
Expand Down Expand Up @@ -73,6 +73,7 @@ struct BackwardKernel {
};

void forward_kernel(ffStream_t stream,
PerDeviceFFHandle handle,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output,
DataType input_type,
Expand All @@ -82,6 +83,7 @@ void forward_kernel(ffStream_t stream,
}

void backward_kernel(ffStream_t stream,
PerDeviceFFHandle handle,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output,
DataType input_type,
Expand Down
3 changes: 2 additions & 1 deletion lib/kernels/src/hip/ops/combine_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
*/

#include "kernels/combine_kernels.h"
#include "device.h"
#include "kernels/accessor.h"
#include "kernels/datatype_dispatch.h"
#include "kernels/hip_helper.h"
#include <hip/hip_runtime.h>

namespace FlexFlow {
Expand Down
3 changes: 2 additions & 1 deletion lib/kernels/src/hip/ops/concat_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
*/

#include "kernels/concat_kernels.h"
#include "kernels/hip_helper.h"
#include "device.h"
#include <cassert>
#include <hip/hip_runtime.h>

namespace FlexFlow {
Expand Down
97 changes: 45 additions & 52 deletions lib/kernels/src/hip/ops/conv_2d_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,43 @@ namespace FlexFlow {
namespace Kernels {
namespace Conv2D {

miopenConvBwdDataAlgorithm_t selectConvolutionBackwardDataAlgorithm(
miopenHandle_t handle,
const miopenTensorDescriptor_t wDesc,
void const *w,
const miopenTensorDescriptor_t dyDesc,
void const *dy,
const miopenConvolutionDescriptor_t convDesc,
void *workSpace,
size_t workSpaceSize,
const miopenTensorDescriptor_t dxDesc,
void *dx,
float *time) {
int const reqAlgCnt = 8;
int cnt = 0;
miopenConvAlgoPerf_t perfResults[reqAlgCnt];
checkCUDNN(miopenFindConvolutionBackwardDataAlgorithm(handle,
dyDesc,
dy,
wDesc,
w,
convDesc,
dxDesc,
dx,
reqAlgCnt,
&cnt,
perfResults,
workSpace,
workSpaceSize,
false));
assert(cnt > 0);
// checkCUDNN(perfResults[0].status);
if (time != nullptr) {
*time = perfResults[0].time;
}
return perfResults[0].bwd_data_algo;
}

miopenConvFwdAlgorithm_t selectConvolutionForwardAlgorithm(
miopenHandle_t handle,
const miopenTensorDescriptor_t xDesc,
Expand Down Expand Up @@ -49,7 +86,7 @@ miopenConvFwdAlgorithm_t selectConvolutionForwardAlgorithm(
workSpaceSize,
false));
assert(cnt > 0);
// checkCUDNN(perfResults[0].status);
checkCUDNN(perfResults[0].status);
if (time != nullptr) {
*time = perfResults[0].time;
}
Expand Down Expand Up @@ -86,50 +123,13 @@ miopenConvBwdWeightsAlgorithm_t selectConvolutionBackwardFilterAlgorithm(
workSpaceSize,
false));
assert(cnt > 0);
// checkCUDNN(perfResults[0].status);
checkCUDNN(perfResults[0].status);
if (time != nullptr) {
*time = perfResults[0].time;
}
return perfResults[0].bwd_weights_algo;
}

miopenConvBwdDataAlgorithm_t selectConvolutionBackwardDataAlgorithm(
miopenHandle_t handle,
const miopenTensorDescriptor_t wDesc,
void const *w,
const miopenTensorDescriptor_t dyDesc,
void const *dy,
const miopenConvolutionDescriptor_t convDesc,
void *workSpace,
size_t workSpaceSize,
const miopenTensorDescriptor_t dxDesc,
void *dx,
float *time) {
int const reqAlgCnt = 8;
int cnt = 0;
miopenConvAlgoPerf_t perfResults[reqAlgCnt];
checkCUDNN(miopenFindConvolutionBackwardDataAlgorithm(handle,
dyDesc,
dy,
wDesc,
w,
convDesc,
dxDesc,
dx,
reqAlgCnt,
&cnt,
perfResults,
workSpace,
workSpaceSize,
false));
assert(cnt > 0);
// checkCUDNN(perfResults[0].status);
if (time != nullptr) {
*time = perfResults[0].time;
}
return perfResults[0].bwd_data_algo;
}

Conv2DPerDeviceState init_kernel(PerDeviceFFHandle handle,
std::optional<Activation> activation,
int kernel_h,
Expand Down Expand Up @@ -182,25 +182,19 @@ Conv2DPerDeviceState init_kernel(PerDeviceFFHandle handle,
checkCUDNN(miopenSet4dTensorDescriptor(
filterDesc, miopenFloat, output_c, input_c / groups, kernel_h, kernel_w));

checkCUDNN(miopenInitConvolutionDescriptor(convDesc,
miopenConvolution,
pad_h, // conv->padding_h,
pad_w, // conv->padding_w,
stride_h,
stride_w,
1 /*upscale_x*/,
1 /*upscale_y*/));
checkCUDNN(miopenInitConvolutionDescriptor(
convDesc, miopenConvolution, pad_h, pad_w, stride_h, stride_w, 1, 1));

if (groups != 1) {
checkCUDNN(miopenSetConvolutionGroupCount(convDesc, groups));
}

// TODO: enable tensor core when possible
if (handle.allowTensorOpMathConversion) {
// checkCUDNN(hipdnnSetConvolutionMathType(m->convDesc,
// CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION));
checkCUDNN(hipdnnSetConvolutionMathType(
m.convDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION));
} else {
// checkCUDNN(hipdnnSetConvolutionMathType(m->convDesc,
// HIPDNN_TENSOR_OP_MATH));
checkCUDNN(hipdnnSetConvolutionMathType(m.convDesc, HIPDNN_TENSOR_OP_MATH));
}

int n, c, h, w;
Expand Down Expand Up @@ -298,7 +292,6 @@ void forward_kernel(hipStream_t stream,
m.handle.workSpace,
m.handle.workSpaceSize));

// use_bias == True
if (bias_ptr != NULL) {
checkCUDNN(miopenConvolutionForwardBias(m.handle.dnn,
&alpha,
Expand Down