diff --git a/lib/kernels/src/hip/softmax_kernels.cpp b/lib/kernels/src/hip/softmax_kernels.cpp index e7bec53962..3a8f2813b7 100644 --- a/lib/kernels/src/hip/softmax_kernels.cpp +++ b/lib/kernels/src/hip/softmax_kernels.cpp @@ -14,40 +14,36 @@ */ #include "kernels/softmax_kernels.h" -#include "kernels/hip_helper.h" +#include "device.h" #include namespace FlexFlow { -// declare Legion names -using Legion::Domain; - -SoftmaxPerDeviceState::SoftmaxPerDeviceState(FFHandler handler, - Softmax const *softmax, - Domain const &input_domain) - : PerDeviceOpState(handler) { - checkCUDNN(miopenCreateTensorDescriptor(&inputTensor)); - checkCUDNN(cudnnSetTensorDescriptorFromDomain(inputTensor, input_domain)); - dim = softmax->dim; - profiling = softmax->profiling; - std::strcpy(op_name, softmax->name); -} namespace Kernels { namespace Softmax { +SoftmaxPerDeviceState init_kernel(PerDeviceFFHandle const &handle, int dim) { + ffTensorDescriptor_t inputTensor; + + checkCUDNN(miopenCreateTensorDescriptor(&inputTensor)); + + SoftmaxPerDeviceState per_device_state = {handle, inputTensor, dim}; + return per_device_state; +} + void forward_kernel(hipStream_t stream, - SoftmaxPerDeviceState const *m, + SoftmaxPerDeviceState const &m, float const *input_ptr, float *output_ptr) { - checkCUDNN(miopenSetStream(m->handle.dnn, stream)); + checkCUDNN(miopenSetStream(m.handle.dnn, stream)); float alpha = 1.0f, beta = 0.0f; - checkCUDNN(miopenSoftmaxForward_V2(m->handle.dnn, + checkCUDNN(miopenSoftmaxForward_V2(m.handle.dnn, &alpha, - m->inputTensor, + m.inputTensor, input_ptr, &beta, - m->inputTensor, + m.inputTensor, output_ptr, MIOPEN_SOFTMAX_ACCURATE, MIOPEN_SOFTMAX_MODE_CHANNEL)); diff --git a/lib/kernels/src/hip/split_kernels.cpp b/lib/kernels/src/hip/split_kernels.cpp index 439e715c88..5599ae6d6f 100644 --- a/lib/kernels/src/hip/split_kernels.cpp +++ b/lib/kernels/src/hip/split_kernels.cpp @@ -14,12 +14,10 @@ */ #include "kernels/split_kernels.h" -#include "kernels/hip_helper.h" +#include "device.h" #include namespace FlexFlow { -// declare Legion names -using Legion::coord_t; namespace Kernels { namespace Split { diff --git a/lib/kernels/src/hip/topk_kernels.cpp b/lib/kernels/src/hip/topk_kernels.cpp index 4c9fa4f037..f085c5831f 100644 --- a/lib/kernels/src/hip/topk_kernels.cpp +++ b/lib/kernels/src/hip/topk_kernels.cpp @@ -14,15 +14,10 @@ */ #include "kernels/topk_kernels.h" -#include "kernels/hip_helper.h" +#include "device.h" #include namespace FlexFlow { -// declare Legion names -using Legion::coord_t; - -TopKPerDeviceState::TopKPerDeviceState(FFHandler handler) - : PerDeviceOpState(handler) {} namespace Kernels { namespace TopK { @@ -36,6 +31,11 @@ struct Entry { T value; }; +TopKPerDeviceState init_kernel(bool sorted) { + TopKPerDeviceState per_device_state = {sorted}; + return per_device_state; +} + template struct LinearData { typedef Entry Entry; @@ -371,7 +371,7 @@ __global__ void topk_forward_kernel(T const *__restrict__ input, } void forward_kernel(hipStream_t stream, - TopKPerDeviceState const *m, + TopKPerDeviceState const &m, float const *input_ptr, float *output_ptr, int *indices_ptr, @@ -428,7 +428,7 @@ __global__ void topk_backward_kernel(T const *__restrict__ value_grad_ptr, } void backward_kernel(hipStream_t stream, - TopKPerDeviceState const *m, + TopKPerDeviceState const &m, float const *value_grad_ptr, int const *indices_ptr, float *in_grad_ptr, diff --git a/lib/kernels/src/hip/transpose_kernels.cpp b/lib/kernels/src/hip/transpose_kernels.cpp index de64c74719..ef9dd58c63 100644 --- a/lib/kernels/src/hip/transpose_kernels.cpp +++ b/lib/kernels/src/hip/transpose_kernels.cpp @@ -14,13 +14,12 @@ */ #include "kernels/transpose_kernels.h" -#include "kernels/hip_helper.h" +#include "device.h" +#include "kernels/accessor.h" +#include "utils/exception.h" #include namespace FlexFlow { -// declare Legion names -using Legion::coord_t; -using Legion::Domain; struct TransposeStrides { int num_dim; @@ -31,81 +30,103 @@ struct TransposeStrides { namespace Kernels { namespace Transpose { +TransposePerDeviceState init_kernel(int num_dim, + std::vector const &perm) { + int const length = perm.size(); + + std::vector perm_vector; + assert(length <= MAX_TENSOR_DIM); + for (int i = 0; i < length; ++i) { + perm_vector.push_back(perm[i].value()); + } + + return {num_dim, perm_vector}; +} + +__global__ void transpose_simple_kernel(std::size_t volume, + float const *in_ptr, + float *out_ptr, + const TransposeStrides info, + float const beta) { + CUDA_KERNEL_LOOP(o_idx, volume) { + coord_t i_idx = 0; + coord_t t = o_idx; + for (int i = info.num_dim - 1; i >= 0; i--) { + coord_t ratio = t / info.out_strides[i]; + t -= ratio * info.out_strides[i]; + i_idx += ratio * info.in_strides[info.perm[i]]; + } + out_ptr[o_idx] += out_ptr[o_idx] * beta + in_ptr[i_idx]; + } +} + void forward_kernel(hipStream_t stream, - TransposePerDeviceState const *m, - float const *input_ptr, - float *output_ptr, - Domain in_domain, - Domain out_domain) { + TransposePerDeviceState const &m, + GenericTensorAccessorW const &in_grad, + GenericTensorAccessorR const &out_grad) { TransposeStrides info; - info.num_dim = out_domain.get_dim(); - assert(info.num_dim == m->num_dim); + info.num_dim = in_grad.shape.num_dims(); + assert(info.num_dim == m.num_dim); for (int i = 0; i < info.num_dim; i++) { - int in_dim_size = (in_domain.hi()[i] - in_domain.lo()[i] + 1); - int out_dim_size = (out_domain.hi()[i] - out_domain.lo()[i] + 1); - info.in_strides[i] = (i == 0) ? 1 : info.in_strides[i - 1] * in_dim_size; - info.out_strides[i] = (i == 0) ? 1 : info.out_strides[i - 1] * out_dim_size; - info.perm[i] = m->perm[i]; + if (i == 0) { + info.in_strides[i] = 1; + info.out_strides[i] = 1; + } else { + int in_dim_size = input.shape[legion_dim_t(i)] + 1; + int out_dim_size = output.shape[legion_dim_t(i)] + 1; + info.in_strides[i] = info.in_strides[i - 1] * in_dim_size; + info.out_strides[i] = info.out_strides[i - 1] * out_dim_size; + } + info.perm[i] = m.perm[i]; } + hipLaunchKernelGGL(transpose_simple_kernel, - GET_BLOCKS(out_domain.get_volume()), + GET_BLOCKS(output.shape.get_volume()), CUDA_NUM_THREADS, 0, stream, - out_domain.get_volume(), - input_ptr, - output_ptr, + output.shape.get_volume(), + input.get_float_ptr(), + output.get_float_ptr(), info, 0.0f /*beta*/); } void backward_kernel(hipStream_t stream, - TransposePerDeviceState const *m, + TransposePerDeviceState const &m, float *input_grad_ptr, float const *output_grad_ptr, Domain in_grad_domain, Domain out_grad_domain) { TransposeStrides info; - info.num_dim = in_grad_domain.get_dim(); - assert(info.num_dim == m->num_dim); + info.num_dim = in_grad.shape.num_dims(); + assert(info.num_dim == m.num_dim); for (int i = 0; i < info.num_dim; i++) { - int in_dim_size = (out_grad_domain.hi()[i] - out_grad_domain.lo()[i] + 1); - int out_dim_size = (in_grad_domain.hi()[i] - in_grad_domain.lo()[i] + 1); - info.in_strides[i] = (i == 0) ? 1 : info.in_strides[i - 1] * in_dim_size; - info.out_strides[i] = (i == 0) ? 1 : info.out_strides[i - 1] * out_dim_size; - info.perm[m->perm[i]] = i; + if (i == 0) { + info.in_strides[i] = 1; + info.out_strides[i] = 1; + } else { + int in_dim_size = out_grad.shape[legion_dim_t(i)] + 1; + int out_dim_size = in_grad.shape[legion_dim_t(i)] + 1; + info.in_strides[i] = info.in_strides[i - 1] * in_dim_size; + info.out_strides[i] = info.out_strides[i - 1] * out_dim_size; + } + info.perm[m.perm[i]] = i; } hipLaunchKernelGGL(transpose_simple_kernel, - GET_BLOCKS(in_grad_domain.get_volume()), + GET_BLOCKS(in_grad.shape.get_volume()), CUDA_NUM_THREADS, 0, stream, - in_grad_domain.get_volume(), - output_grad_ptr, - input_grad_ptr, + in_grad.shape.get_volume(), + out_grad.get_float_ptr(), + in_grad.get_float_ptr(), info, 1.0f /*beta*/); } -__global__ void transpose_simple_kernel(coord_t volume, - float const *in_ptr, - float *out_ptr, - const TransposeStrides info, - float const beta) { - CUDA_KERNEL_LOOP(o_idx, volume) { - coord_t i_idx = 0; - coord_t t = o_idx; - for (int i = info.num_dim - 1; i >= 0; i--) { - coord_t ratio = t / info.out_strides[i]; - t -= ratio * info.out_strides[i]; - i_idx += ratio * info.in_strides[info.perm[i]]; - } - out_ptr[o_idx] += out_ptr[o_idx] * beta + in_ptr[i_idx]; - } -} - } // namespace Transpose } // namespace Kernels } // namespace FlexFlow