diff --git a/lib/kernels/src/hip/partition_kernels.cpp b/lib/kernels/src/hip/partition_kernels.cpp index 3761da5c84..4591247faa 100644 --- a/lib/kernels/src/hip/partition_kernels.cpp +++ b/lib/kernels/src/hip/partition_kernels.cpp @@ -14,21 +14,17 @@ */ #include "kernels/partition_kernels.h" +#include "device.h" #include "kernels/datatype_dispatch.h" -#include "kernels/hip_helper.h" #include namespace FlexFlow { - -RepartitionPerDeviceState::RepartitionPerDeviceState(FFHandler handler) - : PerDeviceOpState(handler) {} - namespace Kernels { namespace Repartition { tempate struct ForwardKernel { void operator()(hipStream_t stream, - RepartitionPerDeviceState const *m, + RepartitionPerDeviceState const &m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { checkCUDA(hipMemcpyAsync(output.get(), @@ -41,7 +37,7 @@ tempate struct ForwardKernel { tempate struct BackwardKernel { void operator()(hipStream_t stream, - RepartitionPerDeviceState const *m, + RepartitionPerDeviceState const &m, GenericTensorAccessorR const &output_grad, GenericTensorAccessorW const &input_grad) { hipLaunchKernelGGL(HIP_KERNEL_NAME(add_kernel), @@ -55,19 +51,25 @@ tempate struct BackwardKernel { } } +RepartitionPerDeviceState + init_kernel(PerDeviceFFHandle const &handle, DataType data_type) { + RepartitionPerDeviceState per_device_state = {handle, data_type}; + return per_device_state; +} + void forward_kernel(hipStream_t stream, - RepartitionPerDeviceState const *m, + RepartitionPerDeviceState const &m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { - DataTypeDispatch1{}(m->data_type, stream, m, input, output) + DataTypeDispatch1{}(m.data_type, stream, m, input, output) } void backward_kernel(hipStream_t stream, - RepartitionPerDeviceState const *m, + RepartitionPerDeviceState const &m, GenericTensorAccessorR const &output_grad, GenericTensorAccessorW const &input_grad) { DataTypeDispatch1{}( - m->data_type, stream, m, input_grad, output_grad) + m.data_type, stream, m, input_grad, output_grad) } } // namespace Repartition diff --git a/lib/kernels/src/hip/pool_2d_kernels.cpp b/lib/kernels/src/hip/pool_2d_kernels.cpp index 0bb44c3e1a..ed942c105c 100644 --- a/lib/kernels/src/hip/pool_2d_kernels.cpp +++ b/lib/kernels/src/hip/pool_2d_kernels.cpp @@ -14,116 +14,122 @@ */ #include "kernels/pool_2d_kernels.h" -#include "kernels/hip_helper.h" +#include "device.h" +#include namespace FlexFlow { -Pool2DPerDeviceState::Pool2DPerDeviceState(FFHandler handler) - : PerDeviceOpState(handler) { +namespace Kernels { +namespace Pool2D { + +Pool2DPerDeviceState init_kernel(PerDeviceFFHandle handle, + optional activation, + int input_w, + int input_h, + int input_c, + int input_n, + int output_w, + int output_h, + int output_c, + int output_n, + int pad_h, + int pad_w, + int kernel_h, + int kernel_w, + int stride_h, + int stride_w, + PoolOp pool_type) { + ffTensorDescriptor_t inputTensor; + ffTensorDescriptor_t outputTensor; + ffPoolingDescriptor_t poolDesc; + ffActivationDescriptor_t actiDesc; + checkCUDNN(miopenCreateTensorDescriptor(&inputTensor)); checkCUDNN(miopenCreateTensorDescriptor(&outputTensor)); checkCUDNN(miopenCreatePoolingDescriptor(&poolDesc)); -} + checkCUDNN(miopenCreateActivationDescriptor(&actiDesc)); -namespace Kernels { -namespace Pool2D { - -void init_kernel(Pool2DPerDeviceState *m, - int input_w, - int input_h, - int input_c, - int input_n, - int output_w, - int output_h, - int output_c, - int output_n, - int pad_h, - int pad_w, - int kernel_h, - int kernel_w, - int stride_h, - int stride_w, - PoolType pool_type) { checkCUDNN(miopenSet4dTensorDescriptor( - m->inputTensor, miopenFloat, input_n, input_c, input_h, input_w)); - + inputTensor, miopenFloat, input_n, input_c, input_h, input_w)); miopenPoolingMode_t mode; - if (pool_type == POOL_MAX) { + if (pool_type == PoolOp::MAX) { mode = miopenPoolingMax; } else { - assert(pool_type == POOL_AVG); + assert(pool_type == PoolOp::AVG); mode = miopenPoolingAverage; } - checkCUDNN(miopenSet2dPoolingDescriptor( - m->poolDesc, mode, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w)); + + checkCUDNN(miopenSetPooling2dDescriptor( + poolDesc, mode, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w)); + int n, c, h, w; - checkCUDNN(miopenGetPoolingForwardOutputDim( - m->poolDesc, m->inputTensor, &n, &c, &h, &w)); + checkCUDNN(miopenGetPooling2dForwardOutputDim( + poolDesc, inputTensor, &n, &c, &h, &w)); assert(n == output_n); assert(c == output_c); assert(h == output_h); assert(w == output_w); checkCUDNN( - miopenSet4dTensorDescriptor(m->outputTensor, miopenFloat, n, c, h, w)); + miopenSet4dTensorDescriptor(outputTensor, miopenFloat, n, c, h, w)); + bool relu = false; + if (activation == Activation::RELU) { + relu = true; + } + Pool2DPerDeviceState state = { + handle, + inputTensor, + outputTensor, + actiDesc, + poolDesc, + relu, + }; + return state; } void forward_kernel(hipStream_t stream, - Pool2DPerDeviceState const *m, + Pool2DPerDeviceState const &m, void const *input_ptr, void *output_ptr) { - checkCUDNN(miopenSetStream(m->handle.dnn, stream)); + checkCUDNN(miopenSetStream(m.handle.dnn, stream)); float alpha = 1.0f, beta = 0.0f; - checkCUDNN(miopenPoolingForward(m->handle.dnn, - m->poolDesc, + checkCUDNN(miopenPoolingForward(m.handle.dnn, + m.poolDesc, &alpha, - m->inputTensor, + m.inputTensor, input_ptr, &beta, - m->outputTensor, + m.outputTensor, output_ptr, true, - m->handle.workSpace, - m->handle.workSpaceSize)); - if (m->profiling) { - hipEventRecord(t_end, stream); - checkCUDA(hipEventSynchronize(t_end)); - // print_tensor<4, float>(acc_input.ptr, acc_input.rect, - // "[Pool2D:forward:input]"); print_tensor<4, float>(acc_output.ptr, - // acc_output.rect, "[Pool2D:forward:output]"); - float elapsed = 0; - checkCUDA(hipEventElapsedTime(&elapsed, t_start, t_end)); - hipEventDestroy(t_start); - hipEventDestroy(t_end); - printf("%s [Pool2D] forward time = %.2fms\n", m->op_name, elapsed); - } + m.handle.workSpace, + m.handle.workSpaceSize)); } void backward_kernel(hipStream_t stream, - Pool2DPerDeviceState const *m, + Pool2DPerDeviceState const &m, void const *input_ptr, void *input_grad_ptr, void const *output_ptr, void const *output_grad_ptr) { - checkCUDNN(miopenSetStream(m->handle.dnn, stream)); + checkCUDNN(miopenSetStream(m.handle.dnn, stream)); float alpha = 1.0f; - float beta = 0.0f; - checkCUDNN(miopenPoolingBackward(m->handle.dnn, - m->poolDesc, + checkCUDNN(miopenPoolingBackward(m.handle.dnn, + m.poolDesc, &alpha, - m->outputTensor, + m.outputTensor, output_ptr, - m->outputTensor, + m.outputTensor, output_grad_ptr, - m->inputTensor, + m.inputTensor, input_ptr, &beta, - m->inputTensor, + m.inputTensor, input_grad_ptr, - m->handle.workSpace)); + m.handle.workSpace)); } } // namespace Pool2D