Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
c5e3202
add init_kernel for reduction
lambda7xx Sep 8, 2023
2ce00f0
implement the init
lambda7xx Sep 8, 2023
356cdde
add forward API
lambda7xx Sep 8, 2023
c45a318
add backward API
lambda7xx Sep 8, 2023
074f1fe
implement measure_operator_cost and Reduce version 0.1
lambda7xx Sep 8, 2023
db204dd
format the code
lambda7xx Sep 8, 2023
bf95bc4
update the redudce
lambda7xx Sep 24, 2023
be032a1
update the redudce
lambda7xx Sep 24, 2023
bb542fd
use exceptions in reduces.cc
lambda7xx Sep 27, 2023
ae22a95
fix the typo
lambda7xx Sep 30, 2023
40c6fc4
Merge branch 'repo-refactor' into repo-refactor-lambda-Reduce-OP
lockshaw Oct 6, 2023
a1fab29
Merge branch 'repo-refactor' into repo-refactor-lambda-Reduce-OP
lockshaw Oct 6, 2023
0573b88
fix the reduce.cc
lambda7xx Oct 10, 2023
dba438d
format the code
lambda7xx Oct 10, 2023
fe76016
Merge branch 'repo-refactor-lambda-Reduce-OP' of https://github.com/l…
lambda7xx Oct 10, 2023
3e3c147
Merge branch 'repo-refactor' into repo-refactor-lambda-Reduce-OP
reyna-abhyankar Oct 17, 2023
a74ae7a
Merge branch 'repo-refactor' into repo-refactor-lambda-Reduce-OP
reyna-abhyankar Oct 26, 2023
925356e
add reduction size
lambda7xx Oct 29, 2023
7ffc098
Merge branch 'repo-refactor-lambda-Reduce-OP' of https://github.com/l…
lambda7xx Oct 29, 2023
f1594ae
Merge branch 'repo-refactor' into repo-refactor-lambda-Reduce-OP
reyna-abhyankar Nov 1, 2023
408cf3e
Merge branch 'repo-refactor' into repo-refactor-lambda-Reduce-OP
reyna-abhyankar Nov 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 22 additions & 15 deletions lib/kernels/include/kernels/reduce_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,37 @@

namespace FlexFlow {

class ReducePerDeviceState : public PerDeviceOpState {
public:
ReducePerDeviceState(FFHandler handler,
Reduce const *rd,
Legion::Domain const &input_domain);
~ReducePerDeviceState(void);
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
cudnnTensorDescriptor_t inputTensor, outputTensor;
cudnnReduceTensorDescriptor_t reduceDesc;
#else
miopenTensorDescriptor_t inputTensor, outputTensor;
miopenReduceTensorDescriptor_t reduceDesc;
#endif
struct ReducePerDeviceState {
PerDeviceFFHandle handle;
ffTensorDescriptor_t inputTensor;
ffTensorDescriptor_t outputTensor;
ffReduceTensorDescriptor_t reduceDesc;
OperatorType op_type;
size_t reduction_size;
};

FF_VISITABLE_STRUCT(ReducePerDeviceState,
handle,
inputTensor,
outputTensor,
reduceDesc,
op_type,
reduction_size);

namespace Kernels {
namespace Reduce {
void forward_kernel_wrapper(ReducePerDeviceState const *m,

ReducePerDeviceState init_kernel(PerDeviceFFhandle const &,
OperatorType const &,
size_t const &,
ArrayShape input_shape,
ArrayShape output_shape);

void forward_kernel_wrapper(ReducePerDeviceState const &m,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output);

void backward_kernel_wrapper(ReducePerDeviceState const *m,
void backward_kernel_wrapper(ReducePerDeviceState const &m,
GenericTensorAccessorR const &output_grad,
GenericTensorAccessorW const &input_grad);

Expand Down
53 changes: 38 additions & 15 deletions lib/kernels/src/cuda/reduce_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,50 +67,73 @@ ReducePerDeviceState::~ReducePerDeviceState(void) {
namespace Kernels {
namespace Reduce {

ReducePerDeviceState init_kernel(PerDeviceFFhandle const &handle,
OperatorType const &op_type,
size_t const &reduction_size,
ArrayShape const &input_shape,
ArrayShape const &output_shape) {

ffTensorDescriptor_t inputTensor;
ffTensorDescriptor_t outputTensor;
ffReduceTensorDescriptor_t reduceDesc;

checkCUDNN(cudnnCreateTensorDescriptor(&inputTensor));
checkCUDNN(cudnnCreateTensorDescriptor(&outputTensor));
;
checkCUDNN(cudnnCreateReduceTensorDescriptor(&reduceDesc));

checkCUDNN(cudnnSetTensorDescriptorFromArrayShape(inputTensor, input_shape));
checkCUDNN(
cudnnSetTensorDescriptorFromArrayShape(outputTensor, output_shape));

ReducePerDeviceState per_device = {
handle, inputTensor, outputTensor, reduceDesc, op_type, reduction_size};
}

void forward_kernel(cudaStream_t stream,
ReducePerDeviceState const *m,
ReducePerDeviceState const &m,
float const *input_ptr,
float *output_ptr) {
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));
checkCUDNN(cudnnSetStream(m.handle.dnn, stream));
float alpha = 1.0f, beta = 0.0f;
checkCUDNN(cudnnReduceTensor(m->handle.dnn,
m->reduceDesc,
checkCUDNN(cudnnReduceTensor(m.handle.dnn,
m.reduceDesc,
nullptr /*indices*/,
0 /*indicesSizeInBytes*/,
m->handle.workSpace,
m->handle.workSpaceSize,
m.handle.workSpace,
m.handle.workSpaceSize,
&alpha,
m->inputTensor,
m.inputTensor,
input_ptr,
&beta,
m->outputTensor,
m.outputTensor,
output_ptr));
};

void backward_kernel(cudaStream_t stream,
ReducePerDeviceState const *m,
ReducePerDeviceState const &m,
float const *output_grad_ptr,
float *input_grad_ptr) {
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));
checkCUDNN(cudnnSetStream(m.handle.dnn, stream));
float alpha = 1.0, beta = 1.0f;
switch (m->op_type) {
switch (m.op_type) {
case OP_REDUCE_SUM:
alpha = 1.0f;
break;
case OP_REDUCE_MEAN:
// When the output is the average of multiple input elements
// we need to scale the gradients by 1.0 / reduction_size
alpha = 1.0f / m->reduction_size;
alpha = 1.0f / m.reduction_size;
break;
default:
assert(false);
}
checkCUDNN(cudnnAddTensor(m->handle.dnn,
checkCUDNN(cudnnAddTensor(m.handle.dnn,
&alpha,
m->outputTensor,
m.outputTensor,
output_grad_ptr,
&beta,
m->inputTensor,
m.inputTensor,
input_grad_ptr));
}

Expand Down
Loading