From 6b248cf20e560dbc2ff57d0db2bf463775177620 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 22 Aug 2023 14:11:42 -0700 Subject: [PATCH 01/33] batch matmul initial commit --- .../include/kernels/batch_matmul_kernels.h | 23 +- lib/kernels/src/hip/batch_matmul_kernels.cpp | 7 +- .../include/op-attrs/ops/batch_matmul.h | 4 +- lib/op-attrs/src/batch_matmul.cc | 8 + lib/runtime/src/ops/batch_matmul.cc | 855 ++++-------------- lib/runtime/src/ops/batch_matmul.h | 441 ++++++++- 6 files changed, 655 insertions(+), 683 deletions(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index 0e4437bdb8..cdffdf1907 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -6,17 +6,26 @@ namespace FlexFlow { -class BatchMatmulPerDeviceState : public PerDeviceOpState { -public: - BatchMatmulPerDeviceState(FFHandler handler); - int a_seq_length_dim, b_seq_length_dim; +struct BMMPerDeviceState { + PerDeviceFFHandle handle; + Allocator allocator; + int a_seq_length_dim; + int b_seq_length_dim; }; +FF_VISITABLE_STRUCT_NO_EQ(BMMPerDeviceState, + handle,); + namespace Kernels { namespace BatchMatmul { +BMMPerDeviceState init_kernel(PerDeviceFFHandle handle, + Allocator allocator, + int a_seq_length_dim, + int b_seq_length_dim); + void forward_kernel(ffStream_t stream, - BatchMatmulPerDeviceState const *, + BMMPerDeviceState const *meta, float *o_ptr, float const *a_ptr, float const *b_ptr, @@ -25,12 +34,10 @@ void forward_kernel(ffStream_t stream, int n, int k, int batch, - int a_seq_length_dim = -1, - int b_seq_length_dim = -1, int seq_length = -1); void backward_kernel(ffStream_t stream, - BatchMatmulPerDeviceState const *, + BMMPerDeviceState const *meta, float const *o_ptr, float const *o_grad_ptr, float const *a_ptr, diff --git a/lib/kernels/src/hip/batch_matmul_kernels.cpp b/lib/kernels/src/hip/batch_matmul_kernels.cpp index d8b6500326..e5334b1841 100644 --- a/lib/kernels/src/hip/batch_matmul_kernels.cpp +++ b/lib/kernels/src/hip/batch_matmul_kernels.cpp @@ -19,9 +19,6 @@ namespace FlexFlow { -BatchMatmulPerDeviceState::BatchMatmulPerDeviceState(FFHandler handler) - : PerDeviceOpState(handler) {} - namespace Kernels { namespace BatchMatmul { @@ -42,9 +39,9 @@ void forward_kernel(hipStream_t stream, int k, int batch, hipStream_t stream, - int a_seq_length_dim, - int b_seq_length_dim, int seq_length) { + int a_seq_length_dim = meta->a_seq_length_dim; + int b_seq_length_dim = meta->b_seq_length_dim; checkCUDA(hipblasSetStream(meta->handle.blas, stream)); checkCUDNN(miopenSetStream(meta->handle.dnn, stream)); diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index c74824570c..00c700ba20 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -12,8 +12,10 @@ struct BatchMatmulAttrs { }; FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim); -CHECK_VALID_OP_ATTR(BatchMatmulAttrs); +int get_aSeqLengthDim(BatchMatmulAttrs const &attrs); +int get_bSeqLengthDim(BatchMatmulAttrs const &attrs); +CHECK_VALID_OP_ATTR(BatchMatmulAttrs); } // namespace FlexFlow #endif diff --git a/lib/op-attrs/src/batch_matmul.cc b/lib/op-attrs/src/batch_matmul.cc index 1cc8c5cfda..bd61c24737 100644 --- a/lib/op-attrs/src/batch_matmul.cc +++ b/lib/op-attrs/src/batch_matmul.cc @@ -2,6 +2,14 @@ namespace FlexFlow { +int get_aSeqLengthDim(BatchMatmulAttrs const &attrs) { + return attrs.a_seq_length_dim; +} + +int get_bSeqLengthDim(BatchMatmulAttrs const &attrs) { + return attrs.b_seq_length_dim; +} + /* bool BatchMatmulAttrs::is_valid( */ /* ParallelTensorShape const &lhs, ParallelTensorShape const &rhs) const { */ diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 3e860bd413..9bbc050f8b 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -13,754 +13,287 @@ * limitations under the License. */ +// #include "batch_matmul.h" +// #include "kernels/batch_matmul_kernels.h" +// #include "kernels/profiling.h" +// #include "legion/legion_utilities.h" +// #include "tasks.h" + #include "batch_matmul.h" #include "kernels/batch_matmul_kernels.h" -#include "kernels/profiling.h" -#include "legion/legion_utilities.h" -#include "tasks.h" +#include "legion.h" +#include "op-attrs/ops/batch_matmul.h" namespace FlexFlow { using namespace FlexFlow::Kernels::BatchMatmul; +using Legion::Context; +using Legion::PhysicalRegion; +using Legion::Runtime; +using Legion::Task; + enum Slots { - A_INPUT, - B_INPUT, - OUTPUT, - A_INPUT_GRAD, - B_INPUT_GRAD, - OUTPUT_GRAD, - ATTRS, - PROFILING -}; + A_INPUT, //tensor + B_INPUT, //tensor + OUTPUT, //tensor + PROFILING, + HANDLE, + A_SEQ_LENGTH_DIM, + B_SEQ_LENGTH_DIM, + PER_DEVICE_STATE + }; OpTaskInvocation init(BatchMatmulAttrs const &attrs) { - OpTaskBinding b; + OpTaskBinding init; - b.bind_arg(ATTRS, attrs); - b.bind_arg(PROFILING, enable_profiling()); + init.bind_arg(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); + init.bind_arg(B_SEQ_LENGTH_DIM, get_bSeqLengthDim(attrs)); + init.bind_arg(HANDLE, ff_handle()); - return {BATCHMATMUL_INIT_TASK_ID, b}; + return {BATCHMATMUL_INIT_TASK_ID, init}; } OpTaskInvocation forward(BatchMatmulAttrs const &attrs) { - OpTaskBinding b; - - b.bind(A_INPUT, input_tensor(0)); - b.bind(B_INPUT, input_tensor(1)); - b.bind(OUTPUT, output_tensor(0)); - - return {BATCHMATMUL_FWD_TASK_ID, b}; -} - -OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { - OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - - return {BATCHMATMUL_BWD_TASK_ID, b}; -} - -BatchMatmulParams BatchMatmul::get_params() const { - BatchMatmulParams params; - params.a_seq_length_dim = inputs[0]->num_dims - 1 - this->a_seq_length_dim; - params.b_seq_length_dim = inputs[1]->num_dims - 1 - this->b_seq_length_dim; - return params; -} - -Tensor FFModel::batch_matmul(const Tensor A, - const Tensor B, - int a_seq_length_dim, - int b_seq_length_dim, - char const *name) { - Layer *bmm = new Layer(this, - OP_BATCHMATMUL, - DT_FLOAT, - name, - 2 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - A, - B); - assert((a_seq_length_dim <= 1) && - "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " - "Fortran ordering)."); - assert((b_seq_length_dim <= 1) && - "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " - "Fortran ordering)."); - assert(A->num_dims == B->num_dims); - for (int i = A->num_dims - 1; i >= 2; i--) { - assert(A->dims[i] == B->dims[i]); - } - assert(A->dims[0] == B->dims[1]); - int dims[MAX_TENSOR_DIM]; - int numdim = A->num_dims; - for (int i = 0; i < numdim; i++) { - dims[i] = A->dims[i]; - } - dims[0] = B->dims[0]; - bmm->outputs[0] = create_tensor_legion_ordering( - numdim, dims, A->data_type, bmm, 0, true /*create_grad*/); - bmm->add_int_property("a_seq_length_dim", a_seq_length_dim); - bmm->add_int_property("b_seq_length_dim", b_seq_length_dim); - layers.push_back(bmm); - return bmm->outputs[0]; -} - -Op *BatchMatmul::create_operator_from_layer( - FFModel &model, - Layer const *layer, - std::vector const &inputs) { - long long value; - layer->get_int_property("a_seq_length_dim", value); - int a_seq_length_dim = value; - layer->get_int_property("b_seq_length_dim", value); - int b_seq_length_dim = value; - return new BatchMatmul(model, - inputs[0], - inputs[1], - a_seq_length_dim, - b_seq_length_dim, - layer->name); -} - -BatchMatmul::BatchMatmul( - FFModel &model, - BatchMatmulParams const ¶ms, - std::pair const &inputs, - char const *name) - : BatchMatmul(model, - inputs.first, - inputs.second, - params.a_seq_length_dim, - params.b_seq_length_dim, - name) {} - -// return A*B -BatchMatmul::BatchMatmul(FFModel &model, - const ParallelTensor A, - const ParallelTensor B, - int _a_seq_length_dim, - int _b_seq_length_dim, - char const *name) - : Op(model, - OP_BATCHMATMUL, - DT_FLOAT, - name, - 2 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - A, - B), - a_seq_length_dim(A->num_dims - 1 - _a_seq_length_dim), - b_seq_length_dim(B->num_dims - 1 - _b_seq_length_dim) { - assert((_a_seq_length_dim <= 1) && - "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " - "Fortran ordering)."); - assert((_b_seq_length_dim <= 1) && - "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " - "Fortran ordering)."); - assert(A->num_dims == B->num_dims); - for (int i = A->num_dims - 1; i >= 2; i--) { - assert(A->dims[i] == B->dims[i]); - } - assert(A->dims[0] == B->dims[1]); - ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < A->num_dims; i++) { - dims[i] = A->dims[i]; - } - dims[0] = B->dims[0]; - numOutputs = 1; - outputs[0] = model.create_parallel_tensor_legion_ordering( - A->num_dims, dims, DT_FLOAT, this); - // C is not none - // if (C != Tensor::NO_TENSOR) { - // numInputs = 3; - // assert(C.num_dims == outputs[0].num_dims); - // for (int i = 0; i < C.num_dims; i++) - // assert(C.adim[i] == outputs[0].adim[i]); - //} -} - -void BatchMatmul::serialize(Legion::Serializer &sez) const { - BatchMatmulParams params = get_params(); - sez.serialize(params.a_seq_length_dim); - sez.serialize(params.b_seq_length_dim); -} - -using PCG::Node; -/*static*/ -Node BatchMatmul::deserialize(FFModel &ff, - Legion::Deserializer &dez, - ParallelTensor inputs[], - int num_inputs) { - assert(num_inputs == 2); - int a_seq_length_dim, b_seq_length_dim; - dez.deserialize(a_seq_length_dim); - dez.deserialize(b_seq_length_dim); - - BatchMatmulParams params; - params.a_seq_length_dim = a_seq_length_dim; - params.b_seq_length_dim = b_seq_length_dim; - return ff.get_or_create_node({inputs[0], inputs[1]}, params); -} + OpTaskBinding fwd; -Op *BatchMatmul::materialize(FFModel &ff, - ParallelTensor inputs[], - int num_inputs) const { - BatchMatmulParams params = get_params(); - return new BatchMatmul(ff, params, {inputs[0], inputs[1]}, this->name); -} - -template <> -void register_task() { - OpTaskSignature sig(OpTaskType::INIT); + fwd.bind(A_INPUT, input_tensor(0)); + fwd.bind(B_INPUT, input_tensor(1)); + fwd.bind(OUTPUT, output_tensor(0)); - sig.add_arg_slot(ATTRS); - sig.add_arg_slot(PROFILING); + fwd.bind_arg(PROFILING, profiling_settings()); + fwd.bind_arg(PER_DEVICE_STATE, per_device_op_state()); - register_task(BATCHMATMUL_INIT_TASK_ID, "BatchMatmul Init", sig, init_task); + return {BATCHMATMUL_FWD_TASK_ID, fwd}; } -static OpTaskSignature get_fwd_task_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_input_slot(A_INPUT, READ_WRITE); - fwd.add_input_slot(B_INPUT, READ_WRITE); - fwd.add_output_slot(OUTPUT); - - return fwd; -} - -static OpTaskSignature get_bwd_task_signature() { - OpTaskSignature bwd(OpTaskType::BWD); - - bwd.add_input_slot(A_INPUT); - bwd.add_input_slot(B_INPUT); - bwd.add_input_grad_slot(A_INPUT_GRAD); - bwd.add_input_grad_slot(B_INPUT_GRAD); - bwd.add_output_slot(OUTPUT); - bwd.add_output_grad_slot(OUTPUT_GRAD); - - return bwd; -} - -OpTaskBinding BatchMatmul::get_init_task_binding() const { - OpTaskBinding binding; - - binding.bind_arg(ATTRS, this->attrs); - binding.bind_arg(PROFILING, this->profiling); +OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { + OpTaskBinding bwd = infer_bwd_binding(forward(attrs).binding); - return binding; + return {BATCHMATMUL_BWD_TASK_ID, bwd}; } -OpTaskBinding BatchMatmul::get_fwd_task_binding() const { - OpTaskBinding binding; - - binding.bind(A_INPUT, input_tensor(0)); - binding.bind(B_INPUT, input_tensor(1)); - binding.bind(OUTPUT, output_tensor(0)); - - binding.bind_arg(ATTRS, this->attrs); - return binding; +static DeviceSpecificArg init_task_impl(TaskArgumentAccessor const &acc) { + auto const a_seq_length_dim = acc.get_argument(A_SEQ_LENGTH_DIM); + auto const b_seq_length_dim = acc.get_argument(B_SEQ_LENGTH_DIM); + PerDeviceFFHandle handle = acc.get_argument(HANDLE); + Allocator allocator = acc.get_allocator(); + + DeviceSpecificArg per_device_state = + acc.create_device_specific( + init_kernel(handle, + allocator, + a_seq_length_dim, + b_seq_length_dim)); + + // assert(weight.shape.get_volume() * sizeof(float) == + // acc.unwrap(per_device_state)->weightSize); + return per_device_state; } -OpTaskBinding BatchMatmul::get_bwd_task_binding() const { - OpTaskBinding binding; - binding.bind(A_INPUT, input_tensor(0)); - binding.bind(B_INPUT, input_tensor(1)); - binding.bind_grad(A_INPUT_GRAD, input_tensor(0).grad()); - binding.bind_grad(B_INPUT_GRAD, input_tensor(1).grad()); - - binding.bind(OUTPUT, output_tensor(0)); - binding.bind_grad(OUTPUT_GRAD, output_tensor(0).grad()); - - binding.bind_arg(ATTRS, this->attrs); - return binding; -} - -void BatchMatmul::init(FFModel const &ff) { - int dim = outputs[0]->num_dims; - switch (dim) { -#define DIMFUNC(DIM) \ - case DIM: { \ - // init_with_dim(ff); - this->execute_task(ff, BATCHMATMUL_INIT_TASK_ID, get_init_task_signature()); - break; - } - LEGION_FOREACH_N(DIMFUNC) -#undef DIMFUNC - default: - assert(false); -} -} // namespace FlexFlow -// / -// template -// void BatchMatmul::init_with_dim(FFModel const &ff) { -// assert(check_output_input_weight_same_parallel_is()); -// parallel_is = outputs[0]->parallel_is; -// ArgumentMap argmap; -// Context ctx = ff.config.lg_ctx; -// Runtime *runtime = ff.config.lg_hlr; -// set_argumentmap_for_init(ff, argmap); -// IndexLauncher launcher(BATCHMATMUL_INIT_TASK_ID, -// parallel_is, -// TaskArgument(this, sizeof(BatchMatmul)), -// argmap, -// Predicate::TRUE_PRED, -// false /*must*/, -// 0 /*mapper_id*/, -// outputs[0]->machine_view.hash()); -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, -// 0 /*projection id*/, -// WRITE_ONLY, -// EXCLUSIVE, -// outputs[0]->region)); -// launcher.add_field(0, FID_DATA); -// for (int i = 0; i < numInputs; i++) { -// launcher.add_region_requirement(RegionRequirement(inputs[i]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[i]->region)); -// launcher.add_field(i + 1, FID_DATA); -// } -// FutureMap fm = runtime->execute_index_space(ctx, launcher); -// fm.wait_all_results(); -// set_opmeta_from_futuremap(ff, fm); -// } - -PerDeviceOpState * - BatchMatmul::init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { +static DeviceSpecificArg + init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { TaskArgumentAccessor acc(task, regions, ctx, runtime); - auto const &attrs = acc.get_argument(ATTRS); - bool profiling = acc.get_argument(PROFILING); - - FFHandler handle = *((FFHandler const *)task->local_args); - BatchMatmulPerDeviceState *m = new BatchMatmulPerDeviceState(handle); - m->profiling = profiling; - m->a_seq_length_dim = attrs.a_seq_length_dim; - m->b_seq_length_dim = attrs.b_seq_length_dim; - return m; -} - -void BatchMatmul::forward(FFModel const &ff) { - int dim = outputs[0]->num_dims; - switch (dim) { -#define DIMFUNC(DIM) \ - case DIM: { \ - // forward_with_dim(ff); - this->execute_task(ff, BATCHMATMUL_FWD_TASK_ID, get_fwd_task_signature()); - break; - } - LEGION_FOREACH_N(DIMFUNC) -#undef DIMFUNC - default: - assert(false); -} + return init_task_impl(acc); } -// template -// void BatchMatmul::forward_with_dim(FFModel const &ff) { -// ArgumentMap argmap; -// Context ctx = ff.config.lg_ctx; -// Runtime *runtime = ff.config.lg_hlr; -// set_argumentmap_for_forward(ff, argmap); -// IndexLauncher launcher( -// BATCHMATMUL_FWD_TASK_ID, -// parallel_is, -// TaskArgument(&ff.iter_config, sizeof(FFIterationConfig)), -// argmap, -// Predicate::TRUE_PRED, -// false /*must*/, -// 0 /*mapper_id*/, -// outputs[0]->machine_view.hash()); -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, -// 0 /*projection id*/, -// WRITE_ONLY, -// EXCLUSIVE, -// outputs[0]->region)); -// launcher.add_field(0, FID_DATA); -// for (int i = 0; i < numInputs; i++) { -// launcher.add_region_requirement(RegionRequirement(inputs[i]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[i]->region)); -// launcher.add_field(i + 1, FID_DATA); -// } -// runtime->execute_index_space(ctx, launcher); -// } - -/* - regions[0](O): output - regions[1](I): A - regions[2](I): B - ////////////////////(optional) regions[3](I): C -- TODO: is C deprecated? - output = A * B /////////+ C -*/ -void BatchMatmul::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { +static optional forward_task_impl(TaskArgumentAccessor const &acc) { assert(regions.size() == 3); assert(task->regions.size() == 3); - TaskArgumentAccessor acc(task, regions, ctx, runtime); + auto a_input = acc.get_tensor(A_INPUT); + auto b_input = acc.get_tensor(B_INPUT); + auto output = acc.get_tensor(OUTPUT); - // const BatchMatmul* bmm = (const BatchMatmul*) task->args; + ProfilingSettings profiling = acc.get_argument(PROFILING); + auto per_device_state = acc.get_argument(PER_DEVICE_STATE); FFIterationConfig const *iter_config = (FFIterationConfig const *)task->args; - // BatchMatmulMeta const *meta = *((BatchMatmulMeta **)task->local_args); - BatchMatmulPerDeviceState const *meta = - *((BatchMatmulPerDeviceState **)task->local_args); - auto a_input = acc.get_tensor(A_INPUT); - auto b_input = acc.get_tensor(B_INPUT); - auto output = acc.get_tensor(OUTPUT); - - int m = b_input.shape[0]; - assert(m == output.shape[0]); - int n = a_input.shape[1]; - assert(n == output.shape[1]); - int k = a_input.shape[0]; - assert(k == b_input.shape[1]); + int m = b_input.shape[legion_dim_t(0)]; + assert(m == output.shape[legion_dim_t(0)]); + int n = a_input.shape[legion_dim_t(1)]; + assert(n == output.shape[legion_dim_t(1)]); + int k = a_input.shape[legion_dim_t(0)]; + assert(k == b_input.shape[legion_dim_t(1)]); assert(a_input.shape.size() == b_input.shape.size()); assert(a_input.shape.size() == output.shape.size()); + int batch = 1; for (int i = 2; i < a_input.shape.size(); i++) { - int dim_size = a_input.shape[i]; - assert(dim_size == b_input.shape[i]); - assert(dim_size == output.shape[i]); + int dim_size = a_input.shape[legion_dim_t(i)]; + assert(dim_size == b_input.shape[legion_dim_t(i)]); + assert(dim_size == output.shape[legion_dim_t(i)]); batch *= dim_size; } - float *out_ptr = output.get_float_ptr(); - c float const *a_ptr = a_input.get_float_ptr(); - float const *b_ptr = b_input.get_float_ptr(); - float const *c_ptr = NULL; - // if (regions.size() == 4) { - // Domain c_domain = runtime->get_index_space_domain( - // ctx, task->regions[3].region.get_index_space()); - // assert(c_domain == a_domain); - // c_ptr = helperGetTensorPointerRO( - // regions[3], task->regions[3], FID_DATA, ctx, runtime); - // } - - profile(forward_kernel, - meta->profiling, + + return profile(forward_kernel, + profiling, "[BatchMatmul] forward_time = %.2lfms\n", - out_ptr, - a_ptr, - b_ptr, - c_ptr, + per_device_state, + output.get_float_ptr(), + a_input.get_float_ptr(), + b_input.get_float_ptr(), + NULL, //c_ptr m, n, k, batch, - meta->a_seq_length_dim, - meta->b_seq_length_dim, iter_config->seq_length); } -void BatchMatmul::backward(FFModel const &ff) { - int dim = outputs[0]->num_dims; - switch (dim) { -#define DIMFUNC(DIM) \ - case DIM: { \ - backward_with_dim(ff); \ - break; \ - } - LEGION_FOREACH_N(DIMFUNC) -#undef DIMFUNC - default: - assert(false); - } +static void forward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + forward_task_impl(acc); } -/* - regions[0](I): output - regions[1](I): output_grad - regions[2](I): A - regions[3](I/O): A_grad - regions[4](I): B - regions[5](I/O): B_grad - regions[6](I/O): C_grad -*/ -// template -// void BatchMatmul::backward_with_dim(FFModel const &ff) { -// ArgumentMap argmap; -// Context ctx = ff.config.lg_ctx; -// Runtime *runtime = ff.config.lg_hlr; -// set_argumentmap_for_backward(ff, argmap); -// IndexLauncher launcher( -// BATCHMATMUL_BWD_TASK_ID, -// parallel_is, -// TaskArgument(&ff.iter_config, sizeof(FFIterationConfig)), -// argmap, -// Predicate::TRUE_PRED, -// false /*must*/, -// 0 /*mapper_id*/, -// outputs[0]->machine_view.hash()); -// // regions[0](I): output -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// outputs[0]->region)); -// launcher.add_field(0, FID_DATA); -// // regions[1](I): output_grad -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// outputs[0]->region_grad)); -// launcher.add_field(1, FID_DATA); -// // regions[2](I): A -// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[0]->region)); -// launcher.add_field(2, FID_DATA); -// // regions[3](I/O): A_grad -// launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, -// 0 /*projection id*/, -// READ_WRITE, -// EXCLUSIVE, -// inputs[0]->region_grad)); -// launcher.add_field(3, FID_DATA); -// // regions[4](I): B -// launcher.add_region_requirement(RegionRequirement(inputs[1]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[1]->region)); -// launcher.add_field(4, FID_DATA); -// // regions[5](I/O): B_grad -// launcher.add_region_requirement(RegionRequirement(inputs[1]->part_grad, -// 0 /*projection id*/, -// READ_WRITE, -// EXCLUSIVE, -// inputs[1]->region_grad)); -// launcher.add_field(5, FID_DATA); -// runtime->execute_index_space(ctx, launcher); -// } - -/* - regions[0](I): output - regions[1](I): output_grad - regions[2](I): A - regions[3](I/O): A_grad - regions[4](I): B - regions[5](I/O): B_grad - regions[6](I/O): C_grad -*/ -__host__ void - BatchMatmul::backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { +static optional backward_task_impl(TaskArgumentAccessor const &acc) { // Currently assume C is NULL assert(regions.size() == 6); assert(task->regions.size() == 6); + // BatchMatmul* bmm = (BatchMatmul*) task->args; - TaskArgumentAccessor acc(task, regions, ctx, runtime); FFIterationConfig const *iter_config = (FFIterationConfig const *)task->args; - BatchMatmulPerDeviceState const *meta = - *((BatchMatmulPerDeviceState **)task->local_args); - // output domains - auto output = acc.get_tensor(OUTPUT); - auto output_grad = acc.get_tensor(OUTPUT_GRAD); - assert(output == - output_grad); // is this equivalent to checking `Domain` equality? - // A domains - auto a_input = acc.get_tensor(A_INPUT); - auto a_input_grad = acc.get_tensor(A_INPUT_GRAD); + ProfilingSettings profiling = acc.get_argument(PROFILING); + auto per_device_state = acc.get_argument(PER_DEVICE_STATE); + + auto output = acc.get_tensor(OUTPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); + // is this equivalent to checking `Domain` equality? + assert(output == output_grad); + + auto a_input = acc.get_tensor(A_INPUT); + auto a_input_grad = acc.get_tensor_grad(A_INPUT); assert(a_input == a_input_grad); - // B domains - auto b_input = acc.get_tensor(B_INPUT); - auto b_input_grad = acc.get_tensor(B_INPUT_GRAD); + + auto b_input = acc.get_tensor(B_INPUT); + auto b_input_grad = acc.get_tensor_grad(B_INPUT); assert(b_input == b_input_grad); // check dins - int m = b_input.shape[0]; - assert(m == output.shape[0]); - int n = a_input.shape[1]; - assert(n == output.shape[1]); - int k = a_input.shape[0]; - assert(k == b_input.shape[1]); + int m = b_input.shape[legion_dim_t(0)]; + assert(m == output.shape[legion_dim_t(0)]); + int n = a_input.shape[legion_dim_t(1)]; + assert(n == output.shape[legion_dim_t(1)]); + int k = a_input.shape[legion_dim_t(0)]; + assert(k == b_input.shape[legion_dim_t(1)]); assert(a_input.shape.size() == b_input.shape.size()); assert(a_input.shape.size() == output.shape.size()); int batch = 1; for (int i = 2; i < a_input.shape.size(); i++) { - int dim_size = a_input.shape[i]; - assert(dim_size == b_input.shape[i]); - assert(dim_size == output.shape[i]); + int dim_size = a_input.shape[legion_dim_t(i)]; + assert(dim_size == b_input.shape[legion_dim_t(i)]); + assert(dim_size == output.shape[legion_dim_t(i)]); batch *= dim_size; } - // get pointers - float const *out_ptr = output.get_float_ptr(); - float const *out_grad_ptr = output_grad.get_float_ptr(); - float const *a_ptr = a_input.get_float_ptr(); - float *a_grad_ptr = a_input_grad.get_float_ptr(); - float const *b_ptr = b_input.get_float_ptr(); - float *b_grad_ptr = b_input_grad.get_float_ptr(); - - float *c_grad_ptr = NULL; // TODO: add support for meta->a_seq_length_dim >= 0 // or meta->b_seq_length_dim >= 0 assert((meta->a_seq_length_dim >= a_len) || (iter_config->seq_length == 0)); assert((meta->b_seq_length_dim >= b_len) || (iter_config->seq_length == 0)); - profile(backward_kernel, - meta->profiling, + return profile(backward_kernel, + profiling, "[BatchMatmul] backward_time = %.2lfms\n", - meta, - out_ptr, - out_grad_ptr, - a_ptr, - a_grad_ptr, - b_ptr, - b_grad_ptr, - c_grad_ptr, + per_device_state, + output.get_float_ptr(), + output_grad.get_float_ptr(), + a_input.get_float_ptr(), + a_input_grad.get_float_ptr(), + b_input.get_float_ptr(), + b_input_grad.get_float_ptr(), + NULL, //c_grad_ptr m, n, k, batch); } -void BatchMatmul::print_layer(FFModel const &ff) { - return; +static void backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + backward_task_impl(acc); } -bool BatchMatmul::measure_operator_cost(Simulator *sim, - MachineView const &pc, - CostMetrics &cost_metrics) const { - ParallelTensorBase sub_output, sub_input0, sub_input1; - if (!outputs[0]->get_sub_tensor(pc, sub_output)) { - return false; - } - if (!inputs[0]->get_sub_tensor(pc, sub_input0)) { - return false; - } - if (!inputs[1]->get_sub_tensor(pc, sub_input1)) { - return false; - } - - int input0_c = sub_input0.dims[0].size; - int input0_r = sub_input0.dims[1].size; - int input1_c = sub_input1.dims[0].size; - int input1_r = sub_input1.dims[1].size; - int output_c = sub_output.dims[0].size; - int output_r = sub_output.dims[1].size; - - assert(input0_c == input1_r); - assert(input0_r == output_r); - assert(input1_c == output_c); - - assert(sub_input0.dims[2] == sub_input1.dims[2]); - assert(sub_input1.dims[2] == sub_output.dims[2]); - int batch = 1; - assert(sub_input0.num_dims == sub_input1.num_dims); - for (int i = 2; i < sub_input0.num_dims; i++) { - assert(sub_input0.dims[i] == sub_input1.dims[i]); - assert(sub_input0.dims[i] == sub_output.dims[i]); - batch *= sub_input0.dims[i].size; - } - - BatchMatmulPerDeviceState *meta = sim->batch_matmul_meta; +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + BatchMatmulAttrs const &attrs, + InputParallelTensorDesc const &a_input, + InputParallelTensorDesc const &b_input, + ProfilingSettings const &settings, + MachineView const &pc) const { + auto env = sim.new_environment(); + + //todo add get_output_shape and get_weights_shape to batch_matmul op-attrs + // ParallelTensorShape output_shape = get_output_shape(attrs, inputs); + // ParallelTensorShape weight_shape = get_weights_shape(attrs, inputs); + + SimTaskBinding init_binding; + init_binding.bind_arg(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); + init_binding.bind_arg(B_SEQ_LENGTH_DIM, get_bSeqLengthDim(attrs)); + init_binding.bind_arg(HANDLE, ff_handle()); + + auto init_accessor = + env.get_init_accessor(BATCHMATMUL_INIT_TASK_ID, init_binding); + DeviceSpecificArg per_device_state = + init_task_impl(init_accessor); + + SimTaskBinding fwd_binding; + fwd_binding.bind(A_INPUT, a_input); + fwd_binding.bind(B_INPUT, b_input); + // fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind_arg(PROFILING, settings); + fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); + + SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); + + auto fwd_accessor = env.get_fwd_accessor(BATCHMATMUL_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = env.get_bwd_accessor(BATCHMATMUL_BWD_TASK_ID, bwd_binding); + + float forward_time = forward_task_impl(fwd_accessor).value(); + float backward_time = backward_task_impl(bwd_accessor).value(); + + float sync_time = default_estimate_sync_time(env); + return make_metrics(forward_time, backward_time, sync_time, env); +} - // allocate tensors in simulator - sim->free_all(); - float *a_ptr = (float *)sim->allocate(sub_input0.get_volume(), DT_FLOAT); - assert(a_ptr != NULL); - float *b_ptr = (float *)sim->allocate(sub_input1.get_volume(), DT_FLOAT); - assert(b_ptr != NULL); - float *c_ptr = NULL; - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); +template <> +void register_task() { + OpTaskSignature init(OpTaskType::INIT); - float *out_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - assert(out_ptr != NULL); - cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); + init.add_arg_slot(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); + init.add_arg_slot(B_SEQ_LENGTH_DIM, get_bSeqLengthDim(attrs)); + init.add_unchecked_arg_slot(HANDLE, ff_handle()); - int m = input1_c; - int n = input0_r; - int k = input0_c; + register_task(BATCHMATMUL_INIT_TASK_ID, "BatchMatmul Init", init, init_task); +} - assert(meta->profiling == false); +template <> +void register_task() { + OpTaskSignature fwd(OpTaskType::FWD); - std::function forward, backward; - forward = [&](ffStream_t stream) { - forward_kernel(stream, meta, out_ptr, a_ptr, b_ptr, c_ptr, m, n, k, batch); - }; + fwd.add_input_slot(A_INPUT); + fwd.add_input_slot(B_INPUT); + fwd.add_output_slot(OUTPUT); + fwd.add_arg_slot(PROFILING); + fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); - if (sim->computationMode == COMP_MODE_TRAINING) { - float *a_grad_ptr = - (float *)sim->allocate(sub_input0.get_volume(), DT_FLOAT); - float *b_grad_ptr = - (float *)sim->allocate(sub_input1.get_volume(), DT_FLOAT); - float *c_grad_ptr = NULL; - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *out_grad_ptr = - (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - assert(out_grad_ptr != NULL); - cost_metrics.outputs_memory += - cost_metrics.total_mem_diff_from(sim->offset); - - backward = [&](ffStream_t stream) { - backward_kernel(stream, - meta, - out_ptr, - out_grad_ptr, - a_ptr, - a_grad_ptr, - b_ptr, - b_grad_ptr, - c_grad_ptr, - m, - n, - k, - batch); - }; - } + register_task(BATCHMATMUL_FWD_TASK_ID, "BatchMatmul Fwd", fwd, forward_task); +} - inner_measure_operator_cost(sim, forward, backward, cost_metrics); - - if (sim->computationMode == COMP_MODE_TRAINING) { - printf("[Measure BatchMatmul] name(%s) adim(%d %d %d) bdim(%d %d %d) " - "odim(%d %d %d) forward_time(%.4lf) backward_time(%.4lf)\n", - name, - batch, - input0_r, - input0_c, - batch, - input1_r, - input1_c, - batch, - output_r, - output_c, - cost_metrics.forward_time, - cost_metrics.backward_time); - } else { - printf("[Measure BatchMatmul] name(%s) adim(%d %d %d) bdim(%d %d %d) " - "odim(%d %d %d) forward_time(%.4lf)\n", - name, - batch, - input0_r, - input0_c, - batch, - input1_r, - input1_c, - batch, - output_r, - output_c, - cost_metrics.forward_time); - } +template <> +void register_task() { + OpTaskSignature bwd = + infer_bwd_signature(get_op_signature(ATTENTION_FWD_TASK_ID)); - return true; + register_task(BATCHMATMUL_BWD_TASK_ID, "BatchMatmul Bwd", bwd, backward_task); } -} -; // namespace FlexFlow + +}; // namespace FlexFlow diff --git a/lib/runtime/src/ops/batch_matmul.h b/lib/runtime/src/ops/batch_matmul.h index c133c2a875..ab9bb45e8d 100644 --- a/lib/runtime/src/ops/batch_matmul.h +++ b/lib/runtime/src/ops/batch_matmul.h @@ -1,10 +1,14 @@ #ifndef _FLEXFLOW_BATCH_MATMUL_H #define _FLEXFLOW_BATCH_MATMUL_H +// #include "op-attrs/ops/batch_matmul.h" +// #include "task_spec/op_task_invocation.h" +// #include "task_spec/op_task_signature.h" +// #include "sim_environment.h" + #include "op-attrs/ops/batch_matmul.h" -#include "op_task_invocation.h" -#include "op_task_signature.h" #include "sim_environment.h" +#include "task_spec/op_task_invocation.h" namespace FlexFlow { @@ -19,12 +23,12 @@ OpTaskInvocation init(BatchMatmulAttrs const &); OpTaskInvocation forward(BatchMatmulAttrs const &); OpTaskInvocation backward(BatchMatmulAttrs const &); -CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, - BatchMatmulAttrs const &attrs, - ParallelTensorShape const &lhs_input_shape, - ParallelTensorShape const &rhs_input_shape, - ProfilingSettings const &settings, - MachineView const &); +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + BatchMatmulAttrs const &attrs, + InputParallelTensorDesc const &a_input, + InputParallelTensorDesc const &b_input, + ProfilingSettings const &settings, + MachineView const &pc); /* class BatchMatmul : public Op { */ /* public: */ @@ -84,3 +88,424 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, } // namespace FlexFlow #endif + + + +// BatchMatmulParams BatchMatmul::get_params() const { +// BatchMatmulParams params; +// params.a_seq_length_dim = inputs[0]->num_dims - 1 - this->a_seq_length_dim; +// params.b_seq_length_dim = inputs[1]->num_dims - 1 - this->b_seq_length_dim; +// return params; +// } + +// Tensor FFModel::batch_matmul(const Tensor A, +// const Tensor B, +// int a_seq_length_dim, +// int b_seq_length_dim, +// char const *name) { +// Layer *bmm = new Layer(this, +// OP_BATCHMATMUL, +// DT_FLOAT, +// name, +// 2 /*inputs*/, +// 0 /*weights*/, +// 1 /*outputs*/, +// A, +// B); +// assert((a_seq_length_dim <= 1) && +// "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " +// "Fortran ordering)."); +// assert((b_seq_length_dim <= 1) && +// "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " +// "Fortran ordering)."); +// assert(A->num_dims == B->num_dims); +// for (int i = A->num_dims - 1; i >= 2; i--) { +// assert(A->dims[i] == B->dims[i]); +// } +// assert(A->dims[0] == B->dims[1]); +// int dims[MAX_TENSOR_DIM]; +// int numdim = A->num_dims; +// for (int i = 0; i < numdim; i++) { +// dims[i] = A->dims[i]; +// } +// dims[0] = B->dims[0]; +// bmm->outputs[0] = create_tensor_legion_ordering( +// numdim, dims, A->data_type, bmm, 0, true /*create_grad*/); +// bmm->add_int_property("a_seq_length_dim", a_seq_length_dim); +// bmm->add_int_property("b_seq_length_dim", b_seq_length_dim); +// layers.push_back(bmm); +// return bmm->outputs[0]; +// } + +// Op *BatchMatmul::create_operator_from_layer( +// FFModel &model, +// Layer const *layer, +// std::vector const &inputs) { +// long long value; +// layer->get_int_property("a_seq_length_dim", value); +// int a_seq_length_dim = value; +// layer->get_int_property("b_seq_length_dim", value); +// int b_seq_length_dim = value; +// return new BatchMatmul(model, +// inputs[0], +// inputs[1], +// a_seq_length_dim, +// b_seq_length_dim, +// layer->name); +// } + +// BatchMatmul::BatchMatmul( +// FFModel &model, +// BatchMatmulParams const ¶ms, +// std::pair const &inputs, +// char const *name) +// : BatchMatmul(model, +// inputs.first, +// inputs.second, +// params.a_seq_length_dim, +// params.b_seq_length_dim, +// name) {} + +// // return A*B +// BatchMatmul::BatchMatmul(FFModel &model, +// const ParallelTensor A, +// const ParallelTensor B, +// int _a_seq_length_dim, +// int _b_seq_length_dim, +// char const *name) +// : Op(model, +// OP_BATCHMATMUL, +// DT_FLOAT, +// name, +// 2 /*inputs*/, +// 0 /*weights*/, +// 1 /*outputs*/, +// A, +// B), +// a_seq_length_dim(A->num_dims - 1 - _a_seq_length_dim), +// b_seq_length_dim(B->num_dims - 1 - _b_seq_length_dim) { +// assert((_a_seq_length_dim <= 1) && +// "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " +// "Fortran ordering)."); +// assert((_b_seq_length_dim <= 1) && +// "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " +// "Fortran ordering)."); +// assert(A->num_dims == B->num_dims); +// for (int i = A->num_dims - 1; i >= 2; i--) { +// assert(A->dims[i] == B->dims[i]); +// } +// assert(A->dims[0] == B->dims[1]); +// ParallelDim dims[MAX_TENSOR_DIM]; +// for (int i = 0; i < A->num_dims; i++) { +// dims[i] = A->dims[i]; +// } +// dims[0] = B->dims[0]; +// numOutputs = 1; +// outputs[0] = model.create_parallel_tensor_legion_ordering( +// A->num_dims, dims, DT_FLOAT, this); +// // C is not none +// // if (C != Tensor::NO_TENSOR) { +// // numInputs = 3; +// // assert(C.num_dims == outputs[0].num_dims); +// // for (int i = 0; i < C.num_dims; i++) +// // assert(C.adim[i] == outputs[0].adim[i]); +// //} +// } + +// void BatchMatmul::serialize(Legion::Serializer &sez) const { +// BatchMatmulParams params = get_params(); +// sez.serialize(params.a_seq_length_dim); +// sez.serialize(params.b_seq_length_dim); +// } + +// using PCG::Node; +// /*static*/ +// Node BatchMatmul::deserialize(FFModel &ff, +// Legion::Deserializer &dez, +// ParallelTensor inputs[], +// int num_inputs) { +// assert(num_inputs == 2); +// int a_seq_length_dim, b_seq_length_dim; +// dez.deserialize(a_seq_length_dim); +// dez.deserialize(b_seq_length_dim); + +// BatchMatmulParams params; +// params.a_seq_length_dim = a_seq_length_dim; +// params.b_seq_length_dim = b_seq_length_dim; +// return ff.get_or_create_node({inputs[0], inputs[1]}, params); +// } + +// Op *BatchMatmul::materialize(FFModel &ff, +// ParallelTensor inputs[], +// int num_inputs) const { +// BatchMatmulParams params = get_params(); +// return new BatchMatmul(ff, params, {inputs[0], inputs[1]}, this->name); +// } + + +// void BatchMatmul::forward(FFModel const &ff) { +// int dim = outputs[0]->num_dims; +// switch (dim) { +// #define DIMFUNC(DIM) \ +// case DIM: { \ +// // forward_with_dim(ff); +// this->execute_task(ff, BATCHMATMUL_FWD_TASK_ID, get_fwd_task_signature()); +// break; +// } +// LEGION_FOREACH_N(DIMFUNC) +// #undef DIMFUNC +// default: +// assert(false); +// } +// } + +// template +// void BatchMatmul::forward_with_dim(FFModel const &ff) { +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_forward(ff, argmap); +// IndexLauncher launcher( +// BATCHMATMUL_FWD_TASK_ID, +// parallel_is, +// TaskArgument(&ff.iter_config, sizeof(FFIterationConfig)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// for (int i = 0; i < numInputs; i++) { +// launcher.add_region_requirement(RegionRequirement(inputs[i]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[i]->region)); +// launcher.add_field(i + 1, FID_DATA); +// } +// runtime->execute_index_space(ctx, launcher); +// } + +/* + regions[0](O): output + regions[1](I): A + regions[2](I): B + ////////////////////(optional) regions[3](I): C -- TODO: is C deprecated? + output = A * B /////////+ C +*/ + + +// void BatchMatmul::init(FFModel const &ff) { +// int dim = outputs[0]->num_dims; +// switch (dim) { +// #define DIMFUNC(DIM) \ +// case DIM: { \ +// // init_with_dim(ff); +// this->execute_task(ff, BATCHMATMUL_INIT_TASK_ID, get_init_task_signature()); +// break; +// } +// LEGION_FOREACH_N(DIMFUNC) +// #undef DIMFUNC +// default: +// assert(false); +// } +// } // namespace FlexFlow +// // / +// // template +// // void BatchMatmul::init_with_dim(FFModel const &ff) { +// // assert(check_output_input_weight_same_parallel_is()); +// // parallel_is = outputs[0]->parallel_is; +// // ArgumentMap argmap; +// // Context ctx = ff.config.lg_ctx; +// // Runtime *runtime = ff.config.lg_hlr; +// // set_argumentmap_for_init(ff, argmap); +// // IndexLauncher launcher(BATCHMATMUL_INIT_TASK_ID, +// // parallel_is, +// // TaskArgument(this, sizeof(BatchMatmul)), +// // argmap, +// // Predicate::TRUE_PRED, +// // false /*must*/, +// // 0 /*mapper_id*/, +// // outputs[0]->machine_view.hash()); +// // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// // 0 /*projection id*/, +// // WRITE_ONLY, +// // EXCLUSIVE, +// // outputs[0]->region)); +// // launcher.add_field(0, FID_DATA); +// // for (int i = 0; i < numInputs; i++) { +// // launcher.add_region_requirement(RegionRequirement(inputs[i]->part, +// // 0 /*projection id*/, +// // READ_ONLY, +// // EXCLUSIVE, +// // inputs[i]->region)); +// // launcher.add_field(i + 1, FID_DATA); +// // } +// // FutureMap fm = runtime->execute_index_space(ctx, launcher); +// // fm.wait_all_results(); +// // set_opmeta_from_futuremap(ff, fm); +// // } + +// OpTaskBinding BatchMatmul::get_bwd_task_binding() const { +// OpTaskBinding binding; +// binding.bind(A_INPUT, input_tensor(0)); +// binding.bind(B_INPUT, input_tensor(1)); +// binding.bind_grad(A_INPUT_GRAD, input_tensor(0).grad()); +// binding.bind_grad(B_INPUT_GRAD, input_tensor(1).grad()); + +// binding.bind(OUTPUT, output_tensor(0)); +// binding.bind_grad(OUTPUT_GRAD, output_tensor(0).grad()); + +// binding.bind_arg(ATTRS, this->attrs); +// return binding; +// } + + +// static OpTaskSignature get_fwd_task_signature() { +// OpTaskSignature fwd(OpTaskType::FWD); + +// fwd.add_input_slot(A_INPUT, READ_WRITE); +// fwd.add_input_slot(B_INPUT, READ_WRITE); +// fwd.add_output_slot(OUTPUT); + +// return fwd; +// } + +// static OpTaskSignature get_bwd_task_signature() { +// OpTaskSignature bwd(OpTaskType::BWD); + +// bwd.add_input_slot(A_INPUT); +// bwd.add_input_slot(B_INPUT); +// bwd.add_input_grad_slot(A_INPUT_GRAD); +// bwd.add_input_grad_slot(B_INPUT_GRAD); +// bwd.add_output_slot(OUTPUT); +// bwd.add_output_grad_slot(OUTPUT_GRAD); + +// return bwd; +// } + +// OpTaskBinding BatchMatmul::get_init_task_binding() const { +// OpTaskBinding binding; + +// binding.bind_arg(ATTRS, this->attrs); +// binding.bind_arg(PROFILING, this->profiling); + +// return binding; +// } + +// OpTaskBinding BatchMatmul::get_fwd_task_binding() const { +// OpTaskBinding binding; + +// binding.bind(A_INPUT, input_tensor(0)); +// binding.bind(B_INPUT, input_tensor(1)); +// binding.bind(OUTPUT, output_tensor(0)); + +// binding.bind_arg(ATTRS, this->attrs); +// return binding; +// } + +//void BatchMatmul::backward(FFModel const &ff) { +// int dim = outputs[0]->num_dims; +// switch (dim) { +// #define DIMFUNC(DIM) \ +// case DIM: { \ +// backward_with_dim(ff); \ +// break; \ +// } +// LEGION_FOREACH_N(DIMFUNC) +// #undef DIMFUNC +// default: +// assert(false); +// } +// } + + +// void BatchMatmul::print_layer(FFModel const &ff) { +// return; +// } + + + +/* + regions[0](I): output + regions[1](I): output_grad + regions[2](I): A + regions[3](I/O): A_grad + regions[4](I): B + regions[5](I/O): B_grad + regions[6](I/O): C_grad +*/ +// template +// void BatchMatmul::backward_with_dim(FFModel const &ff) { +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_backward(ff, argmap); +// IndexLauncher launcher( +// BATCHMATMUL_BWD_TASK_ID, +// parallel_is, +// TaskArgument(&ff.iter_config, sizeof(FFIterationConfig)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// // regions[0](I): output +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// // regions[1](I): output_grad +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// outputs[0]->region_grad)); +// launcher.add_field(1, FID_DATA); +// // regions[2](I): A +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(2, FID_DATA); +// // regions[3](I/O): A_grad +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// inputs[0]->region_grad)); +// launcher.add_field(3, FID_DATA); +// // regions[4](I): B +// launcher.add_region_requirement(RegionRequirement(inputs[1]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[1]->region)); +// launcher.add_field(4, FID_DATA); +// // regions[5](I/O): B_grad +// launcher.add_region_requirement(RegionRequirement(inputs[1]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// inputs[1]->region_grad)); +// launcher.add_field(5, FID_DATA); +// runtime->execute_index_space(ctx, launcher); +// } + +/* + regions[0](I): output + regions[1](I): output_grad + regions[2](I): A + regions[3](I/O): A_grad + regions[4](I): B + regions[5](I/O): B_grad + regions[6](I/O): C_grad +*/ \ No newline at end of file From d9aae744914fb6d710a5f174d4e88cc0cd7a95d5 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 22 Aug 2023 14:29:35 -0700 Subject: [PATCH 02/33] fix FF_VISITABLE_STRUCT_NO_EQ --- lib/kernels/include/kernels/batch_matmul_kernels.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index cdffdf1907..b8e6b9cd4b 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -14,7 +14,10 @@ struct BMMPerDeviceState { }; FF_VISITABLE_STRUCT_NO_EQ(BMMPerDeviceState, - handle,); + handle, + allocator, + a_seq_length_dim, + b_seq_length_dim); namespace Kernels { namespace BatchMatmul { From 2073822dbc697615407763f10a8383e864c6929e Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Wed, 23 Aug 2023 12:03:39 -0700 Subject: [PATCH 03/33] finish draft 1 batch_matmul --- .../include/kernels/batch_matmul_kernels.h | 7 ++-- lib/runtime/include/runtime/config.h | 8 +++-- lib/runtime/src/ops/batch_matmul.cc | 32 ++++++++++--------- lib/runtime/src/task_spec/op_arg_ref.h | 8 ++++- 4 files changed, 32 insertions(+), 23 deletions(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index b8e6b9cd4b..1506305e9e 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -10,7 +10,7 @@ struct BMMPerDeviceState { PerDeviceFFHandle handle; Allocator allocator; int a_seq_length_dim; - int b_seq_length_dim; + req b_seq_length_dim; }; FF_VISITABLE_STRUCT_NO_EQ(BMMPerDeviceState, @@ -28,7 +28,7 @@ BMMPerDeviceState init_kernel(PerDeviceFFHandle handle, int b_seq_length_dim); void forward_kernel(ffStream_t stream, - BMMPerDeviceState const *meta, + BMMPerDeviceState const &meta, float *o_ptr, float const *a_ptr, float const *b_ptr, @@ -40,14 +40,13 @@ void forward_kernel(ffStream_t stream, int seq_length = -1); void backward_kernel(ffStream_t stream, - BMMPerDeviceState const *meta, + BMMPerDeviceState const &meta, float const *o_ptr, float const *o_grad_ptr, float const *a_ptr, float *a_grad_ptr, float const *b_ptr, float *b_grad_ptr, - float *c_grad_ptr, int m, int n, int k, diff --git a/lib/runtime/include/runtime/config.h b/lib/runtime/include/runtime/config.h index a7b8d86171..fc45c2c76c 100644 --- a/lib/runtime/include/runtime/config.h +++ b/lib/runtime/include/runtime/config.h @@ -104,13 +104,15 @@ struct FFConfig : public use_visitable_cmp { int python_data_loader_type = 2; }; -class FFIterationConfig { -public: +struct FFIterationConfig { FFIterationConfig(); void reset(); - int seq_length; + req seq_length; }; +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(FFIterationConfig, + seq_length); + enum FieldIDs { FID_DATA, }; diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 9bbc050f8b..761c565657 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -41,7 +41,8 @@ enum Slots { HANDLE, A_SEQ_LENGTH_DIM, B_SEQ_LENGTH_DIM, - PER_DEVICE_STATE + PER_DEVICE_STATE, + ITERATION_CONFIG }; OpTaskInvocation init(BatchMatmulAttrs const &attrs) { @@ -63,6 +64,7 @@ OpTaskInvocation forward(BatchMatmulAttrs const &attrs) { fwd.bind_arg(PROFILING, profiling_settings()); fwd.bind_arg(PER_DEVICE_STATE, per_device_op_state()); + fwd.bind_arg(ITERATION_CONFIG, iteration_config()); return {BATCHMATMUL_FWD_TASK_ID, fwd}; } @@ -110,7 +112,7 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto per_device_state = acc.get_argument(PER_DEVICE_STATE); - FFIterationConfig const *iter_config = (FFIterationConfig const *)task->args; + FFIterationConfig iter_config = acc.get_argument(ITERATION_CONFIG); int m = b_input.shape[legion_dim_t(0)]; assert(m == output.shape[legion_dim_t(0)]); @@ -123,7 +125,7 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { assert(a_input.shape.size() == output.shape.size()); int batch = 1; - for (int i = 2; i < a_input.shape.size(); i++) { + for (int i = 2; i < a_input.shape.get_dim(); i++) { //get_dim() or get_volume()? int dim_size = a_input.shape[legion_dim_t(i)]; assert(dim_size == b_input.shape[legion_dim_t(i)]); assert(dim_size == output.shape[legion_dim_t(i)]); @@ -137,12 +139,12 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { output.get_float_ptr(), a_input.get_float_ptr(), b_input.get_float_ptr(), - NULL, //c_ptr + nullptr, //c_ptr m, n, k, batch, - iter_config->seq_length); + iter_config.seq_length); } static void forward_task(Task const *task, @@ -159,7 +161,7 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(task->regions.size() == 6); // BatchMatmul* bmm = (BatchMatmul*) task->args; - FFIterationConfig const *iter_config = (FFIterationConfig const *)task->args; + FFIterationConfig iter_config = acc.get_argument(ITERATION_CONFIG); ProfilingSettings profiling = acc.get_argument(PROFILING); auto per_device_state = acc.get_argument(PER_DEVICE_STATE); @@ -186,7 +188,7 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(a_input.shape.size() == b_input.shape.size()); assert(a_input.shape.size() == output.shape.size()); int batch = 1; - for (int i = 2; i < a_input.shape.size(); i++) { + for (int i = 2; i < a_input.shape.get_dim(); i++) { //@colin get_dim() or get_volume()? int dim_size = a_input.shape[legion_dim_t(i)]; assert(dim_size == b_input.shape[legion_dim_t(i)]); assert(dim_size == output.shape[legion_dim_t(i)]); @@ -195,8 +197,8 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { // TODO: add support for meta->a_seq_length_dim >= 0 // or meta->b_seq_length_dim >= 0 - assert((meta->a_seq_length_dim >= a_len) || (iter_config->seq_length == 0)); - assert((meta->b_seq_length_dim >= b_len) || (iter_config->seq_length == 0)); + assert((meta->a_seq_length_dim >= a_len) || (iter_config.seq_length == 0)); + assert((meta->b_seq_length_dim >= b_len) || (iter_config.seq_length == 0)); return profile(backward_kernel, profiling, @@ -208,7 +210,6 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { a_input_grad.get_float_ptr(), b_input.get_float_ptr(), b_input_grad.get_float_ptr(), - NULL, //c_grad_ptr m, n, k, @@ -228,10 +229,10 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, InputParallelTensorDesc const &a_input, InputParallelTensorDesc const &b_input, ProfilingSettings const &settings, - MachineView const &pc) const { + MachineView const &pc) { auto env = sim.new_environment(); - //todo add get_output_shape and get_weights_shape to batch_matmul op-attrs + // @colin todo add get_output_shape and get_weights_shape to batch_matmul op-attrs // ParallelTensorShape output_shape = get_output_shape(attrs, inputs); // ParallelTensorShape weight_shape = get_weights_shape(attrs, inputs); @@ -248,6 +249,7 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, SimTaskBinding fwd_binding; fwd_binding.bind(A_INPUT, a_input); fwd_binding.bind(B_INPUT, b_input); + //@colin will uncomment after get_output_shape is done // fwd_binding.bind(OUTPUT, output_shape); fwd_binding.bind_arg(PROFILING, settings); fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); @@ -268,9 +270,9 @@ template <> void register_task() { OpTaskSignature init(OpTaskType::INIT); - init.add_arg_slot(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); - init.add_arg_slot(B_SEQ_LENGTH_DIM, get_bSeqLengthDim(attrs)); - init.add_unchecked_arg_slot(HANDLE, ff_handle()); + init.add_arg_slot(A_SEQ_LENGTH_DIM); + init.add_arg_slot(B_SEQ_LENGTH_DIM); + init.add_unchecked_arg_slot(HANDLE); register_task(BATCHMATMUL_INIT_TASK_ID, "BatchMatmul Init", init, init_task); } diff --git a/lib/runtime/src/task_spec/op_arg_ref.h b/lib/runtime/src/task_spec/op_arg_ref.h index 6e921c05e8..bc4f38e5fb 100644 --- a/lib/runtime/src/task_spec/op_arg_ref.h +++ b/lib/runtime/src/task_spec/op_arg_ref.h @@ -4,10 +4,12 @@ #include "arg_ref.h" #include "device_specific_arg.h" #include "op-attrs/parallel_tensor_shape.h" +#include "runtime/config.h" + namespace FlexFlow { -enum class OpArgRefType { PER_DEVICE_OP_STATE, PARALLEL_TENSOR_SHAPE }; +enum class OpArgRefType { PER_DEVICE_OP_STATE, PARALLEL_TENSOR_SHAPE, ITERATION_CONFIG }; template using OpArgRef = ArgRef; @@ -23,6 +25,10 @@ OpArgRef input_parallel_tensor_shape(int idx) { return {OpArgRefType::PARALLEL_TENSOR_SHAPE}; } +OpArgRef iteration_config() { + return {OpArgRefType::ITERATION_CONFIG}; +} + } // namespace FlexFlow #endif From d5806a01163ab9fb13f0457c7cc4ba4e330b41d0 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Wed, 23 Aug 2023 13:35:55 -0700 Subject: [PATCH 04/33] add output and weights --- lib/op-attrs/include/op-attrs/ops/batch_matmul.h | 3 +++ lib/runtime/src/ops/batch_matmul.cc | 15 ++++----------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index 00c700ba20..fa80ed3da9 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -14,6 +14,9 @@ FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim); int get_aSeqLengthDim(BatchMatmulAttrs const &attrs); int get_bSeqLengthDim(BatchMatmulAttrs const &attrs); +ParallelTensorShape get_weights_shape(BatchMatmulAttrs const &, + ParallelTensorShape const &, + ParallelTensorShape const &); CHECK_VALID_OP_ATTR(BatchMatmulAttrs); } // namespace FlexFlow diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 761c565657..3fabc86d1a 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -13,16 +13,11 @@ * limitations under the License. */ -// #include "batch_matmul.h" -// #include "kernels/batch_matmul_kernels.h" -// #include "kernels/profiling.h" -// #include "legion/legion_utilities.h" -// #include "tasks.h" - #include "batch_matmul.h" #include "kernels/batch_matmul_kernels.h" #include "legion.h" #include "op-attrs/ops/batch_matmul.h" +#include "op-attrs/get_output_shapes.h" namespace FlexFlow { @@ -232,9 +227,8 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, MachineView const &pc) { auto env = sim.new_environment(); - // @colin todo add get_output_shape and get_weights_shape to batch_matmul op-attrs - // ParallelTensorShape output_shape = get_output_shape(attrs, inputs); - // ParallelTensorShape weight_shape = get_weights_shape(attrs, inputs); + ParallelTensorShape output_shape = get_output_shape(attrs, a_input.shape, b_input.shape); + ParallelTensorShape weight_shape = get_weights_shape(attrs, a_input.shape, b_input.shape); SimTaskBinding init_binding; init_binding.bind_arg(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); @@ -249,8 +243,7 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, SimTaskBinding fwd_binding; fwd_binding.bind(A_INPUT, a_input); fwd_binding.bind(B_INPUT, b_input); - //@colin will uncomment after get_output_shape is done - // fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind(OUTPUT, output_shape); fwd_binding.bind_arg(PROFILING, settings); fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); From f620070ea8bc8e7373915a5063e1a5977ed7f433 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Wed, 23 Aug 2023 13:45:00 -0700 Subject: [PATCH 05/33] format --- .../include/kernels/batch_matmul_kernels.h | 15 +-- .../include/op-attrs/ops/batch_matmul.h | 4 +- lib/runtime/include/runtime/config.h | 3 +- lib/runtime/src/ops/batch_matmul.cc | 106 +++++++++--------- lib/runtime/src/ops/batch_matmul.h | 62 +++++----- lib/runtime/src/task_spec/op_arg_ref.h | 7 +- 6 files changed, 97 insertions(+), 100 deletions(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index 1506305e9e..23522a89f7 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -13,19 +13,16 @@ struct BMMPerDeviceState { req b_seq_length_dim; }; -FF_VISITABLE_STRUCT_NO_EQ(BMMPerDeviceState, - handle, - allocator, - a_seq_length_dim, - b_seq_length_dim); +FF_VISITABLE_STRUCT_NO_EQ( + BMMPerDeviceState, handle, allocator, a_seq_length_dim, b_seq_length_dim); namespace Kernels { namespace BatchMatmul { BMMPerDeviceState init_kernel(PerDeviceFFHandle handle, - Allocator allocator, - int a_seq_length_dim, - int b_seq_length_dim); + Allocator allocator, + int a_seq_length_dim, + int b_seq_length_dim); void forward_kernel(ffStream_t stream, BMMPerDeviceState const &meta, @@ -40,7 +37,7 @@ void forward_kernel(ffStream_t stream, int seq_length = -1); void backward_kernel(ffStream_t stream, - BMMPerDeviceState const &meta, + BMMPerDeviceState const &meta, float const *o_ptr, float const *o_grad_ptr, float const *a_ptr, diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index fa80ed3da9..6ac2a2cf69 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -15,8 +15,8 @@ FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim); int get_aSeqLengthDim(BatchMatmulAttrs const &attrs); int get_bSeqLengthDim(BatchMatmulAttrs const &attrs); ParallelTensorShape get_weights_shape(BatchMatmulAttrs const &, - ParallelTensorShape const &, - ParallelTensorShape const &); + ParallelTensorShape const &, + ParallelTensorShape const &); CHECK_VALID_OP_ATTR(BatchMatmulAttrs); } // namespace FlexFlow diff --git a/lib/runtime/include/runtime/config.h b/lib/runtime/include/runtime/config.h index fc45c2c76c..21cfca20cf 100644 --- a/lib/runtime/include/runtime/config.h +++ b/lib/runtime/include/runtime/config.h @@ -110,8 +110,7 @@ struct FFIterationConfig { req seq_length; }; -FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(FFIterationConfig, - seq_length); +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(FFIterationConfig, seq_length); enum FieldIDs { FID_DATA, diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 3fabc86d1a..0f4ad9e865 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -16,8 +16,8 @@ #include "batch_matmul.h" #include "kernels/batch_matmul_kernels.h" #include "legion.h" -#include "op-attrs/ops/batch_matmul.h" #include "op-attrs/get_output_shapes.h" +#include "op-attrs/ops/batch_matmul.h" namespace FlexFlow { @@ -29,16 +29,16 @@ using Legion::Runtime; using Legion::Task; enum Slots { - A_INPUT, //tensor - B_INPUT, //tensor - OUTPUT, //tensor + A_INPUT, // tensor + B_INPUT, // tensor + OUTPUT, // tensor PROFILING, HANDLE, A_SEQ_LENGTH_DIM, B_SEQ_LENGTH_DIM, PER_DEVICE_STATE, ITERATION_CONFIG - }; +}; OpTaskInvocation init(BatchMatmulAttrs const &attrs) { OpTaskBinding init; @@ -70,18 +70,16 @@ OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { return {BATCHMATMUL_BWD_TASK_ID, bwd}; } -static DeviceSpecificArg init_task_impl(TaskArgumentAccessor const &acc) { +static DeviceSpecificArg + init_task_impl(TaskArgumentAccessor const &acc) { auto const a_seq_length_dim = acc.get_argument(A_SEQ_LENGTH_DIM); auto const b_seq_length_dim = acc.get_argument(B_SEQ_LENGTH_DIM); PerDeviceFFHandle handle = acc.get_argument(HANDLE); Allocator allocator = acc.get_allocator(); - DeviceSpecificArg per_device_state = + DeviceSpecificArg per_device_state = acc.create_device_specific( - init_kernel(handle, - allocator, - a_seq_length_dim, - b_seq_length_dim)); + init_kernel(handle, allocator, a_seq_length_dim, b_seq_length_dim)); // assert(weight.shape.get_volume() * sizeof(float) == // acc.unwrap(per_device_state)->weightSize); @@ -107,7 +105,8 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto per_device_state = acc.get_argument(PER_DEVICE_STATE); - FFIterationConfig iter_config = acc.get_argument(ITERATION_CONFIG); + FFIterationConfig iter_config = + acc.get_argument(ITERATION_CONFIG); int m = b_input.shape[legion_dim_t(0)]; assert(m == output.shape[legion_dim_t(0)]); @@ -120,7 +119,8 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { assert(a_input.shape.size() == output.shape.size()); int batch = 1; - for (int i = 2; i < a_input.shape.get_dim(); i++) { //get_dim() or get_volume()? + for (int i = 2; i < a_input.shape.get_dim(); + i++) { // get_dim() or get_volume()? int dim_size = a_input.shape[legion_dim_t(i)]; assert(dim_size == b_input.shape[legion_dim_t(i)]); assert(dim_size == output.shape[legion_dim_t(i)]); @@ -128,18 +128,18 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { } return profile(forward_kernel, - profiling, - "[BatchMatmul] forward_time = %.2lfms\n", - per_device_state, - output.get_float_ptr(), - a_input.get_float_ptr(), - b_input.get_float_ptr(), - nullptr, //c_ptr - m, - n, - k, - batch, - iter_config.seq_length); + profiling, + "[BatchMatmul] forward_time = %.2lfms\n", + per_device_state, + output.get_float_ptr(), + a_input.get_float_ptr(), + b_input.get_float_ptr(), + nullptr, // c_ptr + m, + n, + k, + batch, + iter_config.seq_length); } static void forward_task(Task const *task, @@ -156,14 +156,15 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(task->regions.size() == 6); // BatchMatmul* bmm = (BatchMatmul*) task->args; - FFIterationConfig iter_config = acc.get_argument(ITERATION_CONFIG); + FFIterationConfig iter_config = + acc.get_argument(ITERATION_CONFIG); ProfilingSettings profiling = acc.get_argument(PROFILING); auto per_device_state = acc.get_argument(PER_DEVICE_STATE); auto output = acc.get_tensor(OUTPUT); auto output_grad = acc.get_tensor_grad(OUTPUT); // is this equivalent to checking `Domain` equality? - assert(output == output_grad); + assert(output == output_grad); auto a_input = acc.get_tensor(A_INPUT); auto a_input_grad = acc.get_tensor_grad(A_INPUT); @@ -183,7 +184,8 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(a_input.shape.size() == b_input.shape.size()); assert(a_input.shape.size() == output.shape.size()); int batch = 1; - for (int i = 2; i < a_input.shape.get_dim(); i++) { //@colin get_dim() or get_volume()? + for (int i = 2; i < a_input.shape.get_dim(); + i++) { //@colin get_dim() or get_volume()? int dim_size = a_input.shape[legion_dim_t(i)]; assert(dim_size == b_input.shape[legion_dim_t(i)]); assert(dim_size == output.shape[legion_dim_t(i)]); @@ -196,19 +198,19 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert((meta->b_seq_length_dim >= b_len) || (iter_config.seq_length == 0)); return profile(backward_kernel, - profiling, - "[BatchMatmul] backward_time = %.2lfms\n", - per_device_state, - output.get_float_ptr(), - output_grad.get_float_ptr(), - a_input.get_float_ptr(), - a_input_grad.get_float_ptr(), - b_input.get_float_ptr(), - b_input_grad.get_float_ptr(), - m, - n, - k, - batch); + profiling, + "[BatchMatmul] backward_time = %.2lfms\n", + per_device_state, + output.get_float_ptr(), + output_grad.get_float_ptr(), + a_input.get_float_ptr(), + a_input_grad.get_float_ptr(), + b_input.get_float_ptr(), + b_input_grad.get_float_ptr(), + m, + n, + k, + batch); } static void backward_task(Task const *task, @@ -220,15 +222,17 @@ static void backward_task(Task const *task, } CostMetrics measure_operator_cost(SimEnvFactory const &sim, - BatchMatmulAttrs const &attrs, - InputParallelTensorDesc const &a_input, - InputParallelTensorDesc const &b_input, - ProfilingSettings const &settings, - MachineView const &pc) { + BatchMatmulAttrs const &attrs, + InputParallelTensorDesc const &a_input, + InputParallelTensorDesc const &b_input, + ProfilingSettings const &settings, + MachineView const &pc) { auto env = sim.new_environment(); - ParallelTensorShape output_shape = get_output_shape(attrs, a_input.shape, b_input.shape); - ParallelTensorShape weight_shape = get_weights_shape(attrs, a_input.shape, b_input.shape); + ParallelTensorShape output_shape = + get_output_shape(attrs, a_input.shape, b_input.shape); + ParallelTensorShape weight_shape = + get_weights_shape(attrs, a_input.shape, b_input.shape); SimTaskBinding init_binding; init_binding.bind_arg(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); @@ -249,8 +253,10 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); - auto fwd_accessor = env.get_fwd_accessor(BATCHMATMUL_FWD_TASK_ID, fwd_binding); - auto bwd_accessor = env.get_bwd_accessor(BATCHMATMUL_BWD_TASK_ID, bwd_binding); + auto fwd_accessor = + env.get_fwd_accessor(BATCHMATMUL_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = + env.get_bwd_accessor(BATCHMATMUL_BWD_TASK_ID, bwd_binding); float forward_time = forward_task_impl(fwd_accessor).value(); float backward_time = backward_task_impl(bwd_accessor).value(); diff --git a/lib/runtime/src/ops/batch_matmul.h b/lib/runtime/src/ops/batch_matmul.h index ab9bb45e8d..42c9bc23de 100644 --- a/lib/runtime/src/ops/batch_matmul.h +++ b/lib/runtime/src/ops/batch_matmul.h @@ -24,11 +24,11 @@ OpTaskInvocation forward(BatchMatmulAttrs const &); OpTaskInvocation backward(BatchMatmulAttrs const &); CostMetrics measure_operator_cost(SimEnvFactory const &sim, - BatchMatmulAttrs const &attrs, - InputParallelTensorDesc const &a_input, - InputParallelTensorDesc const &b_input, - ProfilingSettings const &settings, - MachineView const &pc); + BatchMatmulAttrs const &attrs, + InputParallelTensorDesc const &a_input, + InputParallelTensorDesc const &b_input, + ProfilingSettings const &settings, + MachineView const &pc); /* class BatchMatmul : public Op { */ /* public: */ @@ -89,8 +89,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, #endif - - // BatchMatmulParams BatchMatmul::get_params() const { // BatchMatmulParams params; // params.a_seq_length_dim = inputs[0]->num_dims - 1 - this->a_seq_length_dim; @@ -242,15 +240,14 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, // return new BatchMatmul(ff, params, {inputs[0], inputs[1]}, this->name); // } - // void BatchMatmul::forward(FFModel const &ff) { // int dim = outputs[0]->num_dims; // switch (dim) { -// #define DIMFUNC(DIM) \ -// case DIM: { \ +// #define DIMFUNC(DIM) \ +// case DIM: { \ // // forward_with_dim(ff); -// this->execute_task(ff, BATCHMATMUL_FWD_TASK_ID, get_fwd_task_signature()); -// break; +// this->execute_task(ff, BATCHMATMUL_FWD_TASK_ID, +// get_fwd_task_signature()); break; // } // LEGION_FOREACH_N(DIMFUNC) // #undef DIMFUNC @@ -299,15 +296,14 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, output = A * B /////////+ C */ - // void BatchMatmul::init(FFModel const &ff) { // int dim = outputs[0]->num_dims; // switch (dim) { -// #define DIMFUNC(DIM) \ -// case DIM: { \ +// #define DIMFUNC(DIM) \ +// case DIM: { \ // // init_with_dim(ff); -// this->execute_task(ff, BATCHMATMUL_INIT_TASK_ID, get_init_task_signature()); -// break; +// this->execute_task(ff, BATCHMATMUL_INIT_TASK_ID, +// get_init_task_signature()); break; // } // LEGION_FOREACH_N(DIMFUNC) // #undef DIMFUNC @@ -365,7 +361,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, // return binding; // } - // static OpTaskSignature get_fwd_task_signature() { // OpTaskSignature fwd(OpTaskType::FWD); @@ -409,28 +404,25 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, // return binding; // } -//void BatchMatmul::backward(FFModel const &ff) { -// int dim = outputs[0]->num_dims; -// switch (dim) { -// #define DIMFUNC(DIM) \ -// case DIM: { \ -// backward_with_dim(ff); \ -// break; \ +// void BatchMatmul::backward(FFModel const &ff) { +// int dim = outputs[0]->num_dims; +// switch (dim) { +// #d ef ine DIMFUNC(DIM) \ +// case DIM: { \ +// backward_with_dim(ff); \ +// break; \ // } -// LEGION_FOREACH_N(DIMFUNC) -// #undef DIMFUNC -// default: -// assert(false); -// } -// } - +// LEGION_FOREACH_N(DIMFUNC) +// #undef DIMFUNC +// default: +// assert(false); +// } +// } // void BatchMatmul::print_layer(FFModel const &ff) { // return; // } - - /* regions[0](I): output regions[1](I): output_grad @@ -508,4 +500,4 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, regions[4](I): B regions[5](I/O): B_grad regions[6](I/O): C_grad -*/ \ No newline at end of file +*/ diff --git a/lib/runtime/src/task_spec/op_arg_ref.h b/lib/runtime/src/task_spec/op_arg_ref.h index bc4f38e5fb..07316c877b 100644 --- a/lib/runtime/src/task_spec/op_arg_ref.h +++ b/lib/runtime/src/task_spec/op_arg_ref.h @@ -6,10 +6,13 @@ #include "op-attrs/parallel_tensor_shape.h" #include "runtime/config.h" - namespace FlexFlow { -enum class OpArgRefType { PER_DEVICE_OP_STATE, PARALLEL_TENSOR_SHAPE, ITERATION_CONFIG }; +enum class OpArgRefType { + PER_DEVICE_OP_STATE, + PARALLEL_TENSOR_SHAPE, + ITERATION_CONFIG +}; template using OpArgRef = ArgRef; From 5458923bc344d2cefbd6bc18b7c3719e73291ba8 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Wed, 23 Aug 2023 15:23:36 -0700 Subject: [PATCH 06/33] fix DeviceSpecific --- lib/runtime/src/ops/batch_matmul.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 0f4ad9e865..59570f4417 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -70,14 +70,14 @@ OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { return {BATCHMATMUL_BWD_TASK_ID, bwd}; } -static DeviceSpecificArg +static DeviceSpecific init_task_impl(TaskArgumentAccessor const &acc) { auto const a_seq_length_dim = acc.get_argument(A_SEQ_LENGTH_DIM); auto const b_seq_length_dim = acc.get_argument(B_SEQ_LENGTH_DIM); PerDeviceFFHandle handle = acc.get_argument(HANDLE); Allocator allocator = acc.get_allocator(); - DeviceSpecificArg per_device_state = + DeviceSpecific per_device_state = acc.create_device_specific( init_kernel(handle, allocator, a_seq_length_dim, b_seq_length_dim)); @@ -86,7 +86,7 @@ static DeviceSpecificArg return per_device_state; } -static DeviceSpecificArg +static DeviceSpecific init_task(Task const *task, std::vector const ®ions, Context ctx, @@ -241,7 +241,7 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, auto init_accessor = env.get_init_accessor(BATCHMATMUL_INIT_TASK_ID, init_binding); - DeviceSpecificArg per_device_state = + DeviceSpecific per_device_state = init_task_impl(init_accessor); SimTaskBinding fwd_binding; From f2205f43b18ca97e332c11d56668e1c0516c6826 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 10:26:59 -0700 Subject: [PATCH 07/33] change --- lib/kernels/include/kernels/batch_matmul_kernels.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index 23522a89f7..6850dec178 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -19,8 +19,8 @@ FF_VISITABLE_STRUCT_NO_EQ( namespace Kernels { namespace BatchMatmul { -BMMPerDeviceState init_kernel(PerDeviceFFHandle handle, - Allocator allocator, +BMMPerDeviceState init_kernel(PerDeviceFFHandle const &handle, + Allocator const &allocator, int a_seq_length_dim, int b_seq_length_dim); From 2f4662da4fcb64dad9de7dd6294e8e6baa3e532e Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 10:28:45 -0700 Subject: [PATCH 08/33] change --- lib/kernels/include/kernels/batch_matmul_kernels.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index 6850dec178..e5062b2c61 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -26,7 +26,7 @@ BMMPerDeviceState init_kernel(PerDeviceFFHandle const &handle, void forward_kernel(ffStream_t stream, BMMPerDeviceState const &meta, - float *o_ptr, + float *output_ptr, float const *a_ptr, float const *b_ptr, float const *c_ptr, From 642eb90d99bdb83b3a212be50232520b497816ba Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 10:34:04 -0700 Subject: [PATCH 09/33] change --- lib/kernels/include/kernels/batch_matmul_kernels.h | 5 ++--- lib/runtime/src/ops/batch_matmul.cc | 1 - 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index e5062b2c61..ec32648d0f 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -27,9 +27,8 @@ BMMPerDeviceState init_kernel(PerDeviceFFHandle const &handle, void forward_kernel(ffStream_t stream, BMMPerDeviceState const &meta, float *output_ptr, - float const *a_ptr, - float const *b_ptr, - float const *c_ptr, + float const *lhs_input_ptr, + float const *rhs_input_ptr, int m, int n, int k, diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 59570f4417..09eb1f23c7 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -134,7 +134,6 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { output.get_float_ptr(), a_input.get_float_ptr(), b_input.get_float_ptr(), - nullptr, // c_ptr m, n, k, From e79406a842117e7823d8d63ed21ad5613c846078 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 10:46:00 -0700 Subject: [PATCH 10/33] change --- lib/kernels/src/cuda/batch_matmul_kernels.cu | 5 +---- lib/kernels/src/hip/batch_matmul_kernels.cpp | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/lib/kernels/src/cuda/batch_matmul_kernels.cu b/lib/kernels/src/cuda/batch_matmul_kernels.cu index 3593ac4ab2..cde0df93c0 100644 --- a/lib/kernels/src/cuda/batch_matmul_kernels.cu +++ b/lib/kernels/src/cuda/batch_matmul_kernels.cu @@ -18,9 +18,6 @@ namespace FlexFlow { -BatchMatmulPerDeviceState::BatchMatmulPerDeviceState(FFHandler handler) - : PerDeviceOpState(handler) {} - namespace Kernels { namespace BatchMatmul { @@ -124,7 +121,7 @@ O = A * B */ void forward_kernel(cudaStream_t stream, - BatchMatmulPerDeviceState const *meta, + BatchMatmulPerDeviceState const &meta, float *o_ptr, float const *a_ptr, float const *b_ptr, diff --git a/lib/kernels/src/hip/batch_matmul_kernels.cpp b/lib/kernels/src/hip/batch_matmul_kernels.cpp index e5334b1841..a06442d3d6 100644 --- a/lib/kernels/src/hip/batch_matmul_kernels.cpp +++ b/lib/kernels/src/hip/batch_matmul_kernels.cpp @@ -29,7 +29,7 @@ O: (batch, n, m) O = A * B */ void forward_kernel(hipStream_t stream, - BatchMatmulPerDeviceState const *meta, + BatchMatmulPerDeviceState const &meta, float *o_ptr, float const *a_ptr, float const *b_ptr, From bb3c10f8a9f46c9a09d514a9c70f4e72e37e3e94 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 10:51:31 -0700 Subject: [PATCH 11/33] change --- lib/op-attrs/include/op-attrs/ops/batch_matmul.h | 3 --- lib/runtime/src/ops/batch_matmul.cc | 6 +----- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index 6ac2a2cf69..00c700ba20 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -14,9 +14,6 @@ FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim); int get_aSeqLengthDim(BatchMatmulAttrs const &attrs); int get_bSeqLengthDim(BatchMatmulAttrs const &attrs); -ParallelTensorShape get_weights_shape(BatchMatmulAttrs const &, - ParallelTensorShape const &, - ParallelTensorShape const &); CHECK_VALID_OP_ATTR(BatchMatmulAttrs); } // namespace FlexFlow diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 09eb1f23c7..669cad215b 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -80,9 +80,7 @@ static DeviceSpecific DeviceSpecific per_device_state = acc.create_device_specific( init_kernel(handle, allocator, a_seq_length_dim, b_seq_length_dim)); - - // assert(weight.shape.get_volume() * sizeof(float) == - // acc.unwrap(per_device_state)->weightSize); + return per_device_state; } @@ -230,8 +228,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, ParallelTensorShape output_shape = get_output_shape(attrs, a_input.shape, b_input.shape); - ParallelTensorShape weight_shape = - get_weights_shape(attrs, a_input.shape, b_input.shape); SimTaskBinding init_binding; init_binding.bind_arg(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); From 848850900f9d44fe29df90c69a2099cb71be299e Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 11:00:34 -0700 Subject: [PATCH 12/33] change --- lib/runtime/include/runtime/config.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/runtime/include/runtime/config.h b/lib/runtime/include/runtime/config.h index 21cfca20cf..ef7e779469 100644 --- a/lib/runtime/include/runtime/config.h +++ b/lib/runtime/include/runtime/config.h @@ -107,7 +107,7 @@ struct FFConfig : public use_visitable_cmp { struct FFIterationConfig { FFIterationConfig(); void reset(); - req seq_length; + int seq_length; }; FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(FFIterationConfig, seq_length); From 36eba295424c5c58ea0e3aedea76e3f3adfd3790 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 11:03:04 -0700 Subject: [PATCH 13/33] change --- lib/runtime/src/ops/batch_matmul.h | 55 ------------------------------ 1 file changed, 55 deletions(-) diff --git a/lib/runtime/src/ops/batch_matmul.h b/lib/runtime/src/ops/batch_matmul.h index 42c9bc23de..018fe1d582 100644 --- a/lib/runtime/src/ops/batch_matmul.h +++ b/lib/runtime/src/ops/batch_matmul.h @@ -30,61 +30,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, ProfilingSettings const &settings, MachineView const &pc); -/* class BatchMatmul : public Op { */ -/* public: */ -/* BatchMatmul(FFModel &model, */ -/* const ParallelTensor A, */ -/* const ParallelTensor B, */ -/* int a_seq_length_dim, */ -/* int b_seq_length_dim, */ -/* char const *name = nullptr); */ -/* static Op * */ -/* create_operator_from_layer(FFModel &model, */ -/* Layer const *layer, */ -/* std::vector const &inputs); - */ - -/* void init(FFModel const &) override; */ -/* void forward(FFModel const &) override; */ -/* void backward(FFModel const &) override; */ -/* /1* static PCG::Node deserialize(FFModel &ff, *1/ */ -/* /1* Legion::Deserializer &d, *1/ */ -/* /1* ParallelTensor inputs[], *1/ */ -/* /1* int num_inputs); *1/ */ -/* static PerDeviceOpState *init_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void forward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void backward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* bool measure_operator_cost(Simulator *sim, */ -/* MachineView const &pc, */ -/* CostMetrics &cost_metrics) const override; */ - -/* OpTaskBinding get_init_task_binding() const override; */ -/* OpTaskBinding get_fwd_task_binding() const override; */ -/* OpTaskBinding get_bwd_task_binding() const override; */ -/* private: */ -/* template */ -/* void init_with_dim(FFModel const &ff); */ -/* template */ -/* void forward_with_dim(FFModel const &ff); */ -/* template */ -/* void backward_with_dim(FFModel const &ff); */ - -/* public: */ -/* int a_seq_length_dim, b_seq_length_dim; */ -/* }; */ - } // namespace FlexFlow #endif From 98773b4237de4a4585711bca14d26b9d3c6e75ff Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 11:15:53 -0700 Subject: [PATCH 14/33] change --- lib/runtime/src/ops/batch_matmul.cc | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 669cad215b..a8c2ec7bd7 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -72,8 +72,8 @@ OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { static DeviceSpecific init_task_impl(TaskArgumentAccessor const &acc) { - auto const a_seq_length_dim = acc.get_argument(A_SEQ_LENGTH_DIM); - auto const b_seq_length_dim = acc.get_argument(B_SEQ_LENGTH_DIM); + int const a_seq_length_dim = acc.get_argument(A_SEQ_LENGTH_DIM); + int const b_seq_length_dim = acc.get_argument(B_SEQ_LENGTH_DIM); PerDeviceFFHandle handle = acc.get_argument(HANDLE); Allocator allocator = acc.get_allocator(); @@ -94,9 +94,6 @@ static DeviceSpecific } static optional forward_task_impl(TaskArgumentAccessor const &acc) { - assert(regions.size() == 3); - assert(task->regions.size() == 3); - auto a_input = acc.get_tensor(A_INPUT); auto b_input = acc.get_tensor(B_INPUT); auto output = acc.get_tensor(OUTPUT); @@ -148,10 +145,6 @@ static void forward_task(Task const *task, } static optional backward_task_impl(TaskArgumentAccessor const &acc) { - // Currently assume C is NULL - assert(regions.size() == 6); - assert(task->regions.size() == 6); - // BatchMatmul* bmm = (BatchMatmul*) task->args; FFIterationConfig iter_config = acc.get_argument(ITERATION_CONFIG); @@ -161,15 +154,15 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { auto output = acc.get_tensor(OUTPUT); auto output_grad = acc.get_tensor_grad(OUTPUT); // is this equivalent to checking `Domain` equality? - assert(output == output_grad); + assert(output.shape == output_grad.shape); auto a_input = acc.get_tensor(A_INPUT); auto a_input_grad = acc.get_tensor_grad(A_INPUT); - assert(a_input == a_input_grad); + assert(a_input.shape == a_input_grad.shape); auto b_input = acc.get_tensor(B_INPUT); auto b_input_grad = acc.get_tensor_grad(B_INPUT); - assert(b_input == b_input_grad); + assert(b_input.shape == b_input_grad.shape); // check dins int m = b_input.shape[legion_dim_t(0)]; @@ -181,8 +174,8 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(a_input.shape.size() == b_input.shape.size()); assert(a_input.shape.size() == output.shape.size()); int batch = 1; - for (int i = 2; i < a_input.shape.get_dim(); - i++) { //@colin get_dim() or get_volume()? + for (int i = 2; i < a_input.shape.dims.num_dims(); + i++) { int dim_size = a_input.shape[legion_dim_t(i)]; assert(dim_size == b_input.shape[legion_dim_t(i)]); assert(dim_size == output.shape[legion_dim_t(i)]); From ae592613cf3fa6c763ca0f14b30a1899831ee4a9 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 11:28:18 -0700 Subject: [PATCH 15/33] change --- lib/runtime/src/task_spec/op_arg_ref.h | 10 ++-------- lib/runtime/src/task_spec/runtime_arg_ref.h | 2 ++ 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/lib/runtime/src/task_spec/op_arg_ref.h b/lib/runtime/src/task_spec/op_arg_ref.h index f2361af05d..20ce6892a2 100644 --- a/lib/runtime/src/task_spec/op_arg_ref.h +++ b/lib/runtime/src/task_spec/op_arg_ref.h @@ -4,15 +4,13 @@ #include "arg_ref.h" #include "device_specific.h" #include "op-attrs/parallel_tensor_shape.h" -#include "runtime/config.h" namespace FlexFlow { enum class OpArgRefType { PER_DEVICE_OP_STATE, - PARALLEL_TENSOR_SHAPE, - ITERATION_CONFIG -}; + PARALLEL_TENSOR_SHAPE + }; template using OpArgRef = ArgRef; @@ -28,10 +26,6 @@ OpArgRef input_parallel_tensor_shape(int idx) { return {OpArgRefType::PARALLEL_TENSOR_SHAPE}; } -OpArgRef iteration_config() { - return {OpArgRefType::ITERATION_CONFIG}; -} - } // namespace FlexFlow #endif diff --git a/lib/runtime/src/task_spec/runtime_arg_ref.h b/lib/runtime/src/task_spec/runtime_arg_ref.h index 6b4345091a..033c2bcfbc 100644 --- a/lib/runtime/src/task_spec/runtime_arg_ref.h +++ b/lib/runtime/src/task_spec/runtime_arg_ref.h @@ -3,6 +3,7 @@ #include "arg_ref.h" #include "device_specific.h" +#include "runtime/config.h" namespace FlexFlow { @@ -15,6 +16,7 @@ using RuntimeArgRefSpec = ArgRefSpec; RuntimeArgRef profiling_settings(); RuntimeArgRef> ff_handle(); +RuntimeArgRef iteration_config(); } // namespace FlexFlow From 8db251252670a8107430318da166e2ba059879ca Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 17:50:28 -0700 Subject: [PATCH 16/33] fix asserts --- lib/runtime/src/ops/attention.cc | 15 --------------- lib/runtime/src/ops/batch_matmul.cc | 13 ++++--------- 2 files changed, 4 insertions(+), 24 deletions(-) diff --git a/lib/runtime/src/ops/attention.cc b/lib/runtime/src/ops/attention.cc index bca87bdb53..94e2b03731 100644 --- a/lib/runtime/src/ops/attention.cc +++ b/lib/runtime/src/ops/attention.cc @@ -121,18 +121,6 @@ static DeviceSpecific int num_samples = get_piece_shape(query_parallel_tensor_shape)[ff_dim_t(2)]; int num_heads = get_piece_shape(weight_parallel_tensor_shape)[ff_dim_t(1)]; - assert(qoSeqLength == query.shape[legion_dim_t(1)]); - assert(qSize == query.shape[legion_dim_t(0)]); - assert(num_samples == key.shape[legion_dim_t(2)]); - assert(kvSeqLength == key.shape[legion_dim_t(1)]); - assert(kSize == key.shape[legion_dim_t(0)]); - assert(num_samples == value.shape[legion_dim_t(2)]); - assert(kvSeqLength == value.shape[legion_dim_t(1)]); - assert(vSize == value.shape[legion_dim_t(0)]); - assert(num_samples == output.shape[legion_dim_t(2)]); - assert(qoSeqLength == output.shape[legion_dim_t(1)]); - assert(oProjSize == output.shape[legion_dim_t(0)]); - DeviceSpecific per_device_state = acc.create_device_specific( init_kernel(handle, @@ -149,9 +137,6 @@ static DeviceSpecific qoSeqLength, kvSeqLength, attrs.add_bias_kv)); - - assert(weight.shape.get_volume() * sizeof(float) == - acc.unwrap(per_device_state)->weightSize); return per_device_state; } diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index a8c2ec7bd7..00699652e7 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -110,8 +110,8 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { int k = a_input.shape[legion_dim_t(0)]; assert(k == b_input.shape[legion_dim_t(1)]); - assert(a_input.shape.size() == b_input.shape.size()); - assert(a_input.shape.size() == output.shape.size()); + assert(a_input.shape.get_volume() == b_input.shape.get_volume()); + assert(a_input.shape.get_volume() == output.shape.get_volume()); int batch = 1; for (int i = 2; i < a_input.shape.get_dim(); @@ -171,8 +171,8 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(n == output.shape[legion_dim_t(1)]); int k = a_input.shape[legion_dim_t(0)]; assert(k == b_input.shape[legion_dim_t(1)]); - assert(a_input.shape.size() == b_input.shape.size()); - assert(a_input.shape.size() == output.shape.size()); + assert(a_input.shape.get_volume() == b_input.shape.get_volume()); + assert(a_input.shape.get_volume() == output.shape.get_volume()); int batch = 1; for (int i = 2; i < a_input.shape.dims.num_dims(); i++) { @@ -182,11 +182,6 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { batch *= dim_size; } - // TODO: add support for meta->a_seq_length_dim >= 0 - // or meta->b_seq_length_dim >= 0 - assert((meta->a_seq_length_dim >= a_len) || (iter_config.seq_length == 0)); - assert((meta->b_seq_length_dim >= b_len) || (iter_config.seq_length == 0)); - return profile(backward_kernel, profiling, "[BatchMatmul] backward_time = %.2lfms\n", From 09acbe545ad4d6d1cab1e14e55bf1d7e9a638954 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 17:59:10 -0700 Subject: [PATCH 17/33] format --- lib/runtime/src/ops/batch_matmul.cc | 5 ++--- lib/runtime/src/task_spec/op_arg_ref.h | 5 +---- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 00699652e7..45c5e11b9c 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -80,7 +80,7 @@ static DeviceSpecific DeviceSpecific per_device_state = acc.create_device_specific( init_kernel(handle, allocator, a_seq_length_dim, b_seq_length_dim)); - + return per_device_state; } @@ -174,8 +174,7 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(a_input.shape.get_volume() == b_input.shape.get_volume()); assert(a_input.shape.get_volume() == output.shape.get_volume()); int batch = 1; - for (int i = 2; i < a_input.shape.dims.num_dims(); - i++) { + for (int i = 2; i < a_input.shape.dims.num_dims(); i++) { int dim_size = a_input.shape[legion_dim_t(i)]; assert(dim_size == b_input.shape[legion_dim_t(i)]); assert(dim_size == output.shape[legion_dim_t(i)]); diff --git a/lib/runtime/src/task_spec/op_arg_ref.h b/lib/runtime/src/task_spec/op_arg_ref.h index 20ce6892a2..3e931d79a4 100644 --- a/lib/runtime/src/task_spec/op_arg_ref.h +++ b/lib/runtime/src/task_spec/op_arg_ref.h @@ -7,10 +7,7 @@ namespace FlexFlow { -enum class OpArgRefType { - PER_DEVICE_OP_STATE, - PARALLEL_TENSOR_SHAPE - }; +enum class OpArgRefType { PER_DEVICE_OP_STATE, PARALLEL_TENSOR_SHAPE }; template using OpArgRef = ArgRef; From c5be7a72048b563a72377b6806b7caf00d4f00d5 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Mon, 11 Sep 2023 12:19:01 -0700 Subject: [PATCH 18/33] batch_matmul --- .../include/kernels/batch_matmul_kernels.h | 8 +- lib/kernels/src/cuda/batch_matmul_kernels.cu | 177 ++------ lib/runtime/src/ops/batch_matmul.h | 420 +----------------- 3 files changed, 43 insertions(+), 562 deletions(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index ec32648d0f..6c7fdf8d8f 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -3,6 +3,8 @@ #include "kernels/device.h" #include "kernels/ff_handle.h" +#include "kernels/allocation.h" +#include "utils/visitable.h" namespace FlexFlow { @@ -13,7 +15,7 @@ struct BMMPerDeviceState { req b_seq_length_dim; }; -FF_VISITABLE_STRUCT_NO_EQ( +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION( BMMPerDeviceState, handle, allocator, a_seq_length_dim, b_seq_length_dim); namespace Kernels { @@ -27,8 +29,8 @@ BMMPerDeviceState init_kernel(PerDeviceFFHandle const &handle, void forward_kernel(ffStream_t stream, BMMPerDeviceState const &meta, float *output_ptr, - float const *lhs_input_ptr, - float const *rhs_input_ptr, + float const *a_input_ptr, + float const *b_input_ptr, int m, int n, int k, diff --git a/lib/kernels/src/cuda/batch_matmul_kernels.cu b/lib/kernels/src/cuda/batch_matmul_kernels.cu index cde0df93c0..000b3d307b 100644 --- a/lib/kernels/src/cuda/batch_matmul_kernels.cu +++ b/lib/kernels/src/cuda/batch_matmul_kernels.cu @@ -14,176 +14,81 @@ */ #include "kernels/batch_matmul_kernels.h" -#include "kernels/cuda_helper.h" +#include "kernels/device.h" +#include "device.h" namespace FlexFlow { - namespace Kernels { namespace BatchMatmul { -/* void forward_kernel_wrapper(BatchMatmulPerDeviceState const *meta, */ -/* float *o_ptr, */ -/* float const *a_ptr, */ -/* float const *b_ptr, */ -/* float const *c_ptr, */ -/* int m, */ -/* int n, */ -/* int k, */ -/* int batch, */ -/* int a_seq_length_dim, */ -/* int b_seq_length_dim, */ -/* int seq_length) { */ -/* cudaStream_t stream; */ -/* */ - -/* cudaEvent_t t_start, t_end; */ -/* if (meta->profiling) { */ -/* cudaEventCreate(&t_start); */ -/* cudaEventCreate(&t_end); */ -/* cudaEventRecord(t_start, stream); */ -/* } */ -/* Internal::forward_kernel(meta, */ -/* o_ptr, */ -/* a_ptr, */ -/* b_ptr, */ -/* c_ptr, */ -/* m, */ -/* n, */ -/* k, */ -/* batch, */ -/* stream, */ -/* a_seq_length_dim, */ -/* b_seq_length_dim, */ -/* seq_length); */ -/* if (meta->profiling) { */ -/* cudaEventRecord(t_end, stream); */ -/* checkCUDA(cudaEventSynchronize(t_end)); */ -/* float elapsed = 0; */ -/* checkCUDA(cudaEventElapsedTime(&elapsed, t_start, t_end)); */ -/* cudaEventDestroy(t_start); */ -/* cudaEventDestroy(t_end); */ -/* printf("BatchMatmul forward time = %.2lfms\n", elapsed); */ -/* } */ -/* } */ - -/* void backward_kernel_wrapper(BatchMatmulPerDeviceState const *meta, */ -/* float const *o_ptr, */ -/* float const *o_grad_ptr, */ -/* float const *a_ptr, */ -/* float *a_grad_ptr, */ -/* float const *b_ptr, */ -/* float *b_grad_ptr, */ -/* float *c_grad_ptr, */ -/* int m, */ -/* int n, */ -/* int k, */ -/* int batch) { */ -/* cudaStream_t stream; */ -/* */ - -/* cudaEvent_t t_start, t_end; */ -/* if (meta->profiling) { */ -/* cudaEventCreate(&t_start); */ -/* cudaEventCreate(&t_end); */ -/* cudaEventRecord(t_start, stream); */ -/* } */ -/* Internal::backward_kernel(meta, */ -/* o_ptr, */ -/* o_grad_ptr, */ -/* a_ptr, */ -/* a_grad_ptr, */ -/* b_ptr, */ -/* b_grad_ptr, */ -/* c_grad_ptr, */ -/* m, */ -/* n, */ -/* k, */ -/* batch, */ -/* stream); */ -/* if (meta->profiling) { */ -/* cudaEventRecord(t_end, stream); */ -/* checkCUDA(cudaEventSynchronize(t_end)); */ -/* float elapsed = 0; */ -/* checkCUDA(cudaEventElapsedTime(&elapsed, t_start, t_end)); */ -/* cudaEventDestroy(t_start); */ -/* cudaEventDestroy(t_end); */ -/* printf("BatchMatmul backward time = %.2lfms\n", elapsed); */ -/* } */ -/* } */ - -/* namespace Internal { */ +BMMPerDeviceState init_kernel(PerDeviceFFHandle const &handle, + Allocator const &allocator, + int a_seq_length_dim, + int b_seq_length_dim) { -/* -A: (batch, n, k) -B: (batch, k, m) -O: (batch, n, m) -O = A * B -*/ + BMMPerDeviceState per_device_state = {handle, + allocator, + a_seq_length_dim, + b_seq_length_dim}; + return per_device_state; +} -void forward_kernel(cudaStream_t stream, - BatchMatmulPerDeviceState const &meta, - float *o_ptr, - float const *a_ptr, - float const *b_ptr, - float const *c_ptr, +void forward_kernel(ffStream_t stream, + BMMPerDeviceState const &meta, + float *output_ptr, + float const *a_input_ptr, + float const *b_input_ptr, int m, int n, int k, int batch, - cudaStream_t stream, - int a_seq_length_dim, - int b_seq_length_dim, int seq_length) { - checkCUDA(cublasSetStream(meta->handle.blas, stream)); - checkCUDNN(cudnnSetStream(meta->handle.dnn, stream)); - - // int a_stride = n * k; - // int b_stride = m * k; - // int o_stride = n * m; + checkCUDA(cublasSetStream(meta.handle.blas, stream)); + checkCUDNN(cudnnSetStream(meta.handle.dnn, stream)); int lda = k; int ldb = m; int ldo = m; long long int strideA = (long long int)n * k; long long int strideB = (long long int)k * m; long long int strideO = (long long int)n * m; - if ((a_seq_length_dim == 0) && (seq_length >= 0)) { + if ((meta.a_seq_length_dim == 0) && (seq_length >= 0)) { assert(seq_length <= k); k = seq_length; - assert(b_seq_length_dim == 1); - } else if ((a_seq_length_dim == 1) && (seq_length >= 0)) { + assert(meta.b_seq_length_dim == 1); + } else if ((meta.a_seq_length_dim == 1) && (seq_length >= 0)) { assert(seq_length <= n); n = seq_length; } else { // currently only support a_seq_length_dim = 0 or 1 - assert((a_seq_length_dim < 0) || (seq_length < 0)); + assert((meta.a_seq_length_dim < 0) || (seq_length < 0)); } - if ((b_seq_length_dim == 0) && (seq_length >= 0)) { + if ((meta.b_seq_length_dim == 0) && (seq_length >= 0)) { assert(seq_length <= m); m = seq_length; - } else if ((b_seq_length_dim == 1) && (seq_length >= 0)) { - assert(a_seq_length_dim == 0); + } else if ((meta.b_seq_length_dim == 1) && (seq_length >= 0)) { + assert(meta.a_seq_length_dim == 0); assert(k == seq_length); } else { // currently only support a_seq_length_dim = 0 or 1 - assert((b_seq_length_dim < 0) || (seq_length < 0)); + assert((meta.b_seq_length_dim < 0) || (seq_length < 0)); } float alpha = 1.0f, beta = 0.0f; - checkCUDA(cublasSgemmStridedBatched(meta->handle.blas, + checkCUDA(cublasSgemmStridedBatched(meta.handle.blas, CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, &alpha, - b_ptr, + b_input_ptr, ldb, strideB, - a_ptr, + a_input_ptr, lda, strideA, &beta, - o_ptr, + output_ptr, ldo, strideO, batch)); @@ -191,34 +96,26 @@ void forward_kernel(cudaStream_t stream, assert(c_ptr == NULL); } -/* -A, AGrad: (batch, n, k) -B, BGrad: (batch, k, m) -O, OGrad: (batch, n, m) -AGrad = OGrad * B^T -BGrad = A^T * OGrad -*/ -void backward_kernel(cudaStream_t stream, - BatchMatmulPerDeviceState const *meta, +void backward_kernel(ffStream_t stream, + BMMPerDeviceState const &meta, float const *o_ptr, float const *o_grad_ptr, float const *a_ptr, float *a_grad_ptr, float const *b_ptr, float *b_grad_ptr, - float *c_grad_ptr, int m, int n, int k, int batch) { - checkCUDA(cublasSetStream(meta->handle.blas, stream)); - checkCUDNN(cudnnSetStream(meta->handle.dnn, stream)); + checkCUDA(cublasSetStream(meta.handle.blas, stream)); + checkCUDNN(cudnnSetStream(meta.handle.dnn, stream)); int a_stride = n * k; int b_stride = m * k; int o_stride = n * m; float alpha = 1.0f; - checkCUDA(cublasSgemmStridedBatched(meta->handle.blas, + checkCUDA(cublasSgemmStridedBatched(meta.handle.blas, CUBLAS_OP_T, CUBLAS_OP_N, k, @@ -236,7 +133,7 @@ void backward_kernel(cudaStream_t stream, k, a_stride, batch)); - checkCUDA(cublasSgemmStridedBatched(meta->handle.blas, + checkCUDA(cublasSgemmStridedBatched(meta.handle.blas, CUBLAS_OP_N, CUBLAS_OP_T, m, diff --git a/lib/runtime/src/ops/batch_matmul.h b/lib/runtime/src/ops/batch_matmul.h index 018fe1d582..1701ab493c 100644 --- a/lib/runtime/src/ops/batch_matmul.h +++ b/lib/runtime/src/ops/batch_matmul.h @@ -1,11 +1,6 @@ #ifndef _FLEXFLOW_BATCH_MATMUL_H #define _FLEXFLOW_BATCH_MATMUL_H -// #include "op-attrs/ops/batch_matmul.h" -// #include "task_spec/op_task_invocation.h" -// #include "task_spec/op_task_signature.h" -// #include "sim_environment.h" - #include "op-attrs/ops/batch_matmul.h" #include "sim_environment.h" #include "task_spec/op_task_invocation.h" @@ -32,417 +27,4 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, } // namespace FlexFlow -#endif - -// BatchMatmulParams BatchMatmul::get_params() const { -// BatchMatmulParams params; -// params.a_seq_length_dim = inputs[0]->num_dims - 1 - this->a_seq_length_dim; -// params.b_seq_length_dim = inputs[1]->num_dims - 1 - this->b_seq_length_dim; -// return params; -// } - -// Tensor FFModel::batch_matmul(const Tensor A, -// const Tensor B, -// int a_seq_length_dim, -// int b_seq_length_dim, -// char const *name) { -// Layer *bmm = new Layer(this, -// OP_BATCHMATMUL, -// DT_FLOAT, -// name, -// 2 /*inputs*/, -// 0 /*weights*/, -// 1 /*outputs*/, -// A, -// B); -// assert((a_seq_length_dim <= 1) && -// "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " -// "Fortran ordering)."); -// assert((b_seq_length_dim <= 1) && -// "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " -// "Fortran ordering)."); -// assert(A->num_dims == B->num_dims); -// for (int i = A->num_dims - 1; i >= 2; i--) { -// assert(A->dims[i] == B->dims[i]); -// } -// assert(A->dims[0] == B->dims[1]); -// int dims[MAX_TENSOR_DIM]; -// int numdim = A->num_dims; -// for (int i = 0; i < numdim; i++) { -// dims[i] = A->dims[i]; -// } -// dims[0] = B->dims[0]; -// bmm->outputs[0] = create_tensor_legion_ordering( -// numdim, dims, A->data_type, bmm, 0, true /*create_grad*/); -// bmm->add_int_property("a_seq_length_dim", a_seq_length_dim); -// bmm->add_int_property("b_seq_length_dim", b_seq_length_dim); -// layers.push_back(bmm); -// return bmm->outputs[0]; -// } - -// Op *BatchMatmul::create_operator_from_layer( -// FFModel &model, -// Layer const *layer, -// std::vector const &inputs) { -// long long value; -// layer->get_int_property("a_seq_length_dim", value); -// int a_seq_length_dim = value; -// layer->get_int_property("b_seq_length_dim", value); -// int b_seq_length_dim = value; -// return new BatchMatmul(model, -// inputs[0], -// inputs[1], -// a_seq_length_dim, -// b_seq_length_dim, -// layer->name); -// } - -// BatchMatmul::BatchMatmul( -// FFModel &model, -// BatchMatmulParams const ¶ms, -// std::pair const &inputs, -// char const *name) -// : BatchMatmul(model, -// inputs.first, -// inputs.second, -// params.a_seq_length_dim, -// params.b_seq_length_dim, -// name) {} - -// // return A*B -// BatchMatmul::BatchMatmul(FFModel &model, -// const ParallelTensor A, -// const ParallelTensor B, -// int _a_seq_length_dim, -// int _b_seq_length_dim, -// char const *name) -// : Op(model, -// OP_BATCHMATMUL, -// DT_FLOAT, -// name, -// 2 /*inputs*/, -// 0 /*weights*/, -// 1 /*outputs*/, -// A, -// B), -// a_seq_length_dim(A->num_dims - 1 - _a_seq_length_dim), -// b_seq_length_dim(B->num_dims - 1 - _b_seq_length_dim) { -// assert((_a_seq_length_dim <= 1) && -// "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " -// "Fortran ordering)."); -// assert((_b_seq_length_dim <= 1) && -// "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " -// "Fortran ordering)."); -// assert(A->num_dims == B->num_dims); -// for (int i = A->num_dims - 1; i >= 2; i--) { -// assert(A->dims[i] == B->dims[i]); -// } -// assert(A->dims[0] == B->dims[1]); -// ParallelDim dims[MAX_TENSOR_DIM]; -// for (int i = 0; i < A->num_dims; i++) { -// dims[i] = A->dims[i]; -// } -// dims[0] = B->dims[0]; -// numOutputs = 1; -// outputs[0] = model.create_parallel_tensor_legion_ordering( -// A->num_dims, dims, DT_FLOAT, this); -// // C is not none -// // if (C != Tensor::NO_TENSOR) { -// // numInputs = 3; -// // assert(C.num_dims == outputs[0].num_dims); -// // for (int i = 0; i < C.num_dims; i++) -// // assert(C.adim[i] == outputs[0].adim[i]); -// //} -// } - -// void BatchMatmul::serialize(Legion::Serializer &sez) const { -// BatchMatmulParams params = get_params(); -// sez.serialize(params.a_seq_length_dim); -// sez.serialize(params.b_seq_length_dim); -// } - -// using PCG::Node; -// /*static*/ -// Node BatchMatmul::deserialize(FFModel &ff, -// Legion::Deserializer &dez, -// ParallelTensor inputs[], -// int num_inputs) { -// assert(num_inputs == 2); -// int a_seq_length_dim, b_seq_length_dim; -// dez.deserialize(a_seq_length_dim); -// dez.deserialize(b_seq_length_dim); - -// BatchMatmulParams params; -// params.a_seq_length_dim = a_seq_length_dim; -// params.b_seq_length_dim = b_seq_length_dim; -// return ff.get_or_create_node({inputs[0], inputs[1]}, params); -// } - -// Op *BatchMatmul::materialize(FFModel &ff, -// ParallelTensor inputs[], -// int num_inputs) const { -// BatchMatmulParams params = get_params(); -// return new BatchMatmul(ff, params, {inputs[0], inputs[1]}, this->name); -// } - -// void BatchMatmul::forward(FFModel const &ff) { -// int dim = outputs[0]->num_dims; -// switch (dim) { -// #define DIMFUNC(DIM) \ -// case DIM: { \ -// // forward_with_dim(ff); -// this->execute_task(ff, BATCHMATMUL_FWD_TASK_ID, -// get_fwd_task_signature()); break; -// } -// LEGION_FOREACH_N(DIMFUNC) -// #undef DIMFUNC -// default: -// assert(false); -// } -// } - -// template -// void BatchMatmul::forward_with_dim(FFModel const &ff) { -// ArgumentMap argmap; -// Context ctx = ff.config.lg_ctx; -// Runtime *runtime = ff.config.lg_hlr; -// set_argumentmap_for_forward(ff, argmap); -// IndexLauncher launcher( -// BATCHMATMUL_FWD_TASK_ID, -// parallel_is, -// TaskArgument(&ff.iter_config, sizeof(FFIterationConfig)), -// argmap, -// Predicate::TRUE_PRED, -// false /*must*/, -// 0 /*mapper_id*/, -// outputs[0]->machine_view.hash()); -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, -// 0 /*projection id*/, -// WRITE_ONLY, -// EXCLUSIVE, -// outputs[0]->region)); -// launcher.add_field(0, FID_DATA); -// for (int i = 0; i < numInputs; i++) { -// launcher.add_region_requirement(RegionRequirement(inputs[i]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[i]->region)); -// launcher.add_field(i + 1, FID_DATA); -// } -// runtime->execute_index_space(ctx, launcher); -// } - -/* - regions[0](O): output - regions[1](I): A - regions[2](I): B - ////////////////////(optional) regions[3](I): C -- TODO: is C deprecated? - output = A * B /////////+ C -*/ - -// void BatchMatmul::init(FFModel const &ff) { -// int dim = outputs[0]->num_dims; -// switch (dim) { -// #define DIMFUNC(DIM) \ -// case DIM: { \ -// // init_with_dim(ff); -// this->execute_task(ff, BATCHMATMUL_INIT_TASK_ID, -// get_init_task_signature()); break; -// } -// LEGION_FOREACH_N(DIMFUNC) -// #undef DIMFUNC -// default: -// assert(false); -// } -// } // namespace FlexFlow -// // / -// // template -// // void BatchMatmul::init_with_dim(FFModel const &ff) { -// // assert(check_output_input_weight_same_parallel_is()); -// // parallel_is = outputs[0]->parallel_is; -// // ArgumentMap argmap; -// // Context ctx = ff.config.lg_ctx; -// // Runtime *runtime = ff.config.lg_hlr; -// // set_argumentmap_for_init(ff, argmap); -// // IndexLauncher launcher(BATCHMATMUL_INIT_TASK_ID, -// // parallel_is, -// // TaskArgument(this, sizeof(BatchMatmul)), -// // argmap, -// // Predicate::TRUE_PRED, -// // false /*must*/, -// // 0 /*mapper_id*/, -// // outputs[0]->machine_view.hash()); -// // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, -// // 0 /*projection id*/, -// // WRITE_ONLY, -// // EXCLUSIVE, -// // outputs[0]->region)); -// // launcher.add_field(0, FID_DATA); -// // for (int i = 0; i < numInputs; i++) { -// // launcher.add_region_requirement(RegionRequirement(inputs[i]->part, -// // 0 /*projection id*/, -// // READ_ONLY, -// // EXCLUSIVE, -// // inputs[i]->region)); -// // launcher.add_field(i + 1, FID_DATA); -// // } -// // FutureMap fm = runtime->execute_index_space(ctx, launcher); -// // fm.wait_all_results(); -// // set_opmeta_from_futuremap(ff, fm); -// // } - -// OpTaskBinding BatchMatmul::get_bwd_task_binding() const { -// OpTaskBinding binding; -// binding.bind(A_INPUT, input_tensor(0)); -// binding.bind(B_INPUT, input_tensor(1)); -// binding.bind_grad(A_INPUT_GRAD, input_tensor(0).grad()); -// binding.bind_grad(B_INPUT_GRAD, input_tensor(1).grad()); - -// binding.bind(OUTPUT, output_tensor(0)); -// binding.bind_grad(OUTPUT_GRAD, output_tensor(0).grad()); - -// binding.bind_arg(ATTRS, this->attrs); -// return binding; -// } - -// static OpTaskSignature get_fwd_task_signature() { -// OpTaskSignature fwd(OpTaskType::FWD); - -// fwd.add_input_slot(A_INPUT, READ_WRITE); -// fwd.add_input_slot(B_INPUT, READ_WRITE); -// fwd.add_output_slot(OUTPUT); - -// return fwd; -// } - -// static OpTaskSignature get_bwd_task_signature() { -// OpTaskSignature bwd(OpTaskType::BWD); - -// bwd.add_input_slot(A_INPUT); -// bwd.add_input_slot(B_INPUT); -// bwd.add_input_grad_slot(A_INPUT_GRAD); -// bwd.add_input_grad_slot(B_INPUT_GRAD); -// bwd.add_output_slot(OUTPUT); -// bwd.add_output_grad_slot(OUTPUT_GRAD); - -// return bwd; -// } - -// OpTaskBinding BatchMatmul::get_init_task_binding() const { -// OpTaskBinding binding; - -// binding.bind_arg(ATTRS, this->attrs); -// binding.bind_arg(PROFILING, this->profiling); - -// return binding; -// } - -// OpTaskBinding BatchMatmul::get_fwd_task_binding() const { -// OpTaskBinding binding; - -// binding.bind(A_INPUT, input_tensor(0)); -// binding.bind(B_INPUT, input_tensor(1)); -// binding.bind(OUTPUT, output_tensor(0)); - -// binding.bind_arg(ATTRS, this->attrs); -// return binding; -// } - -// void BatchMatmul::backward(FFModel const &ff) { -// int dim = outputs[0]->num_dims; -// switch (dim) { -// #d ef ine DIMFUNC(DIM) \ -// case DIM: { \ -// backward_with_dim(ff); \ -// break; \ -// } -// LEGION_FOREACH_N(DIMFUNC) -// #undef DIMFUNC -// default: -// assert(false); -// } -// } - -// void BatchMatmul::print_layer(FFModel const &ff) { -// return; -// } - -/* - regions[0](I): output - regions[1](I): output_grad - regions[2](I): A - regions[3](I/O): A_grad - regions[4](I): B - regions[5](I/O): B_grad - regions[6](I/O): C_grad -*/ -// template -// void BatchMatmul::backward_with_dim(FFModel const &ff) { -// ArgumentMap argmap; -// Context ctx = ff.config.lg_ctx; -// Runtime *runtime = ff.config.lg_hlr; -// set_argumentmap_for_backward(ff, argmap); -// IndexLauncher launcher( -// BATCHMATMUL_BWD_TASK_ID, -// parallel_is, -// TaskArgument(&ff.iter_config, sizeof(FFIterationConfig)), -// argmap, -// Predicate::TRUE_PRED, -// false /*must*/, -// 0 /*mapper_id*/, -// outputs[0]->machine_view.hash()); -// // regions[0](I): output -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// outputs[0]->region)); -// launcher.add_field(0, FID_DATA); -// // regions[1](I): output_grad -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// outputs[0]->region_grad)); -// launcher.add_field(1, FID_DATA); -// // regions[2](I): A -// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[0]->region)); -// launcher.add_field(2, FID_DATA); -// // regions[3](I/O): A_grad -// launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, -// 0 /*projection id*/, -// READ_WRITE, -// EXCLUSIVE, -// inputs[0]->region_grad)); -// launcher.add_field(3, FID_DATA); -// // regions[4](I): B -// launcher.add_region_requirement(RegionRequirement(inputs[1]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[1]->region)); -// launcher.add_field(4, FID_DATA); -// // regions[5](I/O): B_grad -// launcher.add_region_requirement(RegionRequirement(inputs[1]->part_grad, -// 0 /*projection id*/, -// READ_WRITE, -// EXCLUSIVE, -// inputs[1]->region_grad)); -// launcher.add_field(5, FID_DATA); -// runtime->execute_index_space(ctx, launcher); -// } - -/* - regions[0](I): output - regions[1](I): output_grad - regions[2](I): A - regions[3](I/O): A_grad - regions[4](I): B - regions[5](I/O): B_grad - regions[6](I/O): C_grad -*/ +#endif \ No newline at end of file From 67cf7be6d1acc1b080ab50db694d6b27402a7336 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Mon, 11 Sep 2023 12:28:30 -0700 Subject: [PATCH 19/33] format --- lib/kernels/include/kernels/batch_matmul_kernels.h | 2 +- lib/kernels/src/cuda/batch_matmul_kernels.cu | 8 +++----- lib/runtime/src/ops/batch_matmul.h | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index 6c7fdf8d8f..137e76705e 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_OPS_KERNELS_BATCH_MATMUL_KERNELS_H #define _FLEXFLOW_OPS_KERNELS_BATCH_MATMUL_KERNELS_H +#include "kernels/allocation.h" #include "kernels/device.h" #include "kernels/ff_handle.h" -#include "kernels/allocation.h" #include "utils/visitable.h" namespace FlexFlow { diff --git a/lib/kernels/src/cuda/batch_matmul_kernels.cu b/lib/kernels/src/cuda/batch_matmul_kernels.cu index 000b3d307b..ec21da914b 100644 --- a/lib/kernels/src/cuda/batch_matmul_kernels.cu +++ b/lib/kernels/src/cuda/batch_matmul_kernels.cu @@ -13,9 +13,9 @@ * limitations under the License. */ +#include "device.h" #include "kernels/batch_matmul_kernels.h" #include "kernels/device.h" -#include "device.h" namespace FlexFlow { namespace Kernels { @@ -26,10 +26,8 @@ BMMPerDeviceState init_kernel(PerDeviceFFHandle const &handle, int a_seq_length_dim, int b_seq_length_dim) { - BMMPerDeviceState per_device_state = {handle, - allocator, - a_seq_length_dim, - b_seq_length_dim}; + BMMPerDeviceState per_device_state = { + handle, allocator, a_seq_length_dim, b_seq_length_dim}; return per_device_state; } diff --git a/lib/runtime/src/ops/batch_matmul.h b/lib/runtime/src/ops/batch_matmul.h index 1701ab493c..f0be288d02 100644 --- a/lib/runtime/src/ops/batch_matmul.h +++ b/lib/runtime/src/ops/batch_matmul.h @@ -27,4 +27,4 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, } // namespace FlexFlow -#endif \ No newline at end of file +#endif From bbdfe9a9c284293c550f4f8649dd01c9777dba20 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Wed, 13 Sep 2023 11:10:14 -0700 Subject: [PATCH 20/33] add cuda --- lib/kernels/include/kernels/batch_matmul_kernels.h | 14 +++++--------- lib/op-attrs/src/batch_matmul.cc | 8 -------- 2 files changed, 5 insertions(+), 17 deletions(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index 137e76705e..c92cada8c3 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -10,21 +10,15 @@ namespace FlexFlow { struct BMMPerDeviceState { PerDeviceFFHandle handle; - Allocator allocator; - int a_seq_length_dim; - req b_seq_length_dim; }; -FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION( - BMMPerDeviceState, handle, allocator, a_seq_length_dim, b_seq_length_dim); +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(BMMPerDeviceState, handle); namespace Kernels { namespace BatchMatmul { BMMPerDeviceState init_kernel(PerDeviceFFHandle const &handle, - Allocator const &allocator, - int a_seq_length_dim, - int b_seq_length_dim); + Allocator const &allocator); void forward_kernel(ffStream_t stream, BMMPerDeviceState const &meta, @@ -35,7 +29,9 @@ void forward_kernel(ffStream_t stream, int n, int k, int batch, - int seq_length = -1); + int seq_length, + int a_seq_length_dim, + int b_seq_length_dim); void backward_kernel(ffStream_t stream, BMMPerDeviceState const &meta, diff --git a/lib/op-attrs/src/batch_matmul.cc b/lib/op-attrs/src/batch_matmul.cc index bd61c24737..1cc8c5cfda 100644 --- a/lib/op-attrs/src/batch_matmul.cc +++ b/lib/op-attrs/src/batch_matmul.cc @@ -2,14 +2,6 @@ namespace FlexFlow { -int get_aSeqLengthDim(BatchMatmulAttrs const &attrs) { - return attrs.a_seq_length_dim; -} - -int get_bSeqLengthDim(BatchMatmulAttrs const &attrs) { - return attrs.b_seq_length_dim; -} - /* bool BatchMatmulAttrs::is_valid( */ /* ParallelTensorShape const &lhs, ParallelTensorShape const &rhs) const { */ From dce19e23d3db07c28c76eae1de9f67947ba54576 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Wed, 13 Sep 2023 14:26:11 -0700 Subject: [PATCH 21/33] fix attention kuda --- lib/kernels/src/cuda/attention_kernels.cu | 269 ++++++------------ lib/kernels/src/cuda/batch_matmul_kernels.cu | 27 +- lib/op-attrs/include/op-attrs/ops/attention.h | 2 +- lib/op-attrs/include/op-attrs/parallel_dim.h | 2 +- lib/runtime/src/ops/attention.h | 6 +- lib/runtime/src/ops/batch_matmul.cc | 24 +- 6 files changed, 120 insertions(+), 210 deletions(-) diff --git a/lib/kernels/src/cuda/attention_kernels.cu b/lib/kernels/src/cuda/attention_kernels.cu index 5981179395..3e07b37319 100644 --- a/lib/kernels/src/cuda/attention_kernels.cu +++ b/lib/kernels/src/cuda/attention_kernels.cu @@ -14,52 +14,63 @@ */ #include "kernels/attention_kernels.h" -#include "kernels/cuda_helper.h" +#include "device.h" +#include "kernels/device.h" namespace FlexFlow { namespace Kernels { namespace MultiHeadAttention { -void init_kernel(MHAPerDeviceState *m, - int num_samples, - int num_heads, - int qSize, - int kSize, - int vSize, - int qProjSize, - int kProjSize, - int vProjSize, - int oProjSize, - int qoSeqLength, - int kvSeqLength, - bool add_bias_kv) { +MHAPerDeviceState init_kernel(PerDeviceFFHandle const &handle, + Allocator allocator, + int num_samples, + int num_heads, + int qSize, + int kSize, + int vSize, + int qProjSize, + int kProjSize, + int vProjSize, + int oProjSize, + int qoSeqLength, + int kvSeqLength, + bool add_bias_kv) { cudaStream_t stream; + ffAttnDescriptor_t attnDesc; + ffSeqDataDescriptor_t qDesc; + ffSeqDataDescriptor_t kDesc; + ffSeqDataDescriptor_t vDesc; + ffSeqDataDescriptor_t oDesc; + void *reserveSpace; + void *dropoutStates; + int *devQoSeqArray; + int *devKvSeqArray; + size_t reserveSpaceSize; + size_t dropoutStateSize; + size_t weightSize; + checkCUDA(get_legion_stream(&stream)); - checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); - checkCUDNN(cudnnCreateAttnDescriptor(&m->attnDesc)); - checkCUDNN(cudnnCreateSeqDataDescriptor(&m->qDesc)); - checkCUDNN(cudnnCreateSeqDataDescriptor(&m->kDesc)); - checkCUDNN(cudnnCreateSeqDataDescriptor(&m->vDesc)); - checkCUDNN(cudnnCreateSeqDataDescriptor(&m->oDesc)); + checkCUDNN(cudnnSetStream(handle.dnn, stream)); + checkCUDNN(cudnnCreateAttnDescriptor(&attnDesc)); + checkCUDNN(cudnnCreateSeqDataDescriptor(&qDesc)); + checkCUDNN(cudnnCreateSeqDataDescriptor(&kDesc)); + checkCUDNN(cudnnCreateSeqDataDescriptor(&vDesc)); + checkCUDNN(cudnnCreateSeqDataDescriptor(&oDesc)); + // Currently do not support adding bias to key/value projection assert(!add_bias_kv); cudnnAttnQueryMap_t attnMode = CUDNN_ATTN_QUERYMAP_ALL_TO_ONE; + // Assume no beam search for now int maxBeamSize = 1; - // printf("batchSize(%d) qSize(%d) kSize(%d) vSize(%d) qProjSize(%d) - // kProjSize(%d)\n", - // num_samples, attn->qSize, attn->kSize, attn->vSize, attn->qProjSize, - // attn->kProjSize); - // printf("vProjSize(%d) oProjSize(%d) qoSeqLength(%d) kvSeqLength(%d)\n", - // attn->vProjSize, attn->oProjSize, attn->qoSeqLength, - // attn->kvSeqLength); + cudnnMathType_t math_type; - if (m->handle.allowTensorOpMathConversion) { + if (handle.allowTensorOpMathConversion) { math_type = CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION; } else { math_type = CUDNN_TENSOR_OP_MATH; } - checkCUDNN(cudnnSetAttnDescriptor(m->attnDesc, + checkCUDNN(cudnnSetAttnDescriptor(attnDesc, attnMode, num_heads, 1.0f /*smScalar*/, @@ -80,14 +91,13 @@ void init_kernel(MHAPerDeviceState *m, num_samples, maxBeamSize)); size_t workSpaceSize; - checkCUDNN(cudnnGetMultiHeadAttnBuffers(m->handle.dnn, - m->attnDesc, - &m->weightSize, + checkCUDNN(cudnnGetMultiHeadAttnBuffers(handle.dnn, + attnDesc, + &weightSize, &workSpaceSize, - &m->reserveSpaceSize)); - assert(workSpaceSize <= m->handle.workSpaceSize); - // printf("weightSize(%zu) workSpaceSize(%zu) reserveSpaceSize(%zu)\n", - // weightSize, workSpaceSize, reserveSpaceSize); + &reserveSpaceSize)); + assert(workSpaceSize <= handle.workSpaceSize); + int dimA[CUDNN_SEQDATA_DIM_COUNT]; cudnnSeqDataAxis_t axes[CUDNN_SEQDATA_DIM_COUNT]; assert(CUDNN_SEQDATA_DIM_COUNT == 4); @@ -107,7 +117,7 @@ void init_kernel(MHAPerDeviceState *m, dimA[CUDNN_SEQDATA_BATCH_DIM] = num_samples; dimA[CUDNN_SEQDATA_TIME_DIM] = qoSeqLength; dimA[CUDNN_SEQDATA_VECT_DIM] = qSize; - checkCUDNN(cudnnSetSeqDataDescriptor(m->qDesc, + checkCUDNN(cudnnSetSeqDataDescriptor(qDesc, CUDNN_DATA_FLOAT, CUDNN_SEQDATA_DIM_COUNT, dimA, @@ -122,7 +132,7 @@ void init_kernel(MHAPerDeviceState *m, dimA[CUDNN_SEQDATA_BATCH_DIM] = num_samples; dimA[CUDNN_SEQDATA_TIME_DIM] = kvSeqLength; dimA[CUDNN_SEQDATA_VECT_DIM] = kSize; - checkCUDNN(cudnnSetSeqDataDescriptor(m->kDesc, + checkCUDNN(cudnnSetSeqDataDescriptor(kDesc, CUDNN_DATA_FLOAT, CUDNN_SEQDATA_DIM_COUNT, dimA, @@ -137,7 +147,7 @@ void init_kernel(MHAPerDeviceState *m, dimA[CUDNN_SEQDATA_BATCH_DIM] = num_samples; dimA[CUDNN_SEQDATA_TIME_DIM] = kvSeqLength; dimA[CUDNN_SEQDATA_VECT_DIM] = vSize; - checkCUDNN(cudnnSetSeqDataDescriptor(m->vDesc, + checkCUDNN(cudnnSetSeqDataDescriptor(vDesc, CUDNN_DATA_FLOAT, CUDNN_SEQDATA_DIM_COUNT, dimA, @@ -152,7 +162,7 @@ void init_kernel(MHAPerDeviceState *m, dimA[CUDNN_SEQDATA_BATCH_DIM] = num_samples; dimA[CUDNN_SEQDATA_TIME_DIM] = qoSeqLength; dimA[CUDNN_SEQDATA_VECT_DIM] = oProjSize; - checkCUDNN(cudnnSetSeqDataDescriptor(m->oDesc, + checkCUDNN(cudnnSetSeqDataDescriptor(oDesc, CUDNN_DATA_FLOAT, CUDNN_SEQDATA_DIM_COUNT, dimA, @@ -163,109 +173,46 @@ void init_kernel(MHAPerDeviceState *m, } // allocate memory for the seqArray and reserve space { - size_t totalSize = m->reserveSpaceSize + sizeof(int) * num_samples * 2; + size_t totalSize = reserveSpaceSize + sizeof(int) * num_samples * 2; - m->devQoSeqArray = (int *)m->gpu_alloc(totalSize); - checkCUDA(cudaMemcpy(m->devQoSeqArray, + devQoSeqArray = (int *)allocator.allocate(totalSize); + checkCUDA(cudaMemcpy(devQoSeqArray, qoSeqArray, sizeof(int) * num_samples, cudaMemcpyHostToDevice)); - m->devKvSeqArray = m->devQoSeqArray + num_samples; - checkCUDA(cudaMemcpy(m->devKvSeqArray, + devKvSeqArray = devQoSeqArray + num_samples; + checkCUDA(cudaMemcpy(devKvSeqArray, kvSeqArray, sizeof(int) * num_samples, cudaMemcpyHostToDevice)); - m->reserveSpace = m->devKvSeqArray + num_samples; + reserveSpace = devKvSeqArray + num_samples; } // allocate memory for loWinIdx/hiWinIdx - m->loWinIdx = (int *)malloc(sizeof(int) * qoSeqLength); - m->hiWinIdx = (int *)malloc(sizeof(int) * qoSeqLength); + int *loWinIdx = (int *)malloc(sizeof(int) * qoSeqLength); + int *hiWinIdx = (int *)malloc(sizeof(int) * qoSeqLength); for (int i = 0; i < qoSeqLength; i++) { - m->loWinIdx[i] = 0; - m->hiWinIdx[i] = kvSeqLength; + loWinIdx[i] = 0; + hiWinIdx[i] = kvSeqLength; } + + MHAPerDeviceState per_device_state = {handle, + weightSize, + reserveSpaceSize, + attnDesc, + qDesc, + kDesc, + vDesc, + oDesc, + devQoSeqArray, + devKvSeqArray, + loWinIdx, + hiWinIdx, + reserveSpace, + allocator}; free(qoSeqArray); free(kvSeqArray); } -/* void forward_kernel_wrapper(MHAPerDeviceState const *m, */ -/* float const *query_ptr, */ -/* float const *key_ptr, */ -/* float const *value_ptr, */ -/* float const *weight_ptr, */ -/* float *output_ptr) { */ -/* wrapper(Internal::forward_kernel, m->profiling, ) */ -/* cudaStream_t stream; */ -/* checkCUDA(get_legion_stream(&stream)); */ - -/* cudaEvent_t t_start, t_end; */ -/* if (m->profiling) { */ -/* cudaEventCreate(&t_start); */ -/* cudaEventCreate(&t_end); */ -/* cudaEventRecord(t_start, stream); */ -/* } */ -/* Internal::forward_kernel( */ -/* m, query_ptr, key_ptr, value_ptr, weight_ptr, output_ptr, stream); */ -/* if (m->profiling) { */ -/* cudaEventRecord(t_end, stream); */ -/* checkCUDA(cudaEventSynchronize(t_end)); */ -/* float elapsed = 0; */ -/* checkCUDA(cudaEventElapsedTime(&elapsed, t_start, t_end)); */ -/* cudaEventDestroy(t_start); */ -/* cudaEventDestroy(t_end); */ -/* printf("MultiHeadAttention forward time = %.2fms\n", elapsed); */ -/* // print_tensor<3, float>(acc_query.ptr, acc_query.rect, */ -/* // "[Attention:forward:query]"); print_tensor<3, float>(acc_output.ptr, - */ -/* // acc_output.rect, "[Attention:forward:output]"); */ -/* } */ -/* } */ - -/* void backward_kernel_wrapper( */ -/* MHAPerDeviceState const *m, */ -/* float const *query_ptr, */ -/* float *query_grad_ptr, */ -/* float const *key_ptr, */ -/* float *key_grad_ptr, */ -/* float const *value_ptr, */ -/* float *value_grad_ptr, */ -/* float const *weight_ptr, */ -/* float *weight_grad_ptr, */ -/* float const *output_grad_ptr) { */ -/* cudaStream_t stream; */ -/* checkCUDA(get_legion_stream(&stream)); */ - -/* cudaEvent_t t_start, t_end; */ -/* if (m->profiling) { */ -/* cudaEventCreate(&t_start); */ -/* cudaEventCreate(&t_end); */ -/* cudaEventRecord(t_start, stream); */ -/* } */ - -/* Internal::backward_kernel(m, */ -/* query_ptr, */ -/* query_grad_ptr, */ -/* key_ptr, */ -/* key_grad_ptr, */ -/* value_ptr, */ -/* value_grad_ptr, */ -/* weight_ptr, */ -/* weight_grad_ptr, */ -/* output_grad_ptr, */ -/* stream); */ -/* if (m->profiling) { */ -/* cudaEventRecord(t_end, stream); */ -/* checkCUDA(cudaEventSynchronize(t_end)); */ -/* float elapsed = 0; */ -/* checkCUDA(cudaEventElapsedTime(&elapsed, t_start, t_end)); */ -/* cudaEventDestroy(t_start); */ -/* cudaEventDestroy(t_end); */ -/* printf("MultiHeadAttention backward time = %.2fms\n", elapsed); */ -/* } */ -/* } */ - -/* namespace Internal { */ - void forward_kernel(cudaStream_t stream, MHAPerDeviceState *m, float const *query_ptr, @@ -355,56 +302,22 @@ void backward_kernel(cudaStream_t stream, m->reserveSpace)); } -/* } // namespace Internal */ +void cleanup_kernel(int *loWinIdx, + int *hiWinIdx, + ffAttnDescriptor_t attnDesc, + ffSeqDataDescriptor_t qDesc, + ffSeqDataDescriptor_t kDesc, + ffSeqDataDescriptor_t vDesc, + ffSeqDataDescriptor_t oDesc) { + free(loWinIdx); + free(hiWinIdx); + checkCUDNN(cudnnDestroyAttnDescriptor(attnDesc)); + checkCUDNN(cudnnDestroySeqDataDescriptor(qDesc)); + checkCUDNN(cudnnDestroySeqDataDescriptor(kDesc)); + checkCUDNN(cudnnDestroySeqDataDescriptor(vDesc)); + checkCUDNN(cudnnDestroySeqDataDescriptor(oDesc)); +} + } // namespace MultiHeadAttention } // namespace Kernels - -MHAPerDeviceState::MHAPerDeviceState(FFHandler handler, - Memory gpu_mem, - int num_samples, - int num_heads, - int qSize, - int kSize, - int vSize, - int qProjSize, - int kProjSize, - int vProjSize, - int oProjSize, - int qoSeqLength, - int kvSeqLength, - bool add_bias_kv) - : PerDeviceOpState(handler) {} - -MHAPerDeviceState::MHAPerDeviceState(FFHandler handler, - std::unique_ptr allocator, - MultiHeadAttentionAttrs const &attrs, - ArrayShape const &query_shape, - ArrayShape const &key_shape, - ArrayShape const &value_shape) { - : MHAPerDeviceState(handler, - allocator, - query_shape[2], - attrs.num_heads, - query_shape[0], - key_shape[0], - value_shape[0], - qProjSize(attrs), - kProjSize(attrs), - vProjSize(attrs), - oProjSize(attrs), - query_shape[1], - key_shape[1], - attrs.add_bias_kv) -{ } - - MHAPerDeviceState::~MHAPerDeviceState(void) { - free(loWinIdx); - free(hiWinIdx); - checkCUDNN(cudnnDestroyAttnDescriptor(attnDesc)); - checkCUDNN(cudnnDestroySeqDataDescriptor(qDesc)); - checkCUDNN(cudnnDestroySeqDataDescriptor(kDesc)); - checkCUDNN(cudnnDestroySeqDataDescriptor(vDesc)); - checkCUDNN(cudnnDestroySeqDataDescriptor(oDesc)); - } - } // namespace FlexFlow diff --git a/lib/kernels/src/cuda/batch_matmul_kernels.cu b/lib/kernels/src/cuda/batch_matmul_kernels.cu index ec21da914b..3ea1fc6951 100644 --- a/lib/kernels/src/cuda/batch_matmul_kernels.cu +++ b/lib/kernels/src/cuda/batch_matmul_kernels.cu @@ -22,12 +22,9 @@ namespace Kernels { namespace BatchMatmul { BMMPerDeviceState init_kernel(PerDeviceFFHandle const &handle, - Allocator const &allocator, - int a_seq_length_dim, - int b_seq_length_dim) { + Allocator const &allocator) { - BMMPerDeviceState per_device_state = { - handle, allocator, a_seq_length_dim, b_seq_length_dim}; + BMMPerDeviceState per_device_state = {handle}; return per_device_state; } @@ -40,7 +37,9 @@ void forward_kernel(ffStream_t stream, int n, int k, int batch, - int seq_length) { + int a_seq_length_dim, + int b_seq_length_dim, + int seq_length = -1) { checkCUDA(cublasSetStream(meta.handle.blas, stream)); checkCUDNN(cudnnSetStream(meta.handle.dnn, stream)); int lda = k; @@ -49,26 +48,26 @@ void forward_kernel(ffStream_t stream, long long int strideA = (long long int)n * k; long long int strideB = (long long int)k * m; long long int strideO = (long long int)n * m; - if ((meta.a_seq_length_dim == 0) && (seq_length >= 0)) { + if ((a_seq_length_dim == 0) && (seq_length >= 0)) { assert(seq_length <= k); k = seq_length; - assert(meta.b_seq_length_dim == 1); - } else if ((meta.a_seq_length_dim == 1) && (seq_length >= 0)) { + assert(b_seq_length_dim == 1); + } else if ((a_seq_length_dim == 1) && (seq_length >= 0)) { assert(seq_length <= n); n = seq_length; } else { // currently only support a_seq_length_dim = 0 or 1 - assert((meta.a_seq_length_dim < 0) || (seq_length < 0)); + assert((a_seq_length_dim < 0) || (seq_length < 0)); } - if ((meta.b_seq_length_dim == 0) && (seq_length >= 0)) { + if ((b_seq_length_dim == 0) && (seq_length >= 0)) { assert(seq_length <= m); m = seq_length; - } else if ((meta.b_seq_length_dim == 1) && (seq_length >= 0)) { - assert(meta.a_seq_length_dim == 0); + } else if ((b_seq_length_dim == 1) && (seq_length >= 0)) { + assert(a_seq_length_dim == 0); assert(k == seq_length); } else { // currently only support a_seq_length_dim = 0 or 1 - assert((meta.b_seq_length_dim < 0) || (seq_length < 0)); + assert((b_seq_length_dim < 0) || (seq_length < 0)); } float alpha = 1.0f, beta = 0.0f; diff --git a/lib/op-attrs/include/op-attrs/ops/attention.h b/lib/op-attrs/include/op-attrs/ops/attention.h index ec3e592607..3f972e745d 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention.h +++ b/lib/op-attrs/include/op-attrs/ops/attention.h @@ -12,7 +12,7 @@ struct MultiHeadAttentionAttrs { req dropout; req bias, add_bias_kv, add_zero_attn; }; -FF_VISITABLE_STRUCT(MultiHeadAttentionAttrs, +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(MultiHeadAttentionAttrs, embed_dim, num_heads, kdim, diff --git a/lib/op-attrs/include/op-attrs/parallel_dim.h b/lib/op-attrs/include/op-attrs/parallel_dim.h index b9df2d9037..ba4ad68b62 100644 --- a/lib/op-attrs/include/op-attrs/parallel_dim.h +++ b/lib/op-attrs/include/op-attrs/parallel_dim.h @@ -11,7 +11,7 @@ struct ParallelDim { int degree; req is_replica_dim; }; -FF_VISITABLE_STRUCT(ParallelDim, size, degree, is_replica_dim); +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(ParallelDim, size, degree, is_replica_dim); bool is_valid(ParallelDim const &); bool is_replica_dim(ParallelDim const &); diff --git a/lib/runtime/src/ops/attention.h b/lib/runtime/src/ops/attention.h index f0a5e0abc3..09a4ef036f 100644 --- a/lib/runtime/src/ops/attention.h +++ b/lib/runtime/src/ops/attention.h @@ -20,9 +20,9 @@ OpTaskInvocation backward(MultiHeadAttentionAttrs const &); CostMetrics measure_operator_cost(SimEnvFactory const &sim, MultiHeadAttentionAttrs const &attrs, - ParallelTensorShape const &query_shape, - ParallelTensorShape const &key_shape, - ParallelTensorShape const &value_shape, + InputParallelTensorDesc const &query_shape, + InputParallelTensorDesc const &key_shape, + InputParallelTensorDesc const &value_shape, ProfilingSettings const &settings, MachineView const &mv); } // namespace FlexFlow diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 45c5e11b9c..1db24bad12 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -31,20 +31,16 @@ using Legion::Task; enum Slots { A_INPUT, // tensor B_INPUT, // tensor + ATTRS, OUTPUT, // tensor PROFILING, HANDLE, - A_SEQ_LENGTH_DIM, - B_SEQ_LENGTH_DIM, PER_DEVICE_STATE, ITERATION_CONFIG }; OpTaskInvocation init(BatchMatmulAttrs const &attrs) { OpTaskBinding init; - - init.bind_arg(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); - init.bind_arg(B_SEQ_LENGTH_DIM, get_bSeqLengthDim(attrs)); init.bind_arg(HANDLE, ff_handle()); return {BATCHMATMUL_INIT_TASK_ID, init}; @@ -57,6 +53,8 @@ OpTaskInvocation forward(BatchMatmulAttrs const &attrs) { fwd.bind(B_INPUT, input_tensor(1)); fwd.bind(OUTPUT, output_tensor(0)); + fwd.bind_arg(ATTRS, attrs); + fwd.bind_arg(PROFILING, profiling_settings()); fwd.bind_arg(PER_DEVICE_STATE, per_device_op_state()); fwd.bind_arg(ITERATION_CONFIG, iteration_config()); @@ -72,14 +70,12 @@ OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { static DeviceSpecific init_task_impl(TaskArgumentAccessor const &acc) { - int const a_seq_length_dim = acc.get_argument(A_SEQ_LENGTH_DIM); - int const b_seq_length_dim = acc.get_argument(B_SEQ_LENGTH_DIM); PerDeviceFFHandle handle = acc.get_argument(HANDLE); Allocator allocator = acc.get_allocator(); DeviceSpecific per_device_state = acc.create_device_specific( - init_kernel(handle, allocator, a_seq_length_dim, b_seq_length_dim)); + init_kernel(handle, allocator)); return per_device_state; } @@ -97,6 +93,7 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { auto a_input = acc.get_tensor(A_INPUT); auto b_input = acc.get_tensor(B_INPUT); auto output = acc.get_tensor(OUTPUT); + auto attrs = acc.get_argument(ATTRS); ProfilingSettings profiling = acc.get_argument(PROFILING); auto per_device_state = acc.get_argument(PER_DEVICE_STATE); @@ -133,7 +130,9 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { n, k, batch, - iter_config.seq_length); + iter_config.seq_length, + attrs.a_seq_length_dim, + attrs.b_seq_length_dim); } static void forward_task(Task const *task, @@ -217,8 +216,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, get_output_shape(attrs, a_input.shape, b_input.shape); SimTaskBinding init_binding; - init_binding.bind_arg(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); - init_binding.bind_arg(B_SEQ_LENGTH_DIM, get_bSeqLengthDim(attrs)); init_binding.bind_arg(HANDLE, ff_handle()); auto init_accessor = @@ -230,6 +227,7 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, fwd_binding.bind(A_INPUT, a_input); fwd_binding.bind(B_INPUT, b_input); fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind_arg(ATTRS, attrs); fwd_binding.bind_arg(PROFILING, settings); fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); @@ -251,9 +249,8 @@ template <> void register_task() { OpTaskSignature init(OpTaskType::INIT); - init.add_arg_slot(A_SEQ_LENGTH_DIM); - init.add_arg_slot(B_SEQ_LENGTH_DIM); init.add_unchecked_arg_slot(HANDLE); + init.add_return_value(); register_task(BATCHMATMUL_INIT_TASK_ID, "BatchMatmul Init", init, init_task); } @@ -265,6 +262,7 @@ void register_task() { fwd.add_input_slot(A_INPUT); fwd.add_input_slot(B_INPUT); fwd.add_output_slot(OUTPUT); + fwd.add_arg_slot(ATTRS); fwd.add_arg_slot(PROFILING); fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); From e4a796da27c4c7be948c75fdc015815f5c07361f Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Wed, 13 Sep 2023 14:31:04 -0700 Subject: [PATCH 22/33] format --- lib/kernels/src/cuda/attention_kernels.cu | 9 +++------ lib/op-attrs/include/op-attrs/ops/attention.h | 16 ++++++++-------- lib/op-attrs/include/op-attrs/parallel_dim.h | 5 ++++- lib/runtime/src/ops/batch_matmul.cc | 2 +- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/lib/kernels/src/cuda/attention_kernels.cu b/lib/kernels/src/cuda/attention_kernels.cu index 3e07b37319..7490089dd0 100644 --- a/lib/kernels/src/cuda/attention_kernels.cu +++ b/lib/kernels/src/cuda/attention_kernels.cu @@ -13,8 +13,8 @@ * limitations under the License. */ -#include "kernels/attention_kernels.h" #include "device.h" +#include "kernels/attention_kernels.h" #include "kernels/device.h" namespace FlexFlow { @@ -91,11 +91,8 @@ MHAPerDeviceState init_kernel(PerDeviceFFHandle const &handle, num_samples, maxBeamSize)); size_t workSpaceSize; - checkCUDNN(cudnnGetMultiHeadAttnBuffers(handle.dnn, - attnDesc, - &weightSize, - &workSpaceSize, - &reserveSpaceSize)); + checkCUDNN(cudnnGetMultiHeadAttnBuffers( + handle.dnn, attnDesc, &weightSize, &workSpaceSize, &reserveSpaceSize)); assert(workSpaceSize <= handle.workSpaceSize); int dimA[CUDNN_SEQDATA_DIM_COUNT]; diff --git a/lib/op-attrs/include/op-attrs/ops/attention.h b/lib/op-attrs/include/op-attrs/ops/attention.h index 3f972e745d..0852a953f0 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention.h +++ b/lib/op-attrs/include/op-attrs/ops/attention.h @@ -13,14 +13,14 @@ struct MultiHeadAttentionAttrs { req bias, add_bias_kv, add_zero_attn; }; FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(MultiHeadAttentionAttrs, - embed_dim, - num_heads, - kdim, - vdim, - dropout, - bias, - add_bias_kv, - add_zero_attn); + embed_dim, + num_heads, + kdim, + vdim, + dropout, + bias, + add_bias_kv, + add_zero_attn); template struct MultiHeadAttentionInputs diff --git a/lib/op-attrs/include/op-attrs/parallel_dim.h b/lib/op-attrs/include/op-attrs/parallel_dim.h index ba4ad68b62..9d407ec469 100644 --- a/lib/op-attrs/include/op-attrs/parallel_dim.h +++ b/lib/op-attrs/include/op-attrs/parallel_dim.h @@ -11,7 +11,10 @@ struct ParallelDim { int degree; req is_replica_dim; }; -FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(ParallelDim, size, degree, is_replica_dim); +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(ParallelDim, + size, + degree, + is_replica_dim); bool is_valid(ParallelDim const &); bool is_replica_dim(ParallelDim const &); diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 1db24bad12..c1a4f05991 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -32,7 +32,7 @@ enum Slots { A_INPUT, // tensor B_INPUT, // tensor ATTRS, - OUTPUT, // tensor + OUTPUT, // tensor PROFILING, HANDLE, PER_DEVICE_STATE, From e817e75e87e72cb03678b6448f042155182dd466 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Fri, 29 Sep 2023 20:11:32 -0700 Subject: [PATCH 23/33] draft --- lib/op-attrs/include/op-attrs/ops/batch_matmul.h | 3 --- lib/runtime/src/ops/batch_matmul.cc | 1 - lib/runtime/src/task_spec/runtime_arg_ref.h | 2 +- 3 files changed, 1 insertion(+), 5 deletions(-) diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index 00c700ba20..b05a5eb022 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -12,9 +12,6 @@ struct BatchMatmulAttrs { }; FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim); -int get_aSeqLengthDim(BatchMatmulAttrs const &attrs); -int get_bSeqLengthDim(BatchMatmulAttrs const &attrs); - CHECK_VALID_OP_ATTR(BatchMatmulAttrs); } // namespace FlexFlow diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index c1a4f05991..879844d957 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -152,7 +152,6 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { auto output = acc.get_tensor(OUTPUT); auto output_grad = acc.get_tensor_grad(OUTPUT); - // is this equivalent to checking `Domain` equality? assert(output.shape == output_grad.shape); auto a_input = acc.get_tensor(A_INPUT); diff --git a/lib/runtime/src/task_spec/runtime_arg_ref.h b/lib/runtime/src/task_spec/runtime_arg_ref.h index 033c2bcfbc..e85a191209 100644 --- a/lib/runtime/src/task_spec/runtime_arg_ref.h +++ b/lib/runtime/src/task_spec/runtime_arg_ref.h @@ -7,7 +7,7 @@ namespace FlexFlow { -enum class RuntimeArgRefType { FF_HANDLE, PROFILING_SETTINGS }; +enum class RuntimeArgRefType { FF_HANDLE, PROFILING_SETTINGS, FF_ITERATION_CONFIG }; template using RuntimeArgRef = ArgRef; From 5c80bcdcccfba8bc502363b1c2f9cf3d6cd6c5e7 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Fri, 29 Sep 2023 20:15:27 -0700 Subject: [PATCH 24/33] format --- lib/runtime/src/task_spec/runtime_arg_ref.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lib/runtime/src/task_spec/runtime_arg_ref.h b/lib/runtime/src/task_spec/runtime_arg_ref.h index e85a191209..655300e692 100644 --- a/lib/runtime/src/task_spec/runtime_arg_ref.h +++ b/lib/runtime/src/task_spec/runtime_arg_ref.h @@ -7,7 +7,11 @@ namespace FlexFlow { -enum class RuntimeArgRefType { FF_HANDLE, PROFILING_SETTINGS, FF_ITERATION_CONFIG }; +enum class RuntimeArgRefType { + FF_HANDLE, + PROFILING_SETTINGS, + FF_ITERATION_CONFIG +}; template using RuntimeArgRef = ArgRef; From 70dc9f0183ef6d40f3c587f683dfe40c81a6c833 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Fri, 29 Sep 2023 22:00:43 -0700 Subject: [PATCH 25/33] reyna fixes --- lib/runtime/src/ops/batch_matmul.cc | 37 ++++++++++++++++++++++------- lib/runtime/src/ops/batch_matmul.h | 6 +++++ 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 879844d957..a57de1aadf 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -244,18 +244,24 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, return make_metrics(forward_time, backward_time, sync_time, env); } -template <> -void register_task() { +OpTaskSignature init_signature() { OpTaskSignature init(OpTaskType::INIT); init.add_unchecked_arg_slot(HANDLE); init.add_return_value(); - register_task(BATCHMATMUL_INIT_TASK_ID, "BatchMatmul Init", init, init_task); + return init; } template <> -void register_task() { +void register_task() { + register_task(BATCHMATMUL_INIT_TASK_ID, + "BatchMatmul Init", + init_signature(), + init_task); +} + +OpTaskSignature fwd_signature() { OpTaskSignature fwd(OpTaskType::FWD); fwd.add_input_slot(A_INPUT); @@ -265,15 +271,30 @@ void register_task() { fwd.add_arg_slot(PROFILING); fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); - register_task(BATCHMATMUL_FWD_TASK_ID, "BatchMatmul Fwd", fwd, forward_task); + return fwd; } template <> -void register_task() { +void register_task() { + register_task(BATCHMATMUL_FWD_TASK_ID, + "BatchMatmul Fwd", + fwd_signature(), + forward_task); +} + +OpTaskSignature bwd_signature() { OpTaskSignature bwd = - infer_bwd_signature(get_op_signature(ATTENTION_FWD_TASK_ID)); + infer_bwd_signature(get_op_signature(BATCHMATMUL_BWD_TASK_ID)); - register_task(BATCHMATMUL_BWD_TASK_ID, "BatchMatmul Bwd", bwd, backward_task); + return bwd; +} + +template <> +void register_task() { + register_task(BATCHMATMUL_BWD_TASK_ID, + "BatchMatmul Bwd", + bwd_signature(), + backward_task); } }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/batch_matmul.h b/lib/runtime/src/ops/batch_matmul.h index f0be288d02..a3e3fbbb91 100644 --- a/lib/runtime/src/ops/batch_matmul.h +++ b/lib/runtime/src/ops/batch_matmul.h @@ -4,13 +4,19 @@ #include "op-attrs/ops/batch_matmul.h" #include "sim_environment.h" #include "task_spec/op_task_invocation.h" +#include "task_spec/op_task_signature.h" namespace FlexFlow { +OpTaskSignature init_signature(); template <> void register_task(); + +OpTaskSignature fwd_signature(); template <> void register_task(); + +OpTaskSignature bwd_signature(); template <> void register_task(); From 03b291cf37d91bb08a18af1540b03ba4f2f3f1d8 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Fri, 29 Sep 2023 22:22:57 -0700 Subject: [PATCH 26/33] format --- lib/runtime/src/ops/batch_matmul.cc | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index a57de1aadf..d99a56b779 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -255,9 +255,9 @@ OpTaskSignature init_signature() { template <> void register_task() { - register_task(BATCHMATMUL_INIT_TASK_ID, - "BatchMatmul Init", - init_signature(), + register_task(BATCHMATMUL_INIT_TASK_ID, + "BatchMatmul Init", + init_signature(), init_task); } @@ -276,9 +276,9 @@ OpTaskSignature fwd_signature() { template <> void register_task() { - register_task(BATCHMATMUL_FWD_TASK_ID, - "BatchMatmul Fwd", - fwd_signature(), + register_task(BATCHMATMUL_FWD_TASK_ID, + "BatchMatmul Fwd", + fwd_signature(), forward_task); } @@ -291,9 +291,9 @@ OpTaskSignature bwd_signature() { template <> void register_task() { - register_task(BATCHMATMUL_BWD_TASK_ID, - "BatchMatmul Bwd", - bwd_signature(), + register_task(BATCHMATMUL_BWD_TASK_ID, + "BatchMatmul Bwd", + bwd_signature(), backward_task); } From 2bb4501319a9426d504716a67fd14afaa877e910 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 3 Oct 2023 18:28:08 -0700 Subject: [PATCH 27/33] format repo-refactor --- .../substitutions/sub_parallel_computation_graph.h | 3 ++- lib/substitutions/src/sub_parallel_computation_graph.cc | 7 ++++--- .../utils/graph/labelled/labelled_open_interfaces.h | 9 +++++---- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h index 96a3b41dfc..352ffc1dec 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h @@ -14,7 +14,8 @@ using SubParallelComputationGraph = CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(SubParallelComputationGraph); -ParallelTensor at(SubParallelComputationGraph const &g, OpenMultiDiEdge const &e); +ParallelTensor at(SubParallelComputationGraph const &g, + OpenMultiDiEdge const &e); } // namespace FlexFlow diff --git a/lib/substitutions/src/sub_parallel_computation_graph.cc b/lib/substitutions/src/sub_parallel_computation_graph.cc index ac67451c78..e8ab70648f 100644 --- a/lib/substitutions/src/sub_parallel_computation_graph.cc +++ b/lib/substitutions/src/sub_parallel_computation_graph.cc @@ -2,8 +2,9 @@ namespace FlexFlow { -ParallelTensor at(SubParallelComputationGraph const &g, OpenMultiDiEdge const &e) { - return visit([&](const auto &e) { return g.at(e); }, e); +ParallelTensor at(SubParallelComputationGraph const &g, + OpenMultiDiEdge const &e) { + return visit([&](auto const &e) { return g.at(e); }, e); } -} +} // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/labelled/labelled_open_interfaces.h b/lib/utils/include/utils/graph/labelled/labelled_open_interfaces.h index 20138c4212..2db654c615 100644 --- a/lib/utils/include/utils/graph/labelled/labelled_open_interfaces.h +++ b/lib/utils/include/utils/graph/labelled/labelled_open_interfaces.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_LABELLED_OPEN_INTERFACES_H #include "standard_labelled_interfaces.h" -#include "utils/graph/open_graph_interfaces.h" #include "utils/containers.h" +#include "utils/graph/open_graph_interfaces.h" namespace FlexFlow { @@ -15,11 +15,12 @@ struct ILabelledOpenMultiDiGraphView : public IOpenMultiDiGraphView, public ILabelledMultiDiGraphView { public: - std::unordered_set query_edges(MultiDiEdgeQuery const &q) const final { - return map_over_unordered_set([](OpenMultiDiEdge const &e) { return get(e); }, - IOpenMultiDiGraphView::query_edges(static_cast(q))); + return map_over_unordered_set( + [](OpenMultiDiEdge const &e) { return get(e); }, + IOpenMultiDiGraphView::query_edges( + static_cast(q))); } using ILabelledMultiDiGraphView::at; From 125921bd175dc65e6f0cbfe41cde822326cadff7 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 3 Oct 2023 19:03:33 -0700 Subject: [PATCH 28/33] delete init and split register_task --- .../include/kernels/batch_matmul_kernels.h | 14 +--- lib/kernels/src/cuda/batch_matmul_kernels.cu | 25 +++---- lib/runtime/src/ops/batch_matmul.cc | 71 +++---------------- lib/runtime/src/ops/batch_matmul.h | 8 --- 4 files changed, 19 insertions(+), 99 deletions(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index c92cada8c3..c1966309a4 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -7,21 +7,11 @@ #include "utils/visitable.h" namespace FlexFlow { - -struct BMMPerDeviceState { - PerDeviceFFHandle handle; -}; - -FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(BMMPerDeviceState, handle); - namespace Kernels { namespace BatchMatmul { -BMMPerDeviceState init_kernel(PerDeviceFFHandle const &handle, - Allocator const &allocator); - void forward_kernel(ffStream_t stream, - BMMPerDeviceState const &meta, + PerDeviceFFHandle const &handle, float *output_ptr, float const *a_input_ptr, float const *b_input_ptr, @@ -34,7 +24,7 @@ void forward_kernel(ffStream_t stream, int b_seq_length_dim); void backward_kernel(ffStream_t stream, - BMMPerDeviceState const &meta, + PerDeviceFFHandle const &handle, float const *o_ptr, float const *o_grad_ptr, float const *a_ptr, diff --git a/lib/kernels/src/cuda/batch_matmul_kernels.cu b/lib/kernels/src/cuda/batch_matmul_kernels.cu index 3ea1fc6951..ad7470290d 100644 --- a/lib/kernels/src/cuda/batch_matmul_kernels.cu +++ b/lib/kernels/src/cuda/batch_matmul_kernels.cu @@ -21,15 +21,8 @@ namespace FlexFlow { namespace Kernels { namespace BatchMatmul { -BMMPerDeviceState init_kernel(PerDeviceFFHandle const &handle, - Allocator const &allocator) { - - BMMPerDeviceState per_device_state = {handle}; - return per_device_state; -} - void forward_kernel(ffStream_t stream, - BMMPerDeviceState const &meta, + PerDeviceFFHandle const &handle, float *output_ptr, float const *a_input_ptr, float const *b_input_ptr, @@ -40,8 +33,8 @@ void forward_kernel(ffStream_t stream, int a_seq_length_dim, int b_seq_length_dim, int seq_length = -1) { - checkCUDA(cublasSetStream(meta.handle.blas, stream)); - checkCUDNN(cudnnSetStream(meta.handle.dnn, stream)); + checkCUDA(cublasSetStream(handle.blas, stream)); + checkCUDNN(cudnnSetStream(handle.dnn, stream)); int lda = k; int ldb = m; int ldo = m; @@ -71,7 +64,7 @@ void forward_kernel(ffStream_t stream, } float alpha = 1.0f, beta = 0.0f; - checkCUDA(cublasSgemmStridedBatched(meta.handle.blas, + checkCUDA(cublasSgemmStridedBatched(handle.blas, CUBLAS_OP_N, CUBLAS_OP_N, m, @@ -94,7 +87,7 @@ void forward_kernel(ffStream_t stream, } void backward_kernel(ffStream_t stream, - BMMPerDeviceState const &meta, + PerDeviceFFHandle const &handle, float const *o_ptr, float const *o_grad_ptr, float const *a_ptr, @@ -105,14 +98,14 @@ void backward_kernel(ffStream_t stream, int n, int k, int batch) { - checkCUDA(cublasSetStream(meta.handle.blas, stream)); - checkCUDNN(cudnnSetStream(meta.handle.dnn, stream)); + checkCUDA(cublasSetStream(handle.blas, stream)); + checkCUDNN(cudnnSetStream(handle.dnn, stream)); int a_stride = n * k; int b_stride = m * k; int o_stride = n * m; float alpha = 1.0f; - checkCUDA(cublasSgemmStridedBatched(meta.handle.blas, + checkCUDA(cublasSgemmStridedBatched(handle.blas, CUBLAS_OP_T, CUBLAS_OP_N, k, @@ -130,7 +123,7 @@ void backward_kernel(ffStream_t stream, k, a_stride, batch)); - checkCUDA(cublasSgemmStridedBatched(meta.handle.blas, + checkCUDA(cublasSgemmStridedBatched(handle.blas, CUBLAS_OP_N, CUBLAS_OP_T, m, diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index d99a56b779..e376e8592f 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -35,17 +35,9 @@ enum Slots { OUTPUT, // tensor PROFILING, HANDLE, - PER_DEVICE_STATE, ITERATION_CONFIG }; -OpTaskInvocation init(BatchMatmulAttrs const &attrs) { - OpTaskBinding init; - init.bind_arg(HANDLE, ff_handle()); - - return {BATCHMATMUL_INIT_TASK_ID, init}; -} - OpTaskInvocation forward(BatchMatmulAttrs const &attrs) { OpTaskBinding fwd; @@ -54,9 +46,8 @@ OpTaskInvocation forward(BatchMatmulAttrs const &attrs) { fwd.bind(OUTPUT, output_tensor(0)); fwd.bind_arg(ATTRS, attrs); - + fwd.bind_arg(HANDLE, ff_handle()); fwd.bind_arg(PROFILING, profiling_settings()); - fwd.bind_arg(PER_DEVICE_STATE, per_device_op_state()); fwd.bind_arg(ITERATION_CONFIG, iteration_config()); return {BATCHMATMUL_FWD_TASK_ID, fwd}; @@ -68,35 +59,14 @@ OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { return {BATCHMATMUL_BWD_TASK_ID, bwd}; } -static DeviceSpecific - init_task_impl(TaskArgumentAccessor const &acc) { - PerDeviceFFHandle handle = acc.get_argument(HANDLE); - Allocator allocator = acc.get_allocator(); - - DeviceSpecific per_device_state = - acc.create_device_specific( - init_kernel(handle, allocator)); - - return per_device_state; -} - -static DeviceSpecific - init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - return init_task_impl(acc); -} - static optional forward_task_impl(TaskArgumentAccessor const &acc) { auto a_input = acc.get_tensor(A_INPUT); auto b_input = acc.get_tensor(B_INPUT); auto output = acc.get_tensor(OUTPUT); auto attrs = acc.get_argument(ATTRS); + PerDeviceFFHandle handle = acc.get_argument(HANDLE); ProfilingSettings profiling = acc.get_argument(PROFILING); - auto per_device_state = acc.get_argument(PER_DEVICE_STATE); FFIterationConfig iter_config = acc.get_argument(ITERATION_CONFIG); @@ -122,7 +92,7 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { return profile(forward_kernel, profiling, "[BatchMatmul] forward_time = %.2lfms\n", - per_device_state, + handle, output.get_float_ptr(), a_input.get_float_ptr(), b_input.get_float_ptr(), @@ -148,7 +118,7 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { FFIterationConfig iter_config = acc.get_argument(ITERATION_CONFIG); ProfilingSettings profiling = acc.get_argument(PROFILING); - auto per_device_state = acc.get_argument(PER_DEVICE_STATE); + PerDeviceFFHandle handle = acc.get_argument(HANDLE); auto output = acc.get_tensor(OUTPUT); auto output_grad = acc.get_tensor_grad(OUTPUT); @@ -182,7 +152,7 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { return profile(backward_kernel, profiling, "[BatchMatmul] backward_time = %.2lfms\n", - per_device_state, + handle, output.get_float_ptr(), output_grad.get_float_ptr(), a_input.get_float_ptr(), @@ -214,21 +184,13 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, ParallelTensorShape output_shape = get_output_shape(attrs, a_input.shape, b_input.shape); - SimTaskBinding init_binding; - init_binding.bind_arg(HANDLE, ff_handle()); - - auto init_accessor = - env.get_init_accessor(BATCHMATMUL_INIT_TASK_ID, init_binding); - DeviceSpecific per_device_state = - init_task_impl(init_accessor); - SimTaskBinding fwd_binding; fwd_binding.bind(A_INPUT, a_input); fwd_binding.bind(B_INPUT, b_input); fwd_binding.bind(OUTPUT, output_shape); fwd_binding.bind_arg(ATTRS, attrs); fwd_binding.bind_arg(PROFILING, settings); - fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); + fwd_binding.bind_arg(HANDLE, ff_handle()); SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); @@ -244,23 +206,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, return make_metrics(forward_time, backward_time, sync_time, env); } -OpTaskSignature init_signature() { - OpTaskSignature init(OpTaskType::INIT); - - init.add_unchecked_arg_slot(HANDLE); - init.add_return_value(); - - return init; -} - -template <> -void register_task() { - register_task(BATCHMATMUL_INIT_TASK_ID, - "BatchMatmul Init", - init_signature(), - init_task); -} - OpTaskSignature fwd_signature() { OpTaskSignature fwd(OpTaskType::FWD); @@ -269,7 +214,7 @@ OpTaskSignature fwd_signature() { fwd.add_output_slot(OUTPUT); fwd.add_arg_slot(ATTRS); fwd.add_arg_slot(PROFILING); - fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); + fwd.add_unchecked_arg_slot(HANDLE); return fwd; } @@ -284,7 +229,7 @@ void register_task() { OpTaskSignature bwd_signature() { OpTaskSignature bwd = - infer_bwd_signature(get_op_signature(BATCHMATMUL_BWD_TASK_ID)); + infer_bwd_signature(get_op_signature(BATCHMATMUL_FWD_TASK_ID)); return bwd; } diff --git a/lib/runtime/src/ops/batch_matmul.h b/lib/runtime/src/ops/batch_matmul.h index a3e3fbbb91..7d3f2308da 100644 --- a/lib/runtime/src/ops/batch_matmul.h +++ b/lib/runtime/src/ops/batch_matmul.h @@ -8,19 +8,11 @@ namespace FlexFlow { -OpTaskSignature init_signature(); -template <> -void register_task(); - -OpTaskSignature fwd_signature(); template <> void register_task(); - -OpTaskSignature bwd_signature(); template <> void register_task(); -OpTaskInvocation init(BatchMatmulAttrs const &); OpTaskInvocation forward(BatchMatmulAttrs const &); OpTaskInvocation backward(BatchMatmulAttrs const &); From dc225a3346e9df55522b50f9d366e4d5d6a996e1 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 3 Oct 2023 21:55:56 -0700 Subject: [PATCH 29/33] finish bmm and att. --- lib/runtime/src/ops/attention.cc | 40 ++++++++++++++++++++++------- lib/runtime/src/ops/batch_matmul.cc | 11 +++++--- 2 files changed, 38 insertions(+), 13 deletions(-) diff --git a/lib/runtime/src/ops/attention.cc b/lib/runtime/src/ops/attention.cc index 94e2b03731..41905f9014 100644 --- a/lib/runtime/src/ops/attention.cc +++ b/lib/runtime/src/ops/attention.cc @@ -17,6 +17,7 @@ #include "kernels/attention_kernels.h" #include "legion.h" #include "op-attrs/ops/attention.h" +#include "task_spec/op_task_signature.h" namespace FlexFlow { @@ -284,7 +285,7 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, } template <> -void register_task() { +OpTaskSignature init_signature() { OpTaskSignature init(OpTaskType::INIT); init.add_arg_slot(QUERY_PARALLEL_TENSOR_SHAPE); init.add_arg_slot(KEY_PARALLEL_TENSOR_SHAPE); @@ -298,12 +299,19 @@ void register_task() { init.add_return_value(); - register_task( - ATTENTION_INIT_TASK_ID, "MultiHeadAttention Init", init, init_task); + return init; } template <> -void register_task() { +void register_task() { + register_task(ATTENTION_INIT_TASK_ID, + "Attention Init", + init_signature(), + init_task); +} + +template <> +OpTaskSignature fwd_signature() { OpTaskSignature fwd(OpTaskType::FWD); fwd.add_input_slot(QUERY); @@ -315,17 +323,31 @@ void register_task() { fwd.add_arg_slot(PROFILING); fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); - register_task( - ATTENTION_FWD_TASK_ID, "MultiHeadAttention Fwd", fwd, forward_task); + return fwd; } template <> -void register_task() { +void register_task() { + register_task(ATTENTION_FWD_TASK_ID, + "Attention Fwd", + fwd_signature(), + forward_task); +} + +template <> +OpTaskSignature bwd_signature() { OpTaskSignature bwd = infer_bwd_signature(get_op_signature(ATTENTION_FWD_TASK_ID)); - register_task( - ATTENTION_BWD_TASK_ID, "MultiHeadAttention Bwd", bwd, backward_task); + return bwd; +} + +template <> +void register_task() { + register_task(ATTENTION_BWD_TASK_ID, + "Attention Bwd", + bwd_signature(), + backward_task); } } // namespace FlexFlow diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index e376e8592f..9be6df4d9e 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -18,6 +18,7 @@ #include "legion.h" #include "op-attrs/get_output_shapes.h" #include "op-attrs/ops/batch_matmul.h" +#include "task_spec/op_task_signature.h" namespace FlexFlow { @@ -206,7 +207,8 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, return make_metrics(forward_time, backward_time, sync_time, env); } -OpTaskSignature fwd_signature() { +template <> +OpTaskSignature fwd_signature() { OpTaskSignature fwd(OpTaskType::FWD); fwd.add_input_slot(A_INPUT); @@ -223,11 +225,12 @@ template <> void register_task() { register_task(BATCHMATMUL_FWD_TASK_ID, "BatchMatmul Fwd", - fwd_signature(), + fwd_signature(), forward_task); } -OpTaskSignature bwd_signature() { +template <> +OpTaskSignature bwd_signature() { OpTaskSignature bwd = infer_bwd_signature(get_op_signature(BATCHMATMUL_FWD_TASK_ID)); @@ -238,7 +241,7 @@ template <> void register_task() { register_task(BATCHMATMUL_BWD_TASK_ID, "BatchMatmul Bwd", - bwd_signature(), + bwd_signature(), backward_task); } From d6ce742d8d58db9f6a1a0022aa64969fcfc587f4 Mon Sep 17 00:00:00 2001 From: Reyna Abhyankar Date: Fri, 6 Oct 2023 23:51:18 -0700 Subject: [PATCH 30/33] Align hip kernel --- lib/kernels/src/cuda/batch_matmul_kernels.cu | 8 ++--- lib/kernels/src/hip/batch_matmul_kernels.cpp | 38 ++++++++------------ lib/runtime/src/ops/batch_matmul.cc | 9 +++-- 3 files changed, 21 insertions(+), 34 deletions(-) diff --git a/lib/kernels/src/cuda/batch_matmul_kernels.cu b/lib/kernels/src/cuda/batch_matmul_kernels.cu index ad7470290d..08453b9c65 100644 --- a/lib/kernels/src/cuda/batch_matmul_kernels.cu +++ b/lib/kernels/src/cuda/batch_matmul_kernels.cu @@ -21,7 +21,7 @@ namespace FlexFlow { namespace Kernels { namespace BatchMatmul { -void forward_kernel(ffStream_t stream, +void forward_kernel(cudaStream_t stream, PerDeviceFFHandle const &handle, float *output_ptr, float const *a_input_ptr, @@ -82,11 +82,9 @@ void forward_kernel(ffStream_t stream, ldo, strideO, batch)); - // current assume c is null - assert(c_ptr == NULL); } -void backward_kernel(ffStream_t stream, +void backward_kernel(cudaStream_t stream, PerDeviceFFHandle const &handle, float const *o_ptr, float const *o_grad_ptr, @@ -141,10 +139,8 @@ void backward_kernel(ffStream_t stream, m, b_stride, batch)); - assert(c_grad_ptr == NULL); } -/* } // namespace Internal */ } // namespace BatchMatmul } // namespace Kernels } // namespace FlexFlow diff --git a/lib/kernels/src/hip/batch_matmul_kernels.cpp b/lib/kernels/src/hip/batch_matmul_kernels.cpp index a06442d3d6..cbfd669e0f 100644 --- a/lib/kernels/src/hip/batch_matmul_kernels.cpp +++ b/lib/kernels/src/hip/batch_matmul_kernels.cpp @@ -18,7 +18,6 @@ #include namespace FlexFlow { - namespace Kernels { namespace BatchMatmul { @@ -29,21 +28,19 @@ O: (batch, n, m) O = A * B */ void forward_kernel(hipStream_t stream, - BatchMatmulPerDeviceState const &meta, - float *o_ptr, - float const *a_ptr, - float const *b_ptr, - float const *c_ptr, + PerDeviceFFHandle const &handle, + float *output_ptr, + float const *a_input_ptr, + float const *b_input_ptr, int m, int n, int k, int batch, - hipStream_t stream, - int seq_length) { - int a_seq_length_dim = meta->a_seq_length_dim; - int b_seq_length_dim = meta->b_seq_length_dim; - checkCUDA(hipblasSetStream(meta->handle.blas, stream)); - checkCUDNN(miopenSetStream(meta->handle.dnn, stream)); + int a_seq_length_dim, + int b_seq_length_dim, + int seq_length = -1) { + checkCUDA(hipblasSetStream(handle.blas, stream)); + checkCUDNN(miopenSetStream(handle.dnn, stream)); // int a_stride = n * k; // int b_stride = m * k; @@ -77,7 +74,7 @@ void forward_kernel(hipStream_t stream, } float alpha = 1.0f, beta = 0.0f; - checkCUDA(hipblasSgemmStridedBatched(meta->handle.blas, + checkCUDA(hipblasSgemmStridedBatched(handle.blas, HIPBLAS_OP_N, HIPBLAS_OP_N, m, @@ -95,8 +92,6 @@ void forward_kernel(hipStream_t stream, ldo, strideO, batch)); - // current assume c is null - assert(c_ptr == NULL); } /* @@ -107,26 +102,25 @@ AGrad = OGrad * B^T BGrad = A^T * OGrad */ void backward_kernel(hipStream_t stream, - BatchMatmulPerDeviceState const *meta, + PerDeviceFFHandle const &handle, float const *o_ptr, float const *o_grad_ptr, float const *a_ptr, float *a_grad_ptr, float const *b_ptr, float *b_grad_ptr, - float *c_grad_ptr, int m, int n, int k, int batch) { - checkCUDA(hipblasSetStream(meta->handle.blas, stream)); - checkCUDNN(miopenSetStream(meta->handle.dnn, stream)); + checkCUDA(hipblasSetStream(handle.blas, stream)); + checkCUDNN(miopenSetStream(handle.dnn, stream)); int a_stride = n * k; int b_stride = m * k; int o_stride = n * m; float alpha = 1.0f; - checkCUDA(hipblasSgemmStridedBatched(meta->handle.blas, + checkCUDA(hipblasSgemmStridedBatched(handle.blas, HIPBLAS_OP_T, HIPBLAS_OP_N, k, @@ -144,7 +138,7 @@ void backward_kernel(hipStream_t stream, k, a_stride, batch)); - checkCUDA(hipblasSgemmStridedBatched(meta->handle.blas, + checkCUDA(hipblasSgemmStridedBatched(handle.blas, HIPBLAS_OP_N, HIPBLAS_OP_T, m, @@ -162,10 +156,8 @@ void backward_kernel(hipStream_t stream, m, b_stride, batch)); - assert(c_grad_ptr == NULL); } -} // namespace Internal } // namespace BatchMatmul } // namespace Kernels } // namespace FlexFlow diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 9be6df4d9e..5f40def699 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -82,8 +82,7 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { assert(a_input.shape.get_volume() == output.shape.get_volume()); int batch = 1; - for (int i = 2; i < a_input.shape.get_dim(); - i++) { // get_dim() or get_volume()? + for (int i = 2; i < a_input.shape.get_dim(); i++) { int dim_size = a_input.shape[legion_dim_t(i)]; assert(dim_size == b_input.shape[legion_dim_t(i)]); assert(dim_size == output.shape[legion_dim_t(i)]); @@ -101,9 +100,9 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { n, k, batch, - iter_config.seq_length, attrs.a_seq_length_dim, - attrs.b_seq_length_dim); + attrs.b_seq_length_dim, + iter_config.seq_length); } static void forward_task(Task const *task, @@ -232,7 +231,7 @@ void register_task() { template <> OpTaskSignature bwd_signature() { OpTaskSignature bwd = - infer_bwd_signature(get_op_signature(BATCHMATMUL_FWD_TASK_ID)); + infer_bwd_signature(fwd_signature()); return bwd; } From 3f6b8b093167193a5b2c091420f3b26068cf6286 Mon Sep 17 00:00:00 2001 From: Reyna Abhyankar Date: Tue, 10 Oct 2023 10:28:47 -0700 Subject: [PATCH 31/33] Fix attention kernels --- .../include/kernels/attention_kernels.h | 5 +- lib/kernels/src/cuda/attention_kernels.cu | 136 +++++++++--------- lib/kernels/src/cuda/batch_matmul_kernels.cu | 2 +- lib/op-attrs/include/op-attrs/ops/attention.h | 18 +-- 4 files changed, 81 insertions(+), 80 deletions(-) diff --git a/lib/kernels/include/kernels/attention_kernels.h b/lib/kernels/include/kernels/attention_kernels.h index 7c530bc10c..d474ea4da5 100644 --- a/lib/kernels/include/kernels/attention_kernels.h +++ b/lib/kernels/include/kernels/attention_kernels.h @@ -46,7 +46,7 @@ namespace Kernels { namespace MultiHeadAttention { MHAPerDeviceState init_kernel(PerDeviceFFHandle const &, - Allocator const &, + Allocator &, int num_samples, int num_heads, int qSize, @@ -80,6 +80,9 @@ void backward_kernel(ffStream_t stream, float *weight_grad_ptr, float const *output_grad_ptr); +void cleanup_kernel(Allocator &allocator, + MHAPerDeviceState const &device_state); + } // namespace MultiHeadAttention } // namespace Kernels } // namespace FlexFlow diff --git a/lib/kernels/src/cuda/attention_kernels.cu b/lib/kernels/src/cuda/attention_kernels.cu index 7490089dd0..5e11688a46 100644 --- a/lib/kernels/src/cuda/attention_kernels.cu +++ b/lib/kernels/src/cuda/attention_kernels.cu @@ -22,7 +22,7 @@ namespace Kernels { namespace MultiHeadAttention { MHAPerDeviceState init_kernel(PerDeviceFFHandle const &handle, - Allocator allocator, + Allocator &allocator, int num_samples, int num_heads, int qSize, @@ -206,45 +206,47 @@ MHAPerDeviceState init_kernel(PerDeviceFFHandle const &handle, hiWinIdx, reserveSpace, allocator}; - free(qoSeqArray); - free(kvSeqArray); + allocator.deallocate(qoSeqArray); + allocator.deallocate(kvSeqArray); + + return per_device_state; } void forward_kernel(cudaStream_t stream, - MHAPerDeviceState *m, + MHAPerDeviceState const &device_state, float const *query_ptr, float const *key_ptr, float const *value_ptr, float const *weight_ptr, float *output_ptr) { - checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); + checkCUDNN(cudnnSetStream(device_state.handle.dnn, stream)); - checkCUDNN(cudnnMultiHeadAttnForward(m->handle.dnn, - m->attnDesc, + checkCUDNN(cudnnMultiHeadAttnForward(device_state.handle.dnn, + device_state.attnDesc, -1, - m->loWinIdx, - m->hiWinIdx, - m->devQoSeqArray, - m->devKvSeqArray, - m->qDesc, + device_state.loWinIdx, + device_state.hiWinIdx, + device_state.devQoSeqArray, + device_state.devKvSeqArray, + device_state.qDesc, query_ptr, nullptr /*residual*/, - m->kDesc, + device_state.kDesc, key_ptr, - m->vDesc, + device_state.vDesc, value_ptr, - m->oDesc, + device_state.oDesc, output_ptr, - m->weightSize, + device_state.weightSize, weight_ptr, - m->handle.workSpaceSize, - m->handle.workSpace, - m->reserveSpaceSize, - m->reserveSpace)); + device_state.handle.workSpaceSize, + device_state.handle.workSpace, + device_state.reserveSpaceSize, + device_state.reserveSpace)); } void backward_kernel(cudaStream_t stream, - MHAPerDeviceState *m, + MHAPerDeviceState const &device_state, float const *query_ptr, float *query_grad_ptr, float const *key_ptr, @@ -254,65 +256,61 @@ void backward_kernel(cudaStream_t stream, float const *weight_ptr, float *weight_grad_ptr, float const *output_grad_ptr) { - checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); + checkCUDNN(cudnnSetStream(device_state.handle.dnn, stream)); - checkCUDNN(cudnnMultiHeadAttnBackwardData(m->handle.dnn, - m->attnDesc, - m->loWinIdx, - m->hiWinIdx, - m->devQoSeqArray, - m->devKvSeqArray, - m->oDesc, + checkCUDNN(cudnnMultiHeadAttnBackwardData(device_state.handle.dnn, + device_state.attnDesc, + device_state.loWinIdx, + device_state.hiWinIdx, + device_state.devQoSeqArray, + device_state.devKvSeqArray, + device_state.oDesc, output_grad_ptr, - m->qDesc, + device_state.qDesc, query_grad_ptr, query_ptr, - m->kDesc, + device_state.kDesc, key_grad_ptr, key_ptr, - m->vDesc, + device_state.vDesc, value_grad_ptr, value_ptr, - m->weightSize, + device_state.weightSize, weight_ptr, - m->handle.workSpaceSize, - m->handle.workSpace, - m->reserveSpaceSize, - m->reserveSpace)); - checkCUDNN(cudnnMultiHeadAttnBackwardWeights(m->handle.dnn, - m->attnDesc, - CUDNN_WGRAD_MODE_ADD, - m->qDesc, - query_ptr, - m->kDesc, - key_ptr, - m->vDesc, - value_ptr, - m->oDesc, - output_grad_ptr, - m->weightSize, - weight_ptr, - weight_grad_ptr, - m->handle.workSpaceSize, - m->handle.workSpace, - m->reserveSpaceSize, - m->reserveSpace)); + device_state.handle.workSpaceSize, + device_state.handle.workSpace, + device_state.reserveSpaceSize, + device_state.reserveSpace)); + checkCUDNN( + cudnnMultiHeadAttnBackwardWeights(device_state.handle.dnn, + device_state.attnDesc, + CUDNN_WGRAD_MODE_ADD, + device_state.qDesc, + query_ptr, + device_state.kDesc, + key_ptr, + device_state.vDesc, + value_ptr, + device_state.oDesc, + output_grad_ptr, + device_state.weightSize, + weight_ptr, + weight_grad_ptr, + device_state.handle.workSpaceSize, + device_state.handle.workSpace, + device_state.reserveSpaceSize, + device_state.reserveSpace)); } -void cleanup_kernel(int *loWinIdx, - int *hiWinIdx, - ffAttnDescriptor_t attnDesc, - ffSeqDataDescriptor_t qDesc, - ffSeqDataDescriptor_t kDesc, - ffSeqDataDescriptor_t vDesc, - ffSeqDataDescriptor_t oDesc) { - free(loWinIdx); - free(hiWinIdx); - checkCUDNN(cudnnDestroyAttnDescriptor(attnDesc)); - checkCUDNN(cudnnDestroySeqDataDescriptor(qDesc)); - checkCUDNN(cudnnDestroySeqDataDescriptor(kDesc)); - checkCUDNN(cudnnDestroySeqDataDescriptor(vDesc)); - checkCUDNN(cudnnDestroySeqDataDescriptor(oDesc)); +void cleanup_kernel(Allocator &allocator, + MHAPerDeviceState const &device_state) { + allocator.deallocate(device_state.loWinIdx); + allocator.deallocate(device_state.hiWinIdx); + checkCUDNN(cudnnDestroyAttnDescriptor(device_state.attnDesc)); + checkCUDNN(cudnnDestroySeqDataDescriptor(device_state.qDesc)); + checkCUDNN(cudnnDestroySeqDataDescriptor(device_state.kDesc)); + checkCUDNN(cudnnDestroySeqDataDescriptor(device_state.vDesc)); + checkCUDNN(cudnnDestroySeqDataDescriptor(device_state.oDesc)); } } // namespace MultiHeadAttention diff --git a/lib/kernels/src/cuda/batch_matmul_kernels.cu b/lib/kernels/src/cuda/batch_matmul_kernels.cu index 08453b9c65..9d35cb6c1a 100644 --- a/lib/kernels/src/cuda/batch_matmul_kernels.cu +++ b/lib/kernels/src/cuda/batch_matmul_kernels.cu @@ -32,7 +32,7 @@ void forward_kernel(cudaStream_t stream, int batch, int a_seq_length_dim, int b_seq_length_dim, - int seq_length = -1) { + int seq_length) { checkCUDA(cublasSetStream(handle.blas, stream)); checkCUDNN(cudnnSetStream(handle.dnn, stream)); int lda = k; diff --git a/lib/op-attrs/include/op-attrs/ops/attention.h b/lib/op-attrs/include/op-attrs/ops/attention.h index 0852a953f0..ec3e592607 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention.h +++ b/lib/op-attrs/include/op-attrs/ops/attention.h @@ -12,15 +12,15 @@ struct MultiHeadAttentionAttrs { req dropout; req bias, add_bias_kv, add_zero_attn; }; -FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(MultiHeadAttentionAttrs, - embed_dim, - num_heads, - kdim, - vdim, - dropout, - bias, - add_bias_kv, - add_zero_attn); +FF_VISITABLE_STRUCT(MultiHeadAttentionAttrs, + embed_dim, + num_heads, + kdim, + vdim, + dropout, + bias, + add_bias_kv, + add_zero_attn); template struct MultiHeadAttentionInputs From edc193578abad92924cb36fbd0e230cdae633558 Mon Sep 17 00:00:00 2001 From: Reyna Abhyankar Date: Thu, 19 Oct 2023 14:23:52 -0700 Subject: [PATCH 32/33] Replace with unique ptr --- lib/kernels/src/cuda/attention_kernels.cu | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/lib/kernels/src/cuda/attention_kernels.cu b/lib/kernels/src/cuda/attention_kernels.cu index 5e11688a46..684feb3c9d 100644 --- a/lib/kernels/src/cuda/attention_kernels.cu +++ b/lib/kernels/src/cuda/attention_kernels.cu @@ -102,8 +102,8 @@ MHAPerDeviceState init_kernel(PerDeviceFFHandle const &handle, axes[2] = CUDNN_SEQDATA_BEAM_DIM; axes[1] = CUDNN_SEQDATA_TIME_DIM; axes[0] = CUDNN_SEQDATA_BATCH_DIM; - int *qoSeqArray = (int *)malloc(sizeof(int) * num_samples); - int *kvSeqArray = (int *)malloc(sizeof(int) * num_samples); + std::unique_ptr qoSeqArray (new int[num_samples]); + std::unique_ptr kvSeqArray (new int[num_samples]); for (int i = 0; i < num_samples; i++) { qoSeqArray[i] = qoSeqLength; kvSeqArray[i] = kvSeqLength; @@ -120,7 +120,7 @@ MHAPerDeviceState init_kernel(PerDeviceFFHandle const &handle, dimA, axes, num_samples, - qoSeqArray, + qoSeqArray.get(), NULL)); } // Set kDesc @@ -135,7 +135,7 @@ MHAPerDeviceState init_kernel(PerDeviceFFHandle const &handle, dimA, axes, num_samples, - kvSeqArray, + kvSeqArray.get(), NULL)); } // Set vDesc @@ -150,7 +150,7 @@ MHAPerDeviceState init_kernel(PerDeviceFFHandle const &handle, dimA, axes, num_samples, - kvSeqArray, + kvSeqArray.get(), NULL)); } // Set oDesc @@ -165,7 +165,7 @@ MHAPerDeviceState init_kernel(PerDeviceFFHandle const &handle, dimA, axes, num_samples, - qoSeqArray, + qoSeqArray.get(), NULL)); } // allocate memory for the seqArray and reserve space @@ -174,12 +174,12 @@ MHAPerDeviceState init_kernel(PerDeviceFFHandle const &handle, devQoSeqArray = (int *)allocator.allocate(totalSize); checkCUDA(cudaMemcpy(devQoSeqArray, - qoSeqArray, + qoSeqArray.get(), sizeof(int) * num_samples, cudaMemcpyHostToDevice)); devKvSeqArray = devQoSeqArray + num_samples; checkCUDA(cudaMemcpy(devKvSeqArray, - kvSeqArray, + kvSeqArray.get(), sizeof(int) * num_samples, cudaMemcpyHostToDevice)); reserveSpace = devKvSeqArray + num_samples; @@ -206,8 +206,6 @@ MHAPerDeviceState init_kernel(PerDeviceFFHandle const &handle, hiWinIdx, reserveSpace, allocator}; - allocator.deallocate(qoSeqArray); - allocator.deallocate(kvSeqArray); return per_device_state; } From 07d9db9cb6be6c01662090104c23b105dc595863 Mon Sep 17 00:00:00 2001 From: Reyna Abhyankar Date: Mon, 23 Oct 2023 11:30:21 -0700 Subject: [PATCH 33/33] Format --- lib/kernels/src/cuda/attention_kernels.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/kernels/src/cuda/attention_kernels.cu b/lib/kernels/src/cuda/attention_kernels.cu index 684feb3c9d..c2225c13d4 100644 --- a/lib/kernels/src/cuda/attention_kernels.cu +++ b/lib/kernels/src/cuda/attention_kernels.cu @@ -102,8 +102,8 @@ MHAPerDeviceState init_kernel(PerDeviceFFHandle const &handle, axes[2] = CUDNN_SEQDATA_BEAM_DIM; axes[1] = CUDNN_SEQDATA_TIME_DIM; axes[0] = CUDNN_SEQDATA_BATCH_DIM; - std::unique_ptr qoSeqArray (new int[num_samples]); - std::unique_ptr kvSeqArray (new int[num_samples]); + std::unique_ptr qoSeqArray(new int[num_samples]); + std::unique_ptr kvSeqArray(new int[num_samples]); for (int i = 0; i < num_samples; i++) { qoSeqArray[i] = qoSeqLength; kvSeqArray[i] = kvSeqLength;