Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions lib/kernels/include/kernels/batch_matmul_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,43 @@

namespace FlexFlow {

class BatchMatmulPerDeviceState : public PerDeviceOpState {
public:
BatchMatmulPerDeviceState(FFHandler handler);
int a_seq_length_dim, b_seq_length_dim;
struct BMMPerDeviceState {
PerDeviceFFHandle handle;
Allocator allocator;
int a_seq_length_dim;
req<int> b_seq_length_dim;
};

FF_VISITABLE_STRUCT_NO_EQ(
BMMPerDeviceState, handle, allocator, a_seq_length_dim, b_seq_length_dim);

namespace Kernels {
namespace BatchMatmul {

BMMPerDeviceState init_kernel(PerDeviceFFHandle const &handle,
Allocator const &allocator,
int a_seq_length_dim,
int b_seq_length_dim);

void forward_kernel(ffStream_t stream,
BatchMatmulPerDeviceState const *,
float *o_ptr,
float const *a_ptr,
float const *b_ptr,
float const *c_ptr,
BMMPerDeviceState const &meta,
float *output_ptr,
float const *lhs_input_ptr,
float const *rhs_input_ptr,
int m,
int n,
int k,
int batch,
int a_seq_length_dim = -1,
int b_seq_length_dim = -1,
int seq_length = -1);

void backward_kernel(ffStream_t stream,
BatchMatmulPerDeviceState const *,
BMMPerDeviceState const &meta,
float const *o_ptr,
float const *o_grad_ptr,
float const *a_ptr,
float *a_grad_ptr,
float const *b_ptr,
float *b_grad_ptr,
float *c_grad_ptr,
int m,
int n,
int k,
Expand Down
70 changes: 53 additions & 17 deletions lib/kernels/include/kernels/batch_norm_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,66 @@

namespace FlexFlow {

class BatchNormPerDeviceState : public PerDeviceOpState {
public:
BatchNormPerDeviceState(FFHandler handle,
std::unique_ptr<IAllocator> allocator,
int output_n,
int output_c,
int output_h,
int output_w,
bool relu,
bool profiling);
~BatchNormPerDeviceState(void);

ffTensorDescriptor_t inputTensor, outputTensor, biasTensor;
struct BatchNormPerDeviceState {
PerDeviceFFHandle handle;
Allocator allocator;
ffTensorDescriptor_t inputTensor;
ffTensorDescriptor_t outputTensor;
ffTensorDescriptor_t biasTensor;
ffActivationDescriptor_t actiDesc;
ffBatchNormMode_t mode;
float *runningMean, *runningVar, *saveMean, *saveVar;
bool relu;
bool profiling;
std::unique_ptr<IAllocator> allocator;
float *runningMean;
float *runningVar;
float *saveMean;
float *saveVar;
int output_n;
int output_c;
int output_h;
int output_w;
ProfilingSettings profiling;
req<bool> relu;
};

FF_VISITABLE_STRUCT_NO_EQ(BatchNormPerDeviceState,
handle,
allocator,
inputTensor,
outputTensor,
biasTensor,
actiDesc,
mode,
runningMean,
runningVar,
saveMean,
saveVar,
output_n,
output_c,
output_h,
output_w,
profiling,
relu);

namespace Kernels {
namespace BatchNorm {

BatchNormPerDeviceState init_kernel(PerDeviceFFHandle handle,
Allocator allocator,
ffTensorDescriptor_t inputTensor,
ffTensorDescriptor_t outputTensor,
ffTensorDescriptor_t biasTensor,
ffActivationDescriptor_t actiDesc,
ffBatchNormMode_t mode,
float *runningMean,
float *runningVar,
float *saveMean,
float *saveVar,
int output_n,
int output_c,
int output_h,
int output_w,
ProfilingSettings profiling,
bool relu);

void forward_kernel(ffStream_t stream,
BatchNormPerDeviceState *m,
float const *input_ptr,
Expand Down
5 changes: 1 addition & 4 deletions lib/kernels/src/cuda/batch_matmul_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@

namespace FlexFlow {

BatchMatmulPerDeviceState::BatchMatmulPerDeviceState(FFHandler handler)
: PerDeviceOpState(handler) {}

namespace Kernels {
namespace BatchMatmul {

Expand Down Expand Up @@ -124,7 +121,7 @@ O = A * B
*/

void forward_kernel(cudaStream_t stream,
BatchMatmulPerDeviceState const *meta,
BatchMatmulPerDeviceState const &meta,
float *o_ptr,
float const *a_ptr,
float const *b_ptr,
Expand Down
9 changes: 3 additions & 6 deletions lib/kernels/src/hip/batch_matmul_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@

namespace FlexFlow {

BatchMatmulPerDeviceState::BatchMatmulPerDeviceState(FFHandler handler)
: PerDeviceOpState(handler) {}

namespace Kernels {
namespace BatchMatmul {

Expand All @@ -32,7 +29,7 @@ O: (batch, n, m)
O = A * B
*/
void forward_kernel(hipStream_t stream,
BatchMatmulPerDeviceState const *meta,
BatchMatmulPerDeviceState const &meta,
float *o_ptr,
float const *a_ptr,
float const *b_ptr,
Expand All @@ -42,9 +39,9 @@ void forward_kernel(hipStream_t stream,
int k,
int batch,
hipStream_t stream,
int a_seq_length_dim,
int b_seq_length_dim,
int seq_length) {
int a_seq_length_dim = meta->a_seq_length_dim;
int b_seq_length_dim = meta->b_seq_length_dim;
checkCUDA(hipblasSetStream(meta->handle.blas, stream));
checkCUDNN(miopenSetStream(meta->handle.dnn, stream));

Expand Down
4 changes: 3 additions & 1 deletion lib/op-attrs/include/op-attrs/ops/batch_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ struct BatchMatmulAttrs {
};
FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim);

CHECK_VALID_OP_ATTR(BatchMatmulAttrs);
int get_aSeqLengthDim(BatchMatmulAttrs const &attrs);
int get_bSeqLengthDim(BatchMatmulAttrs const &attrs);

CHECK_VALID_OP_ATTR(BatchMatmulAttrs);
} // namespace FlexFlow

#endif
8 changes: 8 additions & 0 deletions lib/op-attrs/src/batch_matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@

namespace FlexFlow {

int get_aSeqLengthDim(BatchMatmulAttrs const &attrs) {
return attrs.a_seq_length_dim;
}

int get_bSeqLengthDim(BatchMatmulAttrs const &attrs) {
return attrs.b_seq_length_dim;
}

/* bool BatchMatmulAttrs::is_valid( */
/* ParallelTensorShape const &lhs, ParallelTensorShape const &rhs) const {
*/
Expand Down
5 changes: 3 additions & 2 deletions lib/runtime/include/runtime/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,14 @@ struct FFConfig : public use_visitable_cmp<FFConfig> {
int python_data_loader_type = 2;
};

class FFIterationConfig {
public:
struct FFIterationConfig {
FFIterationConfig();
void reset();
int seq_length;
};

FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(FFIterationConfig, seq_length);

enum FieldIDs {
FID_DATA,
};
Expand Down
15 changes: 0 additions & 15 deletions lib/runtime/src/ops/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,18 +121,6 @@ static DeviceSpecific<MHAPerDeviceState>
int num_samples = get_piece_shape(query_parallel_tensor_shape)[ff_dim_t(2)];
int num_heads = get_piece_shape(weight_parallel_tensor_shape)[ff_dim_t(1)];

assert(qoSeqLength == query.shape[legion_dim_t(1)]);
assert(qSize == query.shape[legion_dim_t(0)]);
assert(num_samples == key.shape[legion_dim_t(2)]);
assert(kvSeqLength == key.shape[legion_dim_t(1)]);
assert(kSize == key.shape[legion_dim_t(0)]);
assert(num_samples == value.shape[legion_dim_t(2)]);
assert(kvSeqLength == value.shape[legion_dim_t(1)]);
assert(vSize == value.shape[legion_dim_t(0)]);
assert(num_samples == output.shape[legion_dim_t(2)]);
assert(qoSeqLength == output.shape[legion_dim_t(1)]);
assert(oProjSize == output.shape[legion_dim_t(0)]);

DeviceSpecific<MHAPerDeviceState> per_device_state =
acc.create_device_specific<MHAPerDeviceState>(
init_kernel(handle,
Expand All @@ -149,9 +137,6 @@ static DeviceSpecific<MHAPerDeviceState>
qoSeqLength,
kvSeqLength,
attrs.add_bias_kv));

assert(weight.shape.get_volume() * sizeof(float) ==
acc.unwrap(per_device_state)->weightSize);
return per_device_state;
}

Expand Down
Loading