Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 42 additions & 61 deletions lib/kernels/src/hip/ops/reduce_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,108 +14,89 @@
*/

#include "kernels/reduce_kernels.h"
#include "kernels/hip_helper.h"
#include "device.h"
#include <hip/hip_runtime.h>

namespace FlexFlow {
// declare Legion names
using Legion::coord_t;
using Legion::Domain;
namespace Kernels {
namespace Reduce {

ReducePerDeviceState init_kernel(PerDeviceFFHandle const &handle,
OperatorType const &op_type,
size_t const &reduction_size,
ArrayShape const &input_shape,
ArrayShape const &output_shape) {
ffTensorDescriptor_t inputTensor ffTensorDescriptor_t outputTensor;
ffReduceTensorDescriptor_t reduceDesc;

ReducePerDeviceState::ReducePerDeviceState(FFHandler handler,
Reduce const *rd,
Domain const &input_domain)
: op_type(rd->op_type), PerDeviceOpState(handler) {
checkCUDNN(miopenCreateReduceTensorDescriptor(&reduceDesc));
checkCUDNN(miopenCreateTensorDescriptor(&inputTensor));
checkCUDNN(miopenCreateTensorDescriptor(&outputTensor));
cudnnReduceTensorOp_t reduce_op;
switch (rd->op_type) {
case OP_REDUCE_SUM:
reduce_op = CUDNN_REDUCE_TENSOR_ADD;
break;
case OP_REDUCE_MEAN:
reduce_op = CUDNN_REDUCE_TENSOR_AVG;
break;
default:
assert(false);
}
checkCUDNN(miopenSetReduceTensorDescriptor(reduceDesc,
MIOPEN_REDUCE_TENSOR_ADD,
miopenFloat,
MIOPEN_PROPAGATE_NAN,
MIOPEN_REDUCE_TENSOR_NO_INDICES,
MIOPEN_32BIT_INDICES));
checkCUDNN(cudnnSetTensorDescriptorFromDomain(inputTensor, input_domain));
Domain output_domain = input_domain;
for (size_t i = 0; i < rd->num_axes; i++) {
assert(input_domain.dim > rd->axes[i]);
output_domain.rect_data[rd->axes[i] + output_domain.dim] =
output_domain.rect_data[rd->axes[i]];
}
assert(output_domain.get_volume() % input_domain.get_volume() == 0);
reduction_size = input_domain.get_volume() / output_domain.get_volume();
assert(reduction_size > 0);
checkCUDNN(cudnnSetTensorDescriptorFromDomain(outputTensor, output_domain));
}
checkCUDNN(miopenCreateReduceTensorDescriptor(&reduceDesc));

ReducePerDeviceState::~ReducePerDeviceState(void) {
checkCUDNN(miopenDestroyReduceTensorDescriptor(reduceDesc));
checkCUDNN(miopenDestroyTensorDescriptor(inputTensor));
checkCUDNN(miopenDestroyTensorDescriptor(outputTensor));
}
checkCUDNN(miopenSetTensorDescriptor(inputTensor,
miopenFloat,
input_shape.dims.size(),
input_shape.dims.data(),
input_shape.strides.data()));
checkCUDNN(miopenSetTensorDescriptor(outputTensor,
miopenFloat,
output_shape.dims.size(),
output_shape.dims.data(),
output_shape.strides.data()));

namespace Kernels {
namespace Reduce {
ReducePerDeviceState per_device = {
handle, inputTensor, outputTensor, reduceDesc, op_type, reduction_size};
return per_device;
}

void forward_kernel(hipStream_t stream,
ReducePerDeviceState const *m,
ReducePerDeviceState const &m,
float const *input_ptr,
float *output_ptr) {
checkCUDNN(miopenSetStream(m->handle.dnn, stream));
checkCUDNN(miopenSetStream(m.handle.dnn, stream));
float alpha = 1.0f, beta = 0.0f;
checkCUDNN(miopenReduceTensor(m->handle.dnn,
m->reduceDesc,
checkCUDNN(miopenReduceTensor(m.handle.dnn,
m.reduceDesc,
nullptr /*indices*/,
0 /*indicesSizeInBytes*/,
m->handle.workSpace,
m->handle.workSpaceSize,
m.handle.workSpace,
m.handle.workSpaceSize,
&alpha,
m->inputTensor,
m.inputTensor,
input_ptr,
&beta,
m->outputTensor,
m.outputTensor,
output_ptr));
};

void backward_kernel(hipStream_t stream,
ReducePerDeviceState const *m,
ReducePerDeviceState const &m,
float const *output_grad_ptr,
float *input_grad_ptr) {
checkCUDNN(miopenSetStream(m->handle.dnn, stream));
checkCUDNN(miopenSetStream(m.handle.dnn, stream));
float alpha = 1.0f, beta = 0.0f;
switch (m->op_type) {
switch (m.op_type) {
case OP_REDUCE_SUM:
alpha = 1.0f;
break;
case OP_REDUCE_MEAN:
// When the output is the average of multiple input elements
// we need to scale the gradients by 1.0 / reduction_size
alpha = 1.0f / m->reduction_size;
alpha = 1.0f / m.reduction_size;
break;
default:
assert(false);
}
checkCUDNN(miopenOpTensor(m->handle.dnn,
checkCUDNN(miopenOpTensor(m.handle.dnn,
miopenTensorOpAdd,
&alpha,
m->inputTensor,
m.inputTensor,
input_grad_ptr,
&alpha,
m->outputTensor,
m.outputTensor,
output_grad_ptr,
&beta,
m->inputTensor,
m.inputTensor,
input_grad_ptr));
}

Expand Down
8 changes: 4 additions & 4 deletions lib/kernels/src/hip/ops/reduction_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ __global__ void reduction_forward_kernel(T const *input_ptr,

template <DataType T>
struct ForwardKernel {
void operator()(cudaStream_t stream,
void operator()(hipStream_t stream,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output,
size_t num_replicas) {
Expand All @@ -57,7 +57,7 @@ struct ForwardKernel {

template <DataType T>
struct BackwardKernel {
void operator()(cudaStream_t stream,
void operator()(hipStream_t stream,
GenericTensorAccessorW const &input,
GenericTensorAccessorR const &output) {
checkCUDA(hipMemcpyAsync(input.get<T>(),
Expand All @@ -73,13 +73,13 @@ void forward_kernel(hipStream_t stream,
GenericTensorAccessorW const &output,
size_t num_replicas) {
DataTypeDispatch1<ForwardKernel>{}(
input->data_type, stream, input, output, num_replicas);
input.data_type, stream, input, output, num_replicas);
}

void backward_kernel(hipStream_t stream,
GenericTensorAccessorW const &input,
GenericTensorAccessorR const &output) {
DataTypeDispatch1<BackwardKernel>{}(input->data_type, stream, input, output);
DataTypeDispatch1<BackwardKernel>{}(input.data_type, stream, input, output);
}

} // namespace Reduction
Expand Down
35 changes: 18 additions & 17 deletions lib/kernels/src/hip/ops/replicate_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,26 @@
*/

#include "kernels/replicate_kernels.h"
#include "kernels/hip_helper.h"
#include "device.h"
#include "kernels/datatype_dispatch.h"
#include <hip/hip_runtime.h>

namespace FlexFlow {
namespace Kernels {
namespace Replicate {

template <typename T>
__global__ void replicate_backward_kernel(T const *input_ptr,
T *output_ptr,
size_t num_elements,
size_t num_replicas) {
CUDA_KERNEL_LOOP(i, num_elements) {
for (size_t j = 0; j < num_replicas; j++) {
output_ptr[i] += input_ptr[i + j * num_elements];
}
}
}

template <DataType T>
struct ForwardKernel {
void operator()(hipStream_t stream,
Expand All @@ -29,7 +42,7 @@ struct ForwardKernel {

checkCUDA(hipMemcpyAsync(input.get<T>(),
output.get<T>(),
input.shape.num_elements() * sizeof(T),
input.shape.num_elements() * size_of_datatype(T),
hipMemcpyDeviceToDevice,
stream));
}
Expand All @@ -42,7 +55,7 @@ struct BackwardKernel {
GenericTensorAccessorR const &output,
size_t num_replicas) {
size_t total_elements = input.shape.num_elements() * num_replicas;
hipLaunchKernelGGL(HIP_KERNEL_NAME(replicate_backward_kernel<T>),
hipLaunchKernelGGL(HIP_KERNEL_NAME(replicate_backward_kernel<real_type<T>>),
GET_BLOCKS(total_elements),
CUDA_NUM_THREADS,
0,
Expand All @@ -54,30 +67,18 @@ struct BackwardKernel {
}
}

template <typename T>
__global__ void replicate_backward_kernel(T const *input_ptr,
T *output_ptr,
size_t num_elements,
size_t num_replicas) {
CUDA_KERNEL_LOOP(i, num_elements) {
for (size_t j = 0; j < num_replicas; j++) {
output_ptr[i] += input_ptr[i + j * num_elements];
}
}
}

void forward_kernel(hipStream_t stream,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output) {
DataTypeDispatch1<ForwardKernel>{}(input->data_type, stream, input, output);
DataTypeDispatch1<ForwardKernel>{}(input.data_type, stream, input, output);
}

void backward_kernel(hipStream_t stream,
GenericTensorAccessorW const &input,
GenericTensorAccessorR const &output,
size_t num_replicas) {
DataTypeDispatch1<BackwardKernel>{}(
input->data_type, stream, input, output, num_replicas);
input.data_type, stream, input, output, num_replicas);
}

} // namespace Replicate
Expand Down
25 changes: 13 additions & 12 deletions lib/kernels/src/hip/ops/reshape_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,27 @@
*/

#include "kernels/reshape_kernels.h"
#include "device.h"
#include "kernels/datatype_dispatch.h"
#include "kernels/hip_helper.h"
#include <hip/hip_runtime.h>

namespace FlexFlow {

ReshapePerDeviceState::ReshapePerDeviceState(FFHandler handler)
: PerDeviceOpState(handler) {}

namespace Kernels {
namespace Reshape {

ReshapePerDeviceState init_kernel(DataType data_type) {
return ReshapePerDeviceState{data_type};
}

template <DataType T>
struct ForwardKernel {
void operator()(hipStream_t stream,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output) {
checkCUDA(hipMemcpyAsync(output.get<T>(),
input.get<T>(),
input.shape.num_elements() * sizeof(T),
input.shape.num_elements() * size_of_datatype(T),
hipMemcpyDeviceToDevice,
stream));
}
Expand All @@ -42,34 +43,34 @@ struct ForwardKernel {
template <DataType T>
struct BackwardKernel {
void operator()(hipStream_t stream,
ReshapePerDeviceState const *m,
ReshapePerDeviceState const &m,
GenericTensorAccessorW const &input,
GenericTensorAccessorR const &output) {
float alpha = 1.0f;
hipLaunchKernelGGL(HIP_KERNEL_NAME(apply_add_with_scale<T>),
hipLaunchKernelGGL(HIP_KERNEL_NAME(apply_add_with_scale<real_type<T>>),
GET_BLOCKS(input.shape.num_elements()),
CUDA_NUM_THREADS,
0,
stream,
input.get<T>(),
output.get<T>(),
input.shape.num_elements(),
(T)alpha);
static_cast<real_type<T>> alpha);
}
}

void forward_kernel(hipStream_t stream,
ReshapePerDeviceState const *m,
ReshapePerDeviceState const &m,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output) {
DataTypeDispatch1<ForwardKernel>{}(m->data_type, stream, m, input, output);
DataTypeDispatch1<ForwardKernel>{}(m.data_type, stream, m, input, output);
}

void backward_kernel(hipStream_t stream,
ReshapePerDeviceState const *m,
ReshapePerDeviceState const &m,
GenericTensorAccessorW const &input,
GenericTensorAccessorR const &output) {
DataTypeDispatch1<BackwardKernel>{}(m->data_type, stream, m, input, output);
DataTypeDispatch1<BackwardKernel>{}(m.data_type, stream, m, input, output);
}

} // namespace Reshape
Expand Down
36 changes: 17 additions & 19 deletions lib/kernels/src/hip/ops/reverse_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,30 @@
*/

#include "kernels/reverse_kernels.h"
#include "kernels/hip_helper.h"
#include "device.h"
#include <hip/hip_runtime.h>

namespace FlexFlow {
// declare Legion names
using Legion::coord_t;

namespace Kernels {
namespace Reverse {

__global__ void reverse_forward_kernel(float const *in_ptr,
float *out_ptr,
coord_t num_out_blks,
coord_t reverse_dim_size,
coord_t in_blk_size) {
CUDA_KERNEL_LOOP(i, num_out_blks * reverse_dim_size * in_blk_size) {
coord_t blk_idx = i / (reverse_dim_size * in_blk_size);
i = i - blk_idx * (reverse_dim_size * in_blk_size);
coord_t reverse_dim_idx = i / in_blk_size;
i = i - reverse_dim_idx * in_blk_size;
coord_t in_idx = blk_idx * (reverse_dim_size * in_blk_size) +
(reverse_dim_size - 1 - reverse_dim_idx) * in_blk_size + i;
out_ptr[i] = in_ptr[in_idx];
}
}

void forward_kernel(hipStream_t stream,
float const *in_ptr,
float *out_ptr,
Expand Down Expand Up @@ -64,22 +78,6 @@ void backward_kernel(hipStream_t stream,
in_blk_size);
}

__global__ void reverse_forward_kernel(float const *in_ptr,
float *out_ptr,
coord_t num_out_blks,
coord_t reverse_dim_size,
coord_t in_blk_size) {
CUDA_KERNEL_LOOP(i, num_out_blks * reverse_dim_size * in_blk_size) {
coord_t blk_idx = i / (reverse_dim_size * in_blk_size);
i = i - blk_idx * (reverse_dim_size * in_blk_size);
coord_t reverse_dim_idx = i / in_blk_size;
i = i - reverse_dim_idx * in_blk_size;
coord_t in_idx = blk_idx * (reverse_dim_size * in_blk_size) +
(reverse_dim_size - 1 - reverse_dim_idx) * in_blk_size + i;
out_ptr[i] = in_ptr[in_idx];
}
}

} // namespace Reverse
} // namespace Kernels
} // namespace FlexFlow