Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
6b248cf
batch matmul initial commit
KateUnger Aug 22, 2023
d9aae74
fix FF_VISITABLE_STRUCT_NO_EQ
KateUnger Aug 22, 2023
45ecd8c
Merge branch 'repo-refactor' into batch_matmul
lockshaw Aug 23, 2023
2073822
finish draft 1 batch_matmul
KateUnger Aug 23, 2023
5c322bd
Merge branch 'batch_matmul' of github.com:KateUnger/FlexFlow into bat…
KateUnger Aug 23, 2023
d5806a0
add output and weights
KateUnger Aug 23, 2023
f620070
format
KateUnger Aug 23, 2023
e954ab0
Merge branch 'repo-refactor' of https://github.com/flexflow/FlexFlow …
KateUnger Aug 23, 2023
5458923
fix DeviceSpecific
KateUnger Aug 23, 2023
2f604dc
batch_norm
KateUnger Aug 26, 2023
a32f4e5
cast op
KateUnger Aug 28, 2023
d02998d
combine
KateUnger Aug 28, 2023
f2205f4
change
KateUnger Aug 29, 2023
2f4662d
change
KateUnger Aug 29, 2023
642eb90
change
KateUnger Aug 29, 2023
e79406a
change
KateUnger Aug 29, 2023
bb3c10f
change
KateUnger Aug 29, 2023
8488509
change
KateUnger Aug 29, 2023
36eba29
change
KateUnger Aug 29, 2023
98773b4
change
KateUnger Aug 29, 2023
ae59261
change
KateUnger Aug 29, 2023
fe049ef
Merge branch 'batch_matmul' into concat
KateUnger Aug 29, 2023
8db2512
fix asserts
KateUnger Aug 30, 2023
6a28b94
Merge branch 'batch_matmul' into batch_norm
KateUnger Aug 30, 2023
e8a6c30
remove asserts
KateUnger Aug 30, 2023
09acbe5
format
KateUnger Aug 30, 2023
779d8f0
Merge branch 'batch_matmul' into batch_norm
KateUnger Aug 30, 2023
cddefe6
Merge branch 'batch_norm' into cast
KateUnger Aug 30, 2023
c5338ea
Merge branch 'cast' into combine
KateUnger Aug 30, 2023
8c603df
Merge branch 'combine' into concat
KateUnger Aug 30, 2023
67d2d1e
concat
KateUnger Aug 31, 2023
d3342d5
conv2d
KateUnger Sep 1, 2023
bf75067
format
KateUnger Sep 1, 2023
3631b0e
Merge branch 'combine' into concat
KateUnger Sep 1, 2023
d1f4fb9
format
KateUnger Sep 1, 2023
b3bbe2e
Merge branch 'concat' into conv2d
KateUnger Sep 1, 2023
fdf8351
format
KateUnger Sep 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions lib/kernels/include/kernels/batch_matmul_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,43 @@

namespace FlexFlow {

class BatchMatmulPerDeviceState : public PerDeviceOpState {
public:
BatchMatmulPerDeviceState(FFHandler handler);
int a_seq_length_dim, b_seq_length_dim;
struct BMMPerDeviceState {
PerDeviceFFHandle handle;
Allocator allocator;
int a_seq_length_dim;
req<int> b_seq_length_dim;
};

FF_VISITABLE_STRUCT_NO_EQ(
BMMPerDeviceState, handle, allocator, a_seq_length_dim, b_seq_length_dim);

namespace Kernels {
namespace BatchMatmul {

BMMPerDeviceState init_kernel(PerDeviceFFHandle const &handle,
Allocator const &allocator,
int a_seq_length_dim,
int b_seq_length_dim);

void forward_kernel(ffStream_t stream,
BatchMatmulPerDeviceState const *,
float *o_ptr,
float const *a_ptr,
float const *b_ptr,
float const *c_ptr,
BMMPerDeviceState const &meta,
float *output_ptr,
float const *lhs_input_ptr,
float const *rhs_input_ptr,
int m,
int n,
int k,
int batch,
int a_seq_length_dim = -1,
int b_seq_length_dim = -1,
int seq_length = -1);

void backward_kernel(ffStream_t stream,
BatchMatmulPerDeviceState const *,
BMMPerDeviceState const &meta,
float const *o_ptr,
float const *o_grad_ptr,
float const *a_ptr,
float *a_grad_ptr,
float const *b_ptr,
float *b_grad_ptr,
float *c_grad_ptr,
int m,
int n,
int k,
Expand Down
70 changes: 53 additions & 17 deletions lib/kernels/include/kernels/batch_norm_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,66 @@

namespace FlexFlow {

class BatchNormPerDeviceState : public PerDeviceOpState {
public:
BatchNormPerDeviceState(FFHandler handle,
std::unique_ptr<IAllocator> allocator,
int output_n,
int output_c,
int output_h,
int output_w,
bool relu,
bool profiling);
~BatchNormPerDeviceState(void);

ffTensorDescriptor_t inputTensor, outputTensor, biasTensor;
struct BatchNormPerDeviceState {
PerDeviceFFHandle handle;
Allocator allocator;
ffTensorDescriptor_t inputTensor;
ffTensorDescriptor_t outputTensor;
ffTensorDescriptor_t biasTensor;
ffActivationDescriptor_t actiDesc;
ffBatchNormMode_t mode;
float *runningMean, *runningVar, *saveMean, *saveVar;
bool relu;
bool profiling;
std::unique_ptr<IAllocator> allocator;
float *runningMean;
float *runningVar;
float *saveMean;
float *saveVar;
int output_n;
int output_c;
int output_h;
int output_w;
ProfilingSettings profiling;
req<bool> relu;
};

FF_VISITABLE_STRUCT_NO_EQ(BatchNormPerDeviceState,
handle,
allocator,
inputTensor,
outputTensor,
biasTensor,
actiDesc,
mode,
runningMean,
runningVar,
saveMean,
saveVar,
output_n,
output_c,
output_h,
output_w,
profiling,
relu);

namespace Kernels {
namespace BatchNorm {

BatchNormPerDeviceState init_kernel(PerDeviceFFHandle handle,
Allocator allocator,
ffTensorDescriptor_t inputTensor,
ffTensorDescriptor_t outputTensor,
ffTensorDescriptor_t biasTensor,
ffActivationDescriptor_t actiDesc,
ffBatchNormMode_t mode,
float *runningMean,
float *runningVar,
float *saveMean,
float *saveVar,
int output_n,
int output_c,
int output_h,
int output_w,
ProfilingSettings profiling,
bool relu);

void forward_kernel(ffStream_t stream,
BatchNormPerDeviceState *m,
float const *input_ptr,
Expand Down
17 changes: 12 additions & 5 deletions lib/kernels/include/kernels/cast_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,26 @@

#include "kernels/accessor.h"
#include "kernels/device.h"
#include "op-attrs/ffconst.h"

namespace FlexFlow {

class CastPerDeviceState : public PerDeviceOpState {
public:
CastPerDeviceState(FFHandler handle);
DataType input_data_type, output_data_type;
struct CastPerDeviceState {
PerDeviceFFHandle handle;
DataType input_data_type;
req<DataType> output_data_type;
};

FF_VISITABLE_STRUCT_NO_EQ(CastPerDeviceState,
handle,
input_data_type,
output_data_type);

namespace Kernels {
namespace Cast {

CastPerDeviceState
init_kernel(PerDeviceFFHandle const &, DataType input, DataType output);

void forward_kernel(ffStream_t stream,
CastPerDeviceState const *,
GenericTensorAccessorR const &input,
Expand Down
10 changes: 6 additions & 4 deletions lib/kernels/include/kernels/combine_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@

namespace FlexFlow {

class CombinePerDeviceState : public PerDeviceOpState {
public:
CombinePerDeviceState(FFHandler handle);
DataType data_type;
struct CombinePerDeviceState {
req<DataType> data_type;
};

FF_VISITABLE_STRUCT_NO_EQ(CombinePerDeviceState, data_type);

namespace Kernels {
namespace Combine {

CombinePerDeviceState init_kernel(DataType data_type);

void forward_kernel(ffStream_t stream,
CombinePerDeviceState const *m,
GenericTensorAccessorR const &input,
Expand Down
24 changes: 12 additions & 12 deletions lib/kernels/include/kernels/concat_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,29 @@

namespace FlexFlow {

class ConcatPerDeviceState : public PerDeviceOpState {
public:
ConcatPerDeviceState(FFHandler handle) : PerDeviceOpState(handle){};
int legion_axis;
char op_name[MAX_OPNAME];
struct ConcatPerDeviceState {
req<ff_dim_t> legion_axis;
};

FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(ConcatPerDeviceState, legion_axis);

namespace Kernels {
namespace Concat {

void init_meta(ConcatPerDeviceState *meta, int legion_axis);
ConcatPerDeviceState init_kernel(ff_dim_t legion_axis);

void forward_kernel(ffStream_t stream,
ConcatPerDeviceState const *m,
GenericTensorAccessorW const &output,
GenericTensorAccessorR const *inputs,
std::vector<FlexFlow::GenericTensorAccessorR> const &inputs,
int num_inputs);

void backward_kernel(ffStream_t stream,
ConcatPerDeviceState const *m,
GenericTensorAccessorR const &output_grad,
GenericTensorAccessorW const *input_grads,
int num_inputs);
void backward_kernel(
ffStream_t stream,
ConcatPerDeviceState const *m,
GenericTensorAccessorR const &output_grad,
std::vector<FlexFlow::GenericTensorAccessorW> const &input_grads,
int num_inputs);

} // namespace Concat
} // namespace Kernels
Expand Down
65 changes: 35 additions & 30 deletions lib/kernels/include/kernels/conv_2d_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,50 @@

namespace FlexFlow {

class Conv2DPerDeviceState : public PerDeviceOpState {
public:
Conv2DPerDeviceState(FFHandler handler);
ffTensorDescriptor_t inputTensor, biasTensor, outputTensor;
struct Conv2DPerDeviceState {
PerDeviceFFHandle handle;
ffTensorDescriptor_t inputTensor;
ffTensorDescriptor_t biasTensor;
ffTensorDescriptor_t outputTensor;
ffFilterDescriptor_t filterDesc;
ffActivationDescriptor_t actiDesc;
ffConvolutionDescriptor_t convDesc;
ffConvolutionFwdAlgo_t fwdAlgo;
ffConvolutionBwdFilterAlgo_t bwdFilterAlgo;
ffConvolutionBwdDataAlgo_t bwdDataAlgo;
bool relu, use_bias;
char op_name[MAX_OPNAME];
req<optional<Activation>> activation;
req<bool> use_bias;
};

FF_VISITABLE_STRUCT_NO_EQ(Conv2DPerDeviceState,
handle,
inputTensor,
biasTensor,
outputTensor,
filterDesc,
actiDesc,
convDesc,
fwdAlgo,
bwdFilterAlgo,
bwdDataAlgo,
activation,
use_bias);

namespace Kernels {
namespace Conv2D {

void init_kernel(Conv2DPerDeviceState *m,
int input_w,
int input_h,
int input_c,
int input_n,
int output_w,
int output_h,
int output_c,
int output_n,
int kernel_h,
int kernel_w,
int groups,
int stride_h,
int stride_w,
int pad_h,
int pad_w,
float const *input_ptr,
float *output_ptr,
float const *kernel_ptr,
float *kernel_grad_ptr,
float *forward_time = nullptr,
float *backward_time = nullptr);
Conv2DPerDeviceState init_kernel(PerDeviceFFHandle handle,
ffTensorDescriptor_t inputTensor,
ffTensorDescriptor_t biasTensor,
ffTensorDescriptor_t outputTensor,
ffFilterDescriptor_t filterDesc,
ffActivationDescriptor_t actiDesc,
ffConvolutionDescriptor_t convDesc,
ffConvolutionFwdAlgo_t fwdAlgo,
ffConvolutionBwdFilterAlgo_t bwdFilterAlgo,
ffConvolutionBwdDataAlgo_t bwdDataAlgo,
req<optional<Activation>> relu,
bool use_bias);

void forward_kernel(ffStream_t stream,
Conv2DPerDeviceState const *m,
Expand All @@ -58,8 +63,8 @@ void backward_kernel(ffStream_t stream,
float *input_grad_ptr,
float const *output_ptr,
float *output_grad_ptr,
float const *kernel_ptr,
float *kernel_grad_ptr,
float const *filter_ptr,
float *filter_grad_ptr,
float *bias_grad_ptr);

} // namespace Conv2D
Expand Down
5 changes: 1 addition & 4 deletions lib/kernels/src/cuda/batch_matmul_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@

namespace FlexFlow {

BatchMatmulPerDeviceState::BatchMatmulPerDeviceState(FFHandler handler)
: PerDeviceOpState(handler) {}

namespace Kernels {
namespace BatchMatmul {

Expand Down Expand Up @@ -124,7 +121,7 @@ O = A * B
*/

void forward_kernel(cudaStream_t stream,
BatchMatmulPerDeviceState const *meta,
BatchMatmulPerDeviceState const &meta,
float *o_ptr,
float const *a_ptr,
float const *b_ptr,
Expand Down
9 changes: 3 additions & 6 deletions lib/kernels/src/hip/batch_matmul_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@

namespace FlexFlow {

BatchMatmulPerDeviceState::BatchMatmulPerDeviceState(FFHandler handler)
: PerDeviceOpState(handler) {}

namespace Kernels {
namespace BatchMatmul {

Expand All @@ -32,7 +29,7 @@ O: (batch, n, m)
O = A * B
*/
void forward_kernel(hipStream_t stream,
BatchMatmulPerDeviceState const *meta,
BatchMatmulPerDeviceState const &meta,
float *o_ptr,
float const *a_ptr,
float const *b_ptr,
Expand All @@ -42,9 +39,9 @@ void forward_kernel(hipStream_t stream,
int k,
int batch,
hipStream_t stream,
int a_seq_length_dim,
int b_seq_length_dim,
int seq_length) {
int a_seq_length_dim = meta->a_seq_length_dim;
int b_seq_length_dim = meta->b_seq_length_dim;
checkCUDA(hipblasSetStream(meta->handle.blas, stream));
checkCUDNN(miopenSetStream(meta->handle.dnn, stream));

Expand Down
4 changes: 0 additions & 4 deletions lib/kernels/src/hip/concat_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@ using Legion::Rect;
namespace Kernels {
namespace Concat {

void init_meta(ConcatPerDeviceState *m, int legion_axis) {
m->legion_axis = legion_axis;
}

template <int N>
void calc_blk_size(coord_t &num_blocks,
coord_t &blk_size,
Expand Down
1 change: 1 addition & 0 deletions lib/op-attrs/include/op-attrs/get_op_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ OperatorType get_op_type(BatchMatmulAttrs const &);
OperatorType get_op_type(BatchNormAttrs const &);
OperatorType get_op_type(BroadcastAttrs const &);
OperatorType get_op_type(CastAttrs const &);
OperatorType get_op_type(CombineAttrs const &);
OperatorType get_op_type(ConcatAttrs const &);
OperatorType get_op_type(Conv2DAttrs const &);
OperatorType get_op_type(DropoutAttrs const &);
Expand Down
Loading