diff --git a/lib/kernels/src/hip/ops/attention_kernels.cpp b/lib/kernels/src/hip/ops/attention_kernels.cpp index 455fdba1cd..005cef30d1 100644 --- a/lib/kernels/src/hip/ops/attention_kernels.cpp +++ b/lib/kernels/src/hip/ops/attention_kernels.cpp @@ -14,206 +14,299 @@ */ #include "kernels/attention_kernels.h" -#include "kernels/hip_helper.h" +#include "device.h" #include namespace FlexFlow { - -// declare Legion names -using Legion::coord_t; -using Legion::Memory; - namespace Kernels { namespace MultiHeadAttention { -void forward_kernel(hipStream_t stream, - MultiHeadAttentionPerDeviceState const *m, - float const *query_ptr, - float const *key_ptr, - float const *value_ptr, - float const *weight_ptr, - float *output_ptr) { -#if 0 - checkCUDNN(miopenSetStream(m->handle.dnn, stream)); - - checkCUDNN(cudnnMultiHeadAttnForward(m->handle.dnn, - m->attnDesc, -1, m->loWinIdx, m->hiWinIdx, - m->devQoSeqArray, m->devKvSeqArray, m->qDesc, - query_ptr, NULL/*residual*/, m->kDesc, key_ptr, - m->vDesc, value_ptr, m->oDesc, output_ptr, m->weightSize, - weight_ptr, m->handle.workSpaceSize, m->handle.workSpace, - m->reserveSpaceSize, m->reserveSpace)); -#endif -} - -void backward_kernel(hipStream_t stream, - MultiHeadAttentionPerDeviceState const *m, - float const *query_ptr, - float *query_grad_ptr, - float const *key_ptr, - float *key_grad_ptr, - float const *value_ptr, - float *value_grad_ptr, - float const *weight_ptr, - float *weight_grad_ptr, - float const *output_grad_ptr) { - checkCUDNN(miopenSetStream(m->handle.dnn, stream)); - -#if 0 - checkCUDNN(cudnnMultiHeadAttnBackwardData(m->handle.dnn, - m->attnDesc, m->loWinIdx, m->hiWinIdx, m->devQoSeqArray, - m->devKvSeqArray, m->oDesc, output_grad_ptr, m->qDesc, - query_grad_ptr, query_ptr, m->kDesc, key_grad_ptr, key_ptr, - m->vDesc, value_grad_ptr, value_ptr, m->weightSize, weight_ptr, - m->handle.workSpaceSize, m->handle.workSpace, m->reserveSpaceSize, - m->reserveSpace)); - checkCUDNN(cudnnMultiHeadAttnBackwardWeights(m->handle.dnn, - m->attnDesc, CUDNN_WGRAD_MODE_ADD, m->qDesc, - query_ptr, m->kDesc, key_ptr, m->vDesc, value_ptr, m->oDesc, - output_grad_ptr, m->weightSize, weight_ptr, weight_grad_ptr, - m->handle.workSpaceSize, m->handle.workSpace, - m->reserveSpaceSize, m->reserveSpace)); -#endif -} - -} // namespace MultiHeadAttention -} // namespace Kernels - -MultiHeadAttentionPerDeviceState::MultiHeadAttentionPerDeviceState( - FFHandler handler, - MultiHeadAttention const *attn, - Memory gpu_mem, - int num_samples, - int num_heads) - : PerDeviceOpState(handler) { +MHAPerDeviceState init_kernel(PerDeviceFFHandle const &handle, + Allocator &allocator, + int num_samples, + int num_heads, + int qSize, + int kSize, + int vSize, + int qProjSize, + int kProjSize, + int vProjSize, + int oProjSize, + int qoSeqLength, + int kvSeqLength, + bool add_bias_kv) { hipStream_t stream; + ffAttnDescriptor_t attnDesc; + ffSeqDataDescriptor_t qDesc; + ffSeqDataDescriptor_t kDesc; + ffSeqDataDescriptor_t vDesc; + ffSeqDataDescriptor_t oDesc; + void *reserveSpace; + void *dropoutStates; + int *devQoSeqArray; + int *devKvSeqArray; + size_t reserveSpaceSize; + size_t dropoutStateSize; + size_t weightSize; + checkCUDA(get_legion_stream(&stream)); checkCUDNN(miopenSetStream(handler.dnn, stream)); + checkCUDNN(miopenCreateAttnDescriptor(&attnDesc)); + checkCUDNN(miopenCreateSeqDataDescriptor(&qDesc)); + checkCUDNN(miopenCreateSeqDataDescriptor(&kDesc)); + checkCUDNN(miopenCreateSeqDataDescriptor(&vDesc)); + checkCUDNN(miopenCreateSeqDataDescriptor(&oDesc)); + + assert(!add_bias_kv); + miopenAttnQueryMap_t attnMode = MIOPEN_ATTN_QUERYMAP_ALL_TO_ONE; -#if 0 - checkCUDNN(cudnnCreateAttnDescriptor(&attnDesc)); - checkCUDNN(cudnnCreateSeqDataDescriptor(&qDesc)); - checkCUDNN(cudnnCreateSeqDataDescriptor(&kDesc)); - checkCUDNN(cudnnCreateSeqDataDescriptor(&vDesc)); - checkCUDNN(cudnnCreateSeqDataDescriptor(&oDesc)); - // Currently do not support adding bias to key/value projection - assert(!attn->add_bias_kv); - cudnnAttnQueryMap_t attnMode = CUDNN_ATTN_QUERYMAP_ALL_TO_ONE; - // Assume no beam search for now int maxBeamSize = 1; - //printf("batchSize(%d) qSize(%d) kSize(%d) vSize(%d) qProjSize(%d) kProjSize(%d)\n", - // num_samples, attn->qSize, attn->kSize, attn->vSize, attn->qProjSize, attn->kProjSize); - //printf("vProjSize(%d) oProjSize(%d) qoSeqLength(%d) kvSeqLength(%d)\n", - // attn->vProjSize, attn->oProjSize, attn->qoSeqLength, attn->kvSeqLength); + hipdnnMathType_t math_type; if (handle.allowTensorOpMathConversion) { - math_type = CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION; + math_type = HIPDNN_TENSOR_OP_MATH_ALLOW_CONVERSION; } else { math_type = HIPDNN_TENSOR_OP_MATH; } - checkCUDNN(cudnnSetAttnDescriptor(attnDesc, attnMode, num_heads, - 1.0f/*smScalar*/, HIPDNN_DATA_FLOAT, HIPDNN_DATA_FLOAT, math_type, - NULL/*attnDropoutDesc*/, NULL/*postDropoutDesc*/, - attn->qSize, attn->kSize, attn->vSize, attn->qProjSize, attn->kProjSize, - attn->vProjSize, attn->oProjSize, attn->qoSeqLength, attn->kvSeqLength, - num_samples, maxBeamSize)); + checkCUDNN(miopenSetAttnDescriptor(attnDesc, + attnMode, + num_heads, + 1.0f /*smScalar*/, + HIPDNN_DATA_FLOAT, + HIPDNN_DATA_FLOAT, + math_type, + NULL /*attnDropoutDesc*/, + NULL /*postDropoutDesc*/, + qSize, + kSize, + vSize, + qProjSize, + kProjSize, + vProjSize, + oProjSize, + qoSeqLength, + kvSeqLength, + num_samples, + maxBeamSize)); size_t workSpaceSize; - checkCUDNN(cudnnGetMultiHeadAttnBuffers(handler.dnn, attnDesc, &weightSize, - &workSpaceSize, &reserveSpaceSize)); + checkCUDNN(miopenGetMultiHeadAttnBuffers( + handler.dnn, attnDesc, &weightSize, &workSpaceSize, &reserveSpaceSize)); assert(workSpaceSize <= handler.workSpaceSize); - //printf("weightSize(%zu) workSpaceSize(%zu) reserveSpaceSize(%zu)\n", weightSize, workSpaceSize, reserveSpaceSize); - int dimA[CUDNN_SEQDATA_DIM_COUNT]; - cudnnSeqDataAxis_t axes[CUDNN_SEQDATA_DIM_COUNT]; - assert(CUDNN_SEQDATA_DIM_COUNT == 4); - axes[3] = CUDNN_SEQDATA_VECT_DIM; // 3 = nbDims-1 - axes[2] = CUDNN_SEQDATA_BEAM_DIM; - axes[1] = CUDNN_SEQDATA_TIME_DIM; - axes[0] = CUDNN_SEQDATA_BATCH_DIM; - int *qoSeqArray = (int*) malloc(sizeof(int) * num_samples); - int *kvSeqArray = (int*) malloc(sizeof(int) * num_samples); + + int dimA[MIOPEN_SEQDATA_DIM_COUNT]; + miopenSeqDataAxis_t axes[MIOPEN_SEQDATA_DIM_COUNT]; + assert(MIOPEN_SEQDATA_DIM_COUNT == 4); + axes[3] = MIOPEN_SEQDATA_VECT_DIM; // 3 = nbDims-1 + axes[2] = MIOPEN_SEQDATA_BEAM_DIM; + axes[1] = MIOPEN_SEQDATA_TIME_DIM; + axes[0] = MIOPEN_SEQDATA_BATCH_DIM; + + std::unique_ptr qoSeqArray(new int[num_samples]); + std::unique_ptr kvSeqArray(new int[num_samples]); for (int i = 0; i < num_samples; i++) { - qoSeqArray[i] = attn->qoSeqLength; - kvSeqArray[i] = attn->kvSeqLength; + qoSeqArray[i] = qoSeqLength; + kvSeqArray[i] = kvSeqLength; } + // Set qDesc { - dimA[CUDNN_SEQDATA_BEAM_DIM] = 1; - dimA[CUDNN_SEQDATA_BATCH_DIM] = num_samples; - dimA[CUDNN_SEQDATA_TIME_DIM] = attn->qoSeqLength; - dimA[CUDNN_SEQDATA_VECT_DIM] = attn->qSize; - checkCUDNN(cudnnSetSeqDataDescriptor(qDesc, - HIPDNN_DATA_FLOAT, CUDNN_SEQDATA_DIM_COUNT, dimA, axes, - num_samples, qoSeqArray, NULL)); + dimA[MIOPEN_SEQDATA_BEAM_DIM] = 1; + dimA[MIOPEN_SEQDATA_BATCH_DIM] = num_samples; + dimA[MIOPEN_SEQDATA_TIME_DIM] = qoSeqLength; + dimA[MIOPEN_SEQDATA_VECT_DIM] = qSize; + checkCUDNN(miopenSetSeqDataDescriptor(qDesc, + MIOPEN_DATA_FLOAT, + MIOPEN_SEQDATA_DIM_COUNT, + dimA, + axes, + num_samples, + qoSeqArray.get(), + NULL)); } // Set kDesc { - dimA[CUDNN_SEQDATA_BEAM_DIM] = 1; - dimA[CUDNN_SEQDATA_BATCH_DIM] = num_samples; - dimA[CUDNN_SEQDATA_TIME_DIM] = attn->kvSeqLength; - dimA[CUDNN_SEQDATA_VECT_DIM] = attn->kSize; - checkCUDNN(cudnnSetSeqDataDescriptor(kDesc, - HIPDNN_DATA_FLOAT, CUDNN_SEQDATA_DIM_COUNT, dimA, axes, - num_samples, kvSeqArray, NULL)); + dimA[MIOPEN_SEQDATA_BEAM_DIM] = 1; + dimA[MIOPEN_SEQDATA_BATCH_DIM] = num_samples; + dimA[MIOPEN_SEQDATA_TIME_DIM] = kvSeqLength; + dimA[MIOPEN_SEQDATA_VECT_DIM] = kSize; + checkCUDNN(miopenSetSeqDataDescriptor(kDesc, + MIOPEN_DATA_FLOAT, + MIOPEN_SEQDATA_DIM_COUNT, + dimA, + axes, + num_samples, + kvSeqArray.get(), + NULL)); } // Set vDesc { - dimA[CUDNN_SEQDATA_BEAM_DIM] = 1; - dimA[CUDNN_SEQDATA_BATCH_DIM] = num_samples; - dimA[CUDNN_SEQDATA_TIME_DIM] = attn->kvSeqLength; - dimA[CUDNN_SEQDATA_VECT_DIM] = attn->vSize; - checkCUDNN(cudnnSetSeqDataDescriptor(vDesc, - HIPDNN_DATA_FLOAT, CUDNN_SEQDATA_DIM_COUNT, dimA, axes, - num_samples, kvSeqArray, NULL)); + dimA[MIOPEN_SEQDATA_BEAM_DIM] = 1; + dimA[MIOPEN_SEQDATA_BATCH_DIM] = num_samples; + dimA[MIOPEN_SEQDATA_TIME_DIM] = kvSeqLength; + dimA[MIOPEN_SEQDATA_VECT_DIM] = vSize; + checkCUDNN(miopenSetSeqDataDescriptor(vDesc, + MIOPEN_DATA_FLOAT, + MIOPEN_SEQDATA_DIM_COUNT, + dimA, + axes, + num_samples, + kvSeqArray.get(), + NULL)); } // Set oDesc { - dimA[CUDNN_SEQDATA_BEAM_DIM] = 1; - dimA[CUDNN_SEQDATA_BATCH_DIM] = num_samples; - dimA[CUDNN_SEQDATA_TIME_DIM] = attn->qoSeqLength; - dimA[CUDNN_SEQDATA_VECT_DIM] = attn->oProjSize; - checkCUDNN(cudnnSetSeqDataDescriptor(oDesc, - HIPDNN_DATA_FLOAT, CUDNN_SEQDATA_DIM_COUNT, dimA, axes, - num_samples, qoSeqArray, NULL)); + dimA[MIOPEN_SEQDATA_BEAM_DIM] = 1; + dimA[MIOPEN_SEQDATA_BATCH_DIM] = num_samples; + dimA[MIOPEN_SEQDATA_TIME_DIM] = qoSeqLength; + dimA[MIOPEN_SEQDATA_VECT_DIM] = oProjSize; + checkCUDNN(miopenSetSeqDataDescriptor(oDesc, + MIOPEN_DATA_FLOAT, + MIOPEN_SEQDATA_DIM_COUNT, + dimA, + axes, + num_samples, + qoSeqArray.get(), + NULL)); } + // allocate memory for the seqArray and reserve space { size_t totalSize = reserveSpaceSize + sizeof(int) * num_samples * 2; - Realm::Rect<1, coord_t> bounds(Realm::Point<1, coord_t>(0), Realm::Point<1, coord_t>(totalSize-1)); - std::vector field_sizes; - field_sizes.push_back(sizeof(char)); - Realm::RegionInstance::create_instance(reserveInst, gpu_mem, bounds, - field_sizes, 0, Realm::ProfilingRequestSet()).wait(); - devQoSeqArray = (int*) reserveInst.pointer_untyped(0, sizeof(char)); - checkCUDA(hipMemcpy(devQoSeqArray, qoSeqArray, sizeof(int) * num_samples, - hipMemcpyHostToDevice)); - devKvSeqArray = (int*)devQoSeqArray + num_samples; - checkCUDA(hipMemcpy(devKvSeqArray, kvSeqArray, sizeof(int) * num_samples, - hipMemcpyHostToDevice)); - reserveSpace = (int*)devKvSeqArray + num_samples; + + devQoSeqArray = (int *)allocator.allocate(totalSize); + checkCUDA(miopenMemcpy(devQoSeqArray, + qoSeqArray.get(), + sizeof(int) * num_samples, + miopenMemcpyHostToDevice)); + devKvSeqArray = devQoSeqArray + num_samples; + checkCUDA(miopenMemcpy(devKvSeqArray, + kvSeqArray.get(), + sizeof(int) * num_samples, + miopenMemcpyHostToDevice)); + reserveSpace = devKvSeqArray + num_samples; } // allocate memory for loWinIdx/hiWinIdx - loWinIdx = (int*) malloc(sizeof(int) * attn->qoSeqLength); - hiWinIdx = (int*) malloc(sizeof(int) * attn->qoSeqLength); - for (int i = 0; i < attn->qoSeqLength; i++) { + int *loWinIdx = (int *)malloc(sizeof(int) * qoSeqLength); + int *hiWinIdx = (int *)malloc(sizeof(int) * qoSeqLength); + for (int i = 0; i < qoSeqLength; i++) { loWinIdx[i] = 0; - hiWinIdx[i] = attn->kvSeqLength; + hiWinIdx[i] = kvSeqLength; } - free(qoSeqArray); - free(kvSeqArray); -#endif + + MHAPerDeviceState per_device_state = {handle, + weightSize, + reserveSpaceSize, + attnDesc, + qDesc, + kDesc, + vDesc, + oDesc, + devQoSeqArray, + devKvSeqArray, + loWinIdx, + hiWinIdx, + reserveSpace, + allocator}; + + return per_device_state; +} + +void forward_kernel(hipStream_t stream, + MHAPerDeviceState const &device_state, + float const *query_ptr, + float const *key_ptr, + float const *value_ptr, + float const *weight_ptr, + float *output_ptr) { + + checkCUDNN(miopenSetStream(device_state.handle.dnn, stream)); + + checkCUDNN(miopenMultiHeadAttnForward(device_state.handle.dnn, + device_state.attnDesc, + device_state.loWinIdx, + device_state.hiWinIdx, + device_state.devQoSeqArray, + device_state.devKvSeqArray, + device_state.oDesc, + output_ptr, + device_state.qDesc, + query_ptr, + device_state.kDesc, + key_ptr, + device_state.vDesc, + value_ptr, + weight_ptr, + device_state.weightSize, + device_state.reserveSpaceSize, + device_state.reserveSpace)); +} + +void backward_kernel(hipStream_t stream, + MHAPerDeviceState const &device_state, + float const *query_ptr, + float *query_grad_ptr, + float const *key_ptr, + float *key_grad_ptr, + float const *value_ptr, + float *value_grad_ptr, + float const *weight_ptr, + float *weight_grad_ptr, + float const *output_grad_ptr) { + checkCUDNN(miopenSetStream(device_state.handle.dnn, stream)); + + checkCUDNN(miopenMultiHeadAttnBackwardData(device_state.handle.dnn, + device_state.attnDesc, + device_state.loWinIdx, + device_state.hiWinIdx, + device_state.devQoSeqArray, + device_state.devKvSeqArray, + device_state.oDesc, + output_grad_ptr, + device_state.qDesc, + query_grad_ptr, + query_ptr, + device_state.kDesc, + key_grad_ptr, + key_ptr, + device_state.vDesc, + value_grad_ptr, + value_ptr, + weight_ptr, + device_state.weightSize, + device_state.reserveSpaceSize, + device_state.reserveSpace)); + + checkCUDNN(miopenMultiHeadAttnBackwardWeights(device_state.handle.dnn, + device_state.attnDesc, + device_state.loWinIdx, + device_state.hiWinIdx, + device_state.devQoSeqArray, + device_state.devKvSeqArray, + device_state.oDesc, + output_grad_ptr, + device_state.qDesc, + query_ptr, + device_state.kDesc, + key_ptr, + device_state.vDesc, + value_ptr, + weight_grad_ptr, + device_state.weightSize, + device_state.reserveSpaceSize, + device_state.reserveSpace)); } -MultiHeadAttentionPerDeviceState::~MultiHeadAttentionPerDeviceState(void) { -#if 0 - reserveInst.destroy(); - free(loWinIdx); - free(hiWinIdx); - checkCUDNN(cudnnDestroyAttnDescriptor(attnDesc)); - checkCUDNN(cudnnDestroySeqDataDescriptor(qDesc)); - checkCUDNN(cudnnDestroySeqDataDescriptor(kDesc)); - checkCUDNN(cudnnDestroySeqDataDescriptor(vDesc)); - checkCUDNN(cudnnDestroySeqDataDescriptor(oDesc)); -#endif +void cleanup_kernel(Allocator &allocator, + MHAPerDeviceState const &device_state) { + allocator.deallocate(device_state.loWinIdx); + allocator.deallocate(device_state.hiWinIdx); + checkCUDNN(miopenDestroyAttnDescriptor(device_state.attnDesc)); + checkCUDNN(miopenDestroySeqDataDescriptor(device_state.qDesc)); + checkCUDNN(miopenDestroySeqDataDescriptor(device_state.kDesc)); + checkCUDNN(miopenDestroySeqDataDescriptor(device_state.vDesc)); + checkCUDNN(miopenDestroySeqDataDescriptor(device_state.oDesc)); } +} // namespace MultiHeadAttention +} // namespace Kernels } // namespace FlexFlow diff --git a/lib/kernels/src/hip/ops/batch_matmul_kernels.cpp b/lib/kernels/src/hip/ops/batch_matmul_kernels.cpp index cbfd669e0f..c4b3be823f 100644 --- a/lib/kernels/src/hip/ops/batch_matmul_kernels.cpp +++ b/lib/kernels/src/hip/ops/batch_matmul_kernels.cpp @@ -14,7 +14,7 @@ */ #include "kernels/batch_matmul_kernels.h" -#include "kernels/hip_helper.h" +#include "device.h" #include namespace FlexFlow { @@ -38,13 +38,9 @@ void forward_kernel(hipStream_t stream, int batch, int a_seq_length_dim, int b_seq_length_dim, - int seq_length = -1) { + int seq_length) { checkCUDA(hipblasSetStream(handle.blas, stream)); checkCUDNN(miopenSetStream(handle.dnn, stream)); - - // int a_stride = n * k; - // int b_stride = m * k; - // int o_stride = n * m; int lda = k; int ldb = m; int ldo = m; diff --git a/lib/kernels/src/hip/ops/batch_norm_kernels.cpp b/lib/kernels/src/hip/ops/batch_norm_kernels.cpp index 768cc773c6..8e94b462cd 100644 --- a/lib/kernels/src/hip/ops/batch_norm_kernels.cpp +++ b/lib/kernels/src/hip/ops/batch_norm_kernels.cpp @@ -14,60 +14,48 @@ */ #include "kernels/batch_norm_kernels.h" -#include "kernels/hip_helper.h" +#include "device.h" +#include "kernels/allocation.h" +#include "kernels/ff_handle.h" #include namespace FlexFlow { - -// declare Legion names -using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::Machine; -using Legion::Memory; -using Legion::PhysicalRegion; -using Legion::Rect; -using Legion::Runtime; -using Legion::Task; - -#define MIOPEN_BN_MIN_EPSILON 0.001 - namespace Kernels { namespace BatchNorm { void forward_kernel(hipStream_t stream, - BatchNormPerDeviceState const *m, + BatchNormPerDeviceState const &m, float const *input_ptr, float *output_ptr, float const *scale_ptr, float const *bias_ptr) { - checkCUDNN(miopenSetStream(m->handle.dnn, stream)); + checkCUDNN(miopenSetStream(m.handle.dnn, stream)); float alpha = 1.0f, beta = 0.0f; // coord_t numChannels = m->numChannels; checkCUDNN(miopenBatchNormalizationForwardTraining( - m->handle.dnn, - m->mode, + m.handle.dnn, + m.mode, &alpha, &beta, - m->inputTensor, + m.inputTensor, input_ptr, - m->outputTensor, + m.outputTensor, output_ptr, - m->biasTensor, + m.biasTensor, static_cast(const_cast(scale_ptr)), static_cast(const_cast(bias_ptr)), 1.0, - m->runningMean, - m->runningVar, + m.runningMean, + m.runningVar, MIOPEN_BN_MIN_EPSILON, - m->saveMean, - m->saveVar)); + m.saveMean, + m.saveVar)); } void backward_kernel(hipStream_t stream, - BatchNormPerDeviceState *m, + BatchNormPerDeviceState &m, float const *input_ptr, float *output_grad_ptr, float const *output_ptr, @@ -77,10 +65,10 @@ void backward_kernel(hipStream_t stream, float *bias_grad_ptr, size_t numElements) { - checkCUDNN(miopenSetStream(m->handle.dnn, stream)); + checkCUDNN(miopenSetStream(m.handle.dnn, stream)); float alpha = 1.0f; - if (m->relu) { + if (m.relu) { hipLaunchKernelGGL(reluBackward, GET_BLOCKS(numElements), CUDA_NUM_THREADS, @@ -90,28 +78,28 @@ void backward_kernel(hipStream_t stream, output_ptr, numElements); } - checkCUDNN(miopenBatchNormalizationBackward(m->handle.dnn, - m->mode, + checkCUDNN(miopenBatchNormalizationBackward(m.handle.dnn, + m.mode, &alpha, &alpha, &alpha, &alpha, - m->inputTensor, + m.inputTensor, input_ptr, - m->outputTensor, + m.outputTensor, output_grad_ptr, - m->inputTensor, + m.inputTensor, input_grad_ptr, - m->biasTensor, + m.biasTensor, scale_ptr, scale_grad_ptr, bias_grad_ptr, MIOPEN_BN_MIN_EPSILON, - m->saveMean, - m->saveVar)); + m.saveMean, + m.saveVar)); } -BatchNormPerDeviceState init_kernel(PerDeviceFFHandle handler, +BatchNormPerDeviceState init_kernel(PerDeviceFFHandle handle, Allocator allocator, float *runningMean, int output_n, @@ -128,9 +116,6 @@ BatchNormPerDeviceState init_kernel(PerDeviceFFHandle handler, checkCUDNN(miopenCreateTensorDescriptor(&biasTensor)); checkCUDNN(miopenCreateTensorDescriptor(&outputTensor)); mode = miopenBNSpatial; - // #if HIPDNN_VERSION >= 7000 - // mode = HIPDNN_BATCHNORM_SPATIAL_PERSISTENT; - // #endif fprintf( stderr, "output(%d,%d,%d,%d)\n", output_n, output_c, output_h, output_w); checkCUDNN(miopenSet4dTensorDescriptor( @@ -170,6 +155,23 @@ BatchNormPerDeviceState init_kernel(PerDeviceFFHandle handler, checkCUDNN(miopenSetActivationDescriptor( actiDesc, miopenActivationRELU, 0.0, 0.0, 0.0)); } + + BatchNormPerDeviceState per_device_state = {handle, + inputTensor, + outputTensor, + biasTensor, + actiDesc, + mode, + runningMean, + runningVar, + saveMean, + saveVar, + output_n, + output_c, + output_h, + output_w, + relu}; + return per_device_state; } void cleanup_kernel(Allocator allocator, diff --git a/lib/kernels/src/hip/ops/cast_kernels.cpp b/lib/kernels/src/hip/ops/cast_kernels.cpp index cf0ea83275..fa0c37ffa1 100644 --- a/lib/kernels/src/hip/ops/cast_kernels.cpp +++ b/lib/kernels/src/hip/ops/cast_kernels.cpp @@ -14,8 +14,8 @@ */ #include "kernels/cast_kernels.h" +#include "device.h" #include "kernels/datatype_dispatch.h" -#include "kernels/hip_helper.h" #include namespace FlexFlow { @@ -73,6 +73,7 @@ struct BackwardKernel { }; void forward_kernel(ffStream_t stream, + PerDeviceFFHandle handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, DataType input_type, @@ -82,6 +83,7 @@ void forward_kernel(ffStream_t stream, } void backward_kernel(ffStream_t stream, + PerDeviceFFHandle handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, DataType input_type, diff --git a/lib/kernels/src/hip/ops/combine_kernels.cpp b/lib/kernels/src/hip/ops/combine_kernels.cpp index e3871e587b..aa01f02276 100644 --- a/lib/kernels/src/hip/ops/combine_kernels.cpp +++ b/lib/kernels/src/hip/ops/combine_kernels.cpp @@ -14,8 +14,9 @@ */ #include "kernels/combine_kernels.h" +#include "device.h" +#include "kernels/accessor.h" #include "kernels/datatype_dispatch.h" -#include "kernels/hip_helper.h" #include namespace FlexFlow { diff --git a/lib/kernels/src/hip/ops/concat_kernels.cpp b/lib/kernels/src/hip/ops/concat_kernels.cpp index 6eac034a4b..aa38be739b 100644 --- a/lib/kernels/src/hip/ops/concat_kernels.cpp +++ b/lib/kernels/src/hip/ops/concat_kernels.cpp @@ -14,7 +14,8 @@ */ #include "kernels/concat_kernels.h" -#include "kernels/hip_helper.h" +#include "device.h" +#include #include namespace FlexFlow { diff --git a/lib/kernels/src/hip/ops/conv_2d_kernels.cpp b/lib/kernels/src/hip/ops/conv_2d_kernels.cpp index 4d26c20c46..d36da8140b 100644 --- a/lib/kernels/src/hip/ops/conv_2d_kernels.cpp +++ b/lib/kernels/src/hip/ops/conv_2d_kernels.cpp @@ -19,6 +19,43 @@ namespace FlexFlow { namespace Kernels { namespace Conv2D { +miopenConvBwdDataAlgorithm_t selectConvolutionBackwardDataAlgorithm( + miopenHandle_t handle, + const miopenTensorDescriptor_t wDesc, + void const *w, + const miopenTensorDescriptor_t dyDesc, + void const *dy, + const miopenConvolutionDescriptor_t convDesc, + void *workSpace, + size_t workSpaceSize, + const miopenTensorDescriptor_t dxDesc, + void *dx, + float *time) { + int const reqAlgCnt = 8; + int cnt = 0; + miopenConvAlgoPerf_t perfResults[reqAlgCnt]; + checkCUDNN(miopenFindConvolutionBackwardDataAlgorithm(handle, + dyDesc, + dy, + wDesc, + w, + convDesc, + dxDesc, + dx, + reqAlgCnt, + &cnt, + perfResults, + workSpace, + workSpaceSize, + false)); + assert(cnt > 0); + // checkCUDNN(perfResults[0].status); + if (time != nullptr) { + *time = perfResults[0].time; + } + return perfResults[0].bwd_data_algo; +} + miopenConvFwdAlgorithm_t selectConvolutionForwardAlgorithm( miopenHandle_t handle, const miopenTensorDescriptor_t xDesc, @@ -49,7 +86,7 @@ miopenConvFwdAlgorithm_t selectConvolutionForwardAlgorithm( workSpaceSize, false)); assert(cnt > 0); - // checkCUDNN(perfResults[0].status); + checkCUDNN(perfResults[0].status); if (time != nullptr) { *time = perfResults[0].time; } @@ -86,50 +123,13 @@ miopenConvBwdWeightsAlgorithm_t selectConvolutionBackwardFilterAlgorithm( workSpaceSize, false)); assert(cnt > 0); - // checkCUDNN(perfResults[0].status); + checkCUDNN(perfResults[0].status); if (time != nullptr) { *time = perfResults[0].time; } return perfResults[0].bwd_weights_algo; } -miopenConvBwdDataAlgorithm_t selectConvolutionBackwardDataAlgorithm( - miopenHandle_t handle, - const miopenTensorDescriptor_t wDesc, - void const *w, - const miopenTensorDescriptor_t dyDesc, - void const *dy, - const miopenConvolutionDescriptor_t convDesc, - void *workSpace, - size_t workSpaceSize, - const miopenTensorDescriptor_t dxDesc, - void *dx, - float *time) { - int const reqAlgCnt = 8; - int cnt = 0; - miopenConvAlgoPerf_t perfResults[reqAlgCnt]; - checkCUDNN(miopenFindConvolutionBackwardDataAlgorithm(handle, - dyDesc, - dy, - wDesc, - w, - convDesc, - dxDesc, - dx, - reqAlgCnt, - &cnt, - perfResults, - workSpace, - workSpaceSize, - false)); - assert(cnt > 0); - // checkCUDNN(perfResults[0].status); - if (time != nullptr) { - *time = perfResults[0].time; - } - return perfResults[0].bwd_data_algo; -} - Conv2DPerDeviceState init_kernel(PerDeviceFFHandle handle, std::optional activation, int kernel_h, @@ -182,25 +182,19 @@ Conv2DPerDeviceState init_kernel(PerDeviceFFHandle handle, checkCUDNN(miopenSet4dTensorDescriptor( filterDesc, miopenFloat, output_c, input_c / groups, kernel_h, kernel_w)); - checkCUDNN(miopenInitConvolutionDescriptor(convDesc, - miopenConvolution, - pad_h, // conv->padding_h, - pad_w, // conv->padding_w, - stride_h, - stride_w, - 1 /*upscale_x*/, - 1 /*upscale_y*/)); + checkCUDNN(miopenInitConvolutionDescriptor( + convDesc, miopenConvolution, pad_h, pad_w, stride_h, stride_w, 1, 1)); + if (groups != 1) { checkCUDNN(miopenSetConvolutionGroupCount(convDesc, groups)); } // TODO: enable tensor core when possible if (handle.allowTensorOpMathConversion) { - // checkCUDNN(hipdnnSetConvolutionMathType(m->convDesc, - // CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION)); + checkCUDNN(hipdnnSetConvolutionMathType( + m.convDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION)); } else { - // checkCUDNN(hipdnnSetConvolutionMathType(m->convDesc, - // HIPDNN_TENSOR_OP_MATH)); + checkCUDNN(hipdnnSetConvolutionMathType(m.convDesc, HIPDNN_TENSOR_OP_MATH)); } int n, c, h, w; @@ -298,7 +292,6 @@ void forward_kernel(hipStream_t stream, m.handle.workSpace, m.handle.workSpaceSize)); - // use_bias == True if (bias_ptr != NULL) { checkCUDNN(miopenConvolutionForwardBias(m.handle.dnn, &alpha,