From 6b248cf20e560dbc2ff57d0db2bf463775177620 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 22 Aug 2023 14:11:42 -0700 Subject: [PATCH 01/19] batch matmul initial commit --- .../include/kernels/batch_matmul_kernels.h | 23 +- lib/kernels/src/hip/batch_matmul_kernels.cpp | 7 +- .../include/op-attrs/ops/batch_matmul.h | 4 +- lib/op-attrs/src/batch_matmul.cc | 8 + lib/runtime/src/ops/batch_matmul.cc | 855 ++++-------------- lib/runtime/src/ops/batch_matmul.h | 441 ++++++++- 6 files changed, 655 insertions(+), 683 deletions(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index 0e4437bdb8..cdffdf1907 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -6,17 +6,26 @@ namespace FlexFlow { -class BatchMatmulPerDeviceState : public PerDeviceOpState { -public: - BatchMatmulPerDeviceState(FFHandler handler); - int a_seq_length_dim, b_seq_length_dim; +struct BMMPerDeviceState { + PerDeviceFFHandle handle; + Allocator allocator; + int a_seq_length_dim; + int b_seq_length_dim; }; +FF_VISITABLE_STRUCT_NO_EQ(BMMPerDeviceState, + handle,); + namespace Kernels { namespace BatchMatmul { +BMMPerDeviceState init_kernel(PerDeviceFFHandle handle, + Allocator allocator, + int a_seq_length_dim, + int b_seq_length_dim); + void forward_kernel(ffStream_t stream, - BatchMatmulPerDeviceState const *, + BMMPerDeviceState const *meta, float *o_ptr, float const *a_ptr, float const *b_ptr, @@ -25,12 +34,10 @@ void forward_kernel(ffStream_t stream, int n, int k, int batch, - int a_seq_length_dim = -1, - int b_seq_length_dim = -1, int seq_length = -1); void backward_kernel(ffStream_t stream, - BatchMatmulPerDeviceState const *, + BMMPerDeviceState const *meta, float const *o_ptr, float const *o_grad_ptr, float const *a_ptr, diff --git a/lib/kernels/src/hip/batch_matmul_kernels.cpp b/lib/kernels/src/hip/batch_matmul_kernels.cpp index d8b6500326..e5334b1841 100644 --- a/lib/kernels/src/hip/batch_matmul_kernels.cpp +++ b/lib/kernels/src/hip/batch_matmul_kernels.cpp @@ -19,9 +19,6 @@ namespace FlexFlow { -BatchMatmulPerDeviceState::BatchMatmulPerDeviceState(FFHandler handler) - : PerDeviceOpState(handler) {} - namespace Kernels { namespace BatchMatmul { @@ -42,9 +39,9 @@ void forward_kernel(hipStream_t stream, int k, int batch, hipStream_t stream, - int a_seq_length_dim, - int b_seq_length_dim, int seq_length) { + int a_seq_length_dim = meta->a_seq_length_dim; + int b_seq_length_dim = meta->b_seq_length_dim; checkCUDA(hipblasSetStream(meta->handle.blas, stream)); checkCUDNN(miopenSetStream(meta->handle.dnn, stream)); diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index c74824570c..00c700ba20 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -12,8 +12,10 @@ struct BatchMatmulAttrs { }; FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim); -CHECK_VALID_OP_ATTR(BatchMatmulAttrs); +int get_aSeqLengthDim(BatchMatmulAttrs const &attrs); +int get_bSeqLengthDim(BatchMatmulAttrs const &attrs); +CHECK_VALID_OP_ATTR(BatchMatmulAttrs); } // namespace FlexFlow #endif diff --git a/lib/op-attrs/src/batch_matmul.cc b/lib/op-attrs/src/batch_matmul.cc index 1cc8c5cfda..bd61c24737 100644 --- a/lib/op-attrs/src/batch_matmul.cc +++ b/lib/op-attrs/src/batch_matmul.cc @@ -2,6 +2,14 @@ namespace FlexFlow { +int get_aSeqLengthDim(BatchMatmulAttrs const &attrs) { + return attrs.a_seq_length_dim; +} + +int get_bSeqLengthDim(BatchMatmulAttrs const &attrs) { + return attrs.b_seq_length_dim; +} + /* bool BatchMatmulAttrs::is_valid( */ /* ParallelTensorShape const &lhs, ParallelTensorShape const &rhs) const { */ diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 3e860bd413..9bbc050f8b 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -13,754 +13,287 @@ * limitations under the License. */ +// #include "batch_matmul.h" +// #include "kernels/batch_matmul_kernels.h" +// #include "kernels/profiling.h" +// #include "legion/legion_utilities.h" +// #include "tasks.h" + #include "batch_matmul.h" #include "kernels/batch_matmul_kernels.h" -#include "kernels/profiling.h" -#include "legion/legion_utilities.h" -#include "tasks.h" +#include "legion.h" +#include "op-attrs/ops/batch_matmul.h" namespace FlexFlow { using namespace FlexFlow::Kernels::BatchMatmul; +using Legion::Context; +using Legion::PhysicalRegion; +using Legion::Runtime; +using Legion::Task; + enum Slots { - A_INPUT, - B_INPUT, - OUTPUT, - A_INPUT_GRAD, - B_INPUT_GRAD, - OUTPUT_GRAD, - ATTRS, - PROFILING -}; + A_INPUT, //tensor + B_INPUT, //tensor + OUTPUT, //tensor + PROFILING, + HANDLE, + A_SEQ_LENGTH_DIM, + B_SEQ_LENGTH_DIM, + PER_DEVICE_STATE + }; OpTaskInvocation init(BatchMatmulAttrs const &attrs) { - OpTaskBinding b; + OpTaskBinding init; - b.bind_arg(ATTRS, attrs); - b.bind_arg(PROFILING, enable_profiling()); + init.bind_arg(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); + init.bind_arg(B_SEQ_LENGTH_DIM, get_bSeqLengthDim(attrs)); + init.bind_arg(HANDLE, ff_handle()); - return {BATCHMATMUL_INIT_TASK_ID, b}; + return {BATCHMATMUL_INIT_TASK_ID, init}; } OpTaskInvocation forward(BatchMatmulAttrs const &attrs) { - OpTaskBinding b; - - b.bind(A_INPUT, input_tensor(0)); - b.bind(B_INPUT, input_tensor(1)); - b.bind(OUTPUT, output_tensor(0)); - - return {BATCHMATMUL_FWD_TASK_ID, b}; -} - -OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { - OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - - return {BATCHMATMUL_BWD_TASK_ID, b}; -} - -BatchMatmulParams BatchMatmul::get_params() const { - BatchMatmulParams params; - params.a_seq_length_dim = inputs[0]->num_dims - 1 - this->a_seq_length_dim; - params.b_seq_length_dim = inputs[1]->num_dims - 1 - this->b_seq_length_dim; - return params; -} - -Tensor FFModel::batch_matmul(const Tensor A, - const Tensor B, - int a_seq_length_dim, - int b_seq_length_dim, - char const *name) { - Layer *bmm = new Layer(this, - OP_BATCHMATMUL, - DT_FLOAT, - name, - 2 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - A, - B); - assert((a_seq_length_dim <= 1) && - "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " - "Fortran ordering)."); - assert((b_seq_length_dim <= 1) && - "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " - "Fortran ordering)."); - assert(A->num_dims == B->num_dims); - for (int i = A->num_dims - 1; i >= 2; i--) { - assert(A->dims[i] == B->dims[i]); - } - assert(A->dims[0] == B->dims[1]); - int dims[MAX_TENSOR_DIM]; - int numdim = A->num_dims; - for (int i = 0; i < numdim; i++) { - dims[i] = A->dims[i]; - } - dims[0] = B->dims[0]; - bmm->outputs[0] = create_tensor_legion_ordering( - numdim, dims, A->data_type, bmm, 0, true /*create_grad*/); - bmm->add_int_property("a_seq_length_dim", a_seq_length_dim); - bmm->add_int_property("b_seq_length_dim", b_seq_length_dim); - layers.push_back(bmm); - return bmm->outputs[0]; -} - -Op *BatchMatmul::create_operator_from_layer( - FFModel &model, - Layer const *layer, - std::vector const &inputs) { - long long value; - layer->get_int_property("a_seq_length_dim", value); - int a_seq_length_dim = value; - layer->get_int_property("b_seq_length_dim", value); - int b_seq_length_dim = value; - return new BatchMatmul(model, - inputs[0], - inputs[1], - a_seq_length_dim, - b_seq_length_dim, - layer->name); -} - -BatchMatmul::BatchMatmul( - FFModel &model, - BatchMatmulParams const ¶ms, - std::pair const &inputs, - char const *name) - : BatchMatmul(model, - inputs.first, - inputs.second, - params.a_seq_length_dim, - params.b_seq_length_dim, - name) {} - -// return A*B -BatchMatmul::BatchMatmul(FFModel &model, - const ParallelTensor A, - const ParallelTensor B, - int _a_seq_length_dim, - int _b_seq_length_dim, - char const *name) - : Op(model, - OP_BATCHMATMUL, - DT_FLOAT, - name, - 2 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - A, - B), - a_seq_length_dim(A->num_dims - 1 - _a_seq_length_dim), - b_seq_length_dim(B->num_dims - 1 - _b_seq_length_dim) { - assert((_a_seq_length_dim <= 1) && - "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " - "Fortran ordering)."); - assert((_b_seq_length_dim <= 1) && - "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " - "Fortran ordering)."); - assert(A->num_dims == B->num_dims); - for (int i = A->num_dims - 1; i >= 2; i--) { - assert(A->dims[i] == B->dims[i]); - } - assert(A->dims[0] == B->dims[1]); - ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < A->num_dims; i++) { - dims[i] = A->dims[i]; - } - dims[0] = B->dims[0]; - numOutputs = 1; - outputs[0] = model.create_parallel_tensor_legion_ordering( - A->num_dims, dims, DT_FLOAT, this); - // C is not none - // if (C != Tensor::NO_TENSOR) { - // numInputs = 3; - // assert(C.num_dims == outputs[0].num_dims); - // for (int i = 0; i < C.num_dims; i++) - // assert(C.adim[i] == outputs[0].adim[i]); - //} -} - -void BatchMatmul::serialize(Legion::Serializer &sez) const { - BatchMatmulParams params = get_params(); - sez.serialize(params.a_seq_length_dim); - sez.serialize(params.b_seq_length_dim); -} - -using PCG::Node; -/*static*/ -Node BatchMatmul::deserialize(FFModel &ff, - Legion::Deserializer &dez, - ParallelTensor inputs[], - int num_inputs) { - assert(num_inputs == 2); - int a_seq_length_dim, b_seq_length_dim; - dez.deserialize(a_seq_length_dim); - dez.deserialize(b_seq_length_dim); - - BatchMatmulParams params; - params.a_seq_length_dim = a_seq_length_dim; - params.b_seq_length_dim = b_seq_length_dim; - return ff.get_or_create_node({inputs[0], inputs[1]}, params); -} + OpTaskBinding fwd; -Op *BatchMatmul::materialize(FFModel &ff, - ParallelTensor inputs[], - int num_inputs) const { - BatchMatmulParams params = get_params(); - return new BatchMatmul(ff, params, {inputs[0], inputs[1]}, this->name); -} - -template <> -void register_task() { - OpTaskSignature sig(OpTaskType::INIT); + fwd.bind(A_INPUT, input_tensor(0)); + fwd.bind(B_INPUT, input_tensor(1)); + fwd.bind(OUTPUT, output_tensor(0)); - sig.add_arg_slot(ATTRS); - sig.add_arg_slot(PROFILING); + fwd.bind_arg(PROFILING, profiling_settings()); + fwd.bind_arg(PER_DEVICE_STATE, per_device_op_state()); - register_task(BATCHMATMUL_INIT_TASK_ID, "BatchMatmul Init", sig, init_task); + return {BATCHMATMUL_FWD_TASK_ID, fwd}; } -static OpTaskSignature get_fwd_task_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_input_slot(A_INPUT, READ_WRITE); - fwd.add_input_slot(B_INPUT, READ_WRITE); - fwd.add_output_slot(OUTPUT); - - return fwd; -} - -static OpTaskSignature get_bwd_task_signature() { - OpTaskSignature bwd(OpTaskType::BWD); - - bwd.add_input_slot(A_INPUT); - bwd.add_input_slot(B_INPUT); - bwd.add_input_grad_slot(A_INPUT_GRAD); - bwd.add_input_grad_slot(B_INPUT_GRAD); - bwd.add_output_slot(OUTPUT); - bwd.add_output_grad_slot(OUTPUT_GRAD); - - return bwd; -} - -OpTaskBinding BatchMatmul::get_init_task_binding() const { - OpTaskBinding binding; - - binding.bind_arg(ATTRS, this->attrs); - binding.bind_arg(PROFILING, this->profiling); +OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { + OpTaskBinding bwd = infer_bwd_binding(forward(attrs).binding); - return binding; + return {BATCHMATMUL_BWD_TASK_ID, bwd}; } -OpTaskBinding BatchMatmul::get_fwd_task_binding() const { - OpTaskBinding binding; - - binding.bind(A_INPUT, input_tensor(0)); - binding.bind(B_INPUT, input_tensor(1)); - binding.bind(OUTPUT, output_tensor(0)); - - binding.bind_arg(ATTRS, this->attrs); - return binding; +static DeviceSpecificArg init_task_impl(TaskArgumentAccessor const &acc) { + auto const a_seq_length_dim = acc.get_argument(A_SEQ_LENGTH_DIM); + auto const b_seq_length_dim = acc.get_argument(B_SEQ_LENGTH_DIM); + PerDeviceFFHandle handle = acc.get_argument(HANDLE); + Allocator allocator = acc.get_allocator(); + + DeviceSpecificArg per_device_state = + acc.create_device_specific( + init_kernel(handle, + allocator, + a_seq_length_dim, + b_seq_length_dim)); + + // assert(weight.shape.get_volume() * sizeof(float) == + // acc.unwrap(per_device_state)->weightSize); + return per_device_state; } -OpTaskBinding BatchMatmul::get_bwd_task_binding() const { - OpTaskBinding binding; - binding.bind(A_INPUT, input_tensor(0)); - binding.bind(B_INPUT, input_tensor(1)); - binding.bind_grad(A_INPUT_GRAD, input_tensor(0).grad()); - binding.bind_grad(B_INPUT_GRAD, input_tensor(1).grad()); - - binding.bind(OUTPUT, output_tensor(0)); - binding.bind_grad(OUTPUT_GRAD, output_tensor(0).grad()); - - binding.bind_arg(ATTRS, this->attrs); - return binding; -} - -void BatchMatmul::init(FFModel const &ff) { - int dim = outputs[0]->num_dims; - switch (dim) { -#define DIMFUNC(DIM) \ - case DIM: { \ - // init_with_dim(ff); - this->execute_task(ff, BATCHMATMUL_INIT_TASK_ID, get_init_task_signature()); - break; - } - LEGION_FOREACH_N(DIMFUNC) -#undef DIMFUNC - default: - assert(false); -} -} // namespace FlexFlow -// / -// template -// void BatchMatmul::init_with_dim(FFModel const &ff) { -// assert(check_output_input_weight_same_parallel_is()); -// parallel_is = outputs[0]->parallel_is; -// ArgumentMap argmap; -// Context ctx = ff.config.lg_ctx; -// Runtime *runtime = ff.config.lg_hlr; -// set_argumentmap_for_init(ff, argmap); -// IndexLauncher launcher(BATCHMATMUL_INIT_TASK_ID, -// parallel_is, -// TaskArgument(this, sizeof(BatchMatmul)), -// argmap, -// Predicate::TRUE_PRED, -// false /*must*/, -// 0 /*mapper_id*/, -// outputs[0]->machine_view.hash()); -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, -// 0 /*projection id*/, -// WRITE_ONLY, -// EXCLUSIVE, -// outputs[0]->region)); -// launcher.add_field(0, FID_DATA); -// for (int i = 0; i < numInputs; i++) { -// launcher.add_region_requirement(RegionRequirement(inputs[i]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[i]->region)); -// launcher.add_field(i + 1, FID_DATA); -// } -// FutureMap fm = runtime->execute_index_space(ctx, launcher); -// fm.wait_all_results(); -// set_opmeta_from_futuremap(ff, fm); -// } - -PerDeviceOpState * - BatchMatmul::init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { +static DeviceSpecificArg + init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { TaskArgumentAccessor acc(task, regions, ctx, runtime); - auto const &attrs = acc.get_argument(ATTRS); - bool profiling = acc.get_argument(PROFILING); - - FFHandler handle = *((FFHandler const *)task->local_args); - BatchMatmulPerDeviceState *m = new BatchMatmulPerDeviceState(handle); - m->profiling = profiling; - m->a_seq_length_dim = attrs.a_seq_length_dim; - m->b_seq_length_dim = attrs.b_seq_length_dim; - return m; -} - -void BatchMatmul::forward(FFModel const &ff) { - int dim = outputs[0]->num_dims; - switch (dim) { -#define DIMFUNC(DIM) \ - case DIM: { \ - // forward_with_dim(ff); - this->execute_task(ff, BATCHMATMUL_FWD_TASK_ID, get_fwd_task_signature()); - break; - } - LEGION_FOREACH_N(DIMFUNC) -#undef DIMFUNC - default: - assert(false); -} + return init_task_impl(acc); } -// template -// void BatchMatmul::forward_with_dim(FFModel const &ff) { -// ArgumentMap argmap; -// Context ctx = ff.config.lg_ctx; -// Runtime *runtime = ff.config.lg_hlr; -// set_argumentmap_for_forward(ff, argmap); -// IndexLauncher launcher( -// BATCHMATMUL_FWD_TASK_ID, -// parallel_is, -// TaskArgument(&ff.iter_config, sizeof(FFIterationConfig)), -// argmap, -// Predicate::TRUE_PRED, -// false /*must*/, -// 0 /*mapper_id*/, -// outputs[0]->machine_view.hash()); -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, -// 0 /*projection id*/, -// WRITE_ONLY, -// EXCLUSIVE, -// outputs[0]->region)); -// launcher.add_field(0, FID_DATA); -// for (int i = 0; i < numInputs; i++) { -// launcher.add_region_requirement(RegionRequirement(inputs[i]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[i]->region)); -// launcher.add_field(i + 1, FID_DATA); -// } -// runtime->execute_index_space(ctx, launcher); -// } - -/* - regions[0](O): output - regions[1](I): A - regions[2](I): B - ////////////////////(optional) regions[3](I): C -- TODO: is C deprecated? - output = A * B /////////+ C -*/ -void BatchMatmul::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { +static optional forward_task_impl(TaskArgumentAccessor const &acc) { assert(regions.size() == 3); assert(task->regions.size() == 3); - TaskArgumentAccessor acc(task, regions, ctx, runtime); + auto a_input = acc.get_tensor(A_INPUT); + auto b_input = acc.get_tensor(B_INPUT); + auto output = acc.get_tensor(OUTPUT); - // const BatchMatmul* bmm = (const BatchMatmul*) task->args; + ProfilingSettings profiling = acc.get_argument(PROFILING); + auto per_device_state = acc.get_argument(PER_DEVICE_STATE); FFIterationConfig const *iter_config = (FFIterationConfig const *)task->args; - // BatchMatmulMeta const *meta = *((BatchMatmulMeta **)task->local_args); - BatchMatmulPerDeviceState const *meta = - *((BatchMatmulPerDeviceState **)task->local_args); - auto a_input = acc.get_tensor(A_INPUT); - auto b_input = acc.get_tensor(B_INPUT); - auto output = acc.get_tensor(OUTPUT); - - int m = b_input.shape[0]; - assert(m == output.shape[0]); - int n = a_input.shape[1]; - assert(n == output.shape[1]); - int k = a_input.shape[0]; - assert(k == b_input.shape[1]); + int m = b_input.shape[legion_dim_t(0)]; + assert(m == output.shape[legion_dim_t(0)]); + int n = a_input.shape[legion_dim_t(1)]; + assert(n == output.shape[legion_dim_t(1)]); + int k = a_input.shape[legion_dim_t(0)]; + assert(k == b_input.shape[legion_dim_t(1)]); assert(a_input.shape.size() == b_input.shape.size()); assert(a_input.shape.size() == output.shape.size()); + int batch = 1; for (int i = 2; i < a_input.shape.size(); i++) { - int dim_size = a_input.shape[i]; - assert(dim_size == b_input.shape[i]); - assert(dim_size == output.shape[i]); + int dim_size = a_input.shape[legion_dim_t(i)]; + assert(dim_size == b_input.shape[legion_dim_t(i)]); + assert(dim_size == output.shape[legion_dim_t(i)]); batch *= dim_size; } - float *out_ptr = output.get_float_ptr(); - c float const *a_ptr = a_input.get_float_ptr(); - float const *b_ptr = b_input.get_float_ptr(); - float const *c_ptr = NULL; - // if (regions.size() == 4) { - // Domain c_domain = runtime->get_index_space_domain( - // ctx, task->regions[3].region.get_index_space()); - // assert(c_domain == a_domain); - // c_ptr = helperGetTensorPointerRO( - // regions[3], task->regions[3], FID_DATA, ctx, runtime); - // } - - profile(forward_kernel, - meta->profiling, + + return profile(forward_kernel, + profiling, "[BatchMatmul] forward_time = %.2lfms\n", - out_ptr, - a_ptr, - b_ptr, - c_ptr, + per_device_state, + output.get_float_ptr(), + a_input.get_float_ptr(), + b_input.get_float_ptr(), + NULL, //c_ptr m, n, k, batch, - meta->a_seq_length_dim, - meta->b_seq_length_dim, iter_config->seq_length); } -void BatchMatmul::backward(FFModel const &ff) { - int dim = outputs[0]->num_dims; - switch (dim) { -#define DIMFUNC(DIM) \ - case DIM: { \ - backward_with_dim(ff); \ - break; \ - } - LEGION_FOREACH_N(DIMFUNC) -#undef DIMFUNC - default: - assert(false); - } +static void forward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + forward_task_impl(acc); } -/* - regions[0](I): output - regions[1](I): output_grad - regions[2](I): A - regions[3](I/O): A_grad - regions[4](I): B - regions[5](I/O): B_grad - regions[6](I/O): C_grad -*/ -// template -// void BatchMatmul::backward_with_dim(FFModel const &ff) { -// ArgumentMap argmap; -// Context ctx = ff.config.lg_ctx; -// Runtime *runtime = ff.config.lg_hlr; -// set_argumentmap_for_backward(ff, argmap); -// IndexLauncher launcher( -// BATCHMATMUL_BWD_TASK_ID, -// parallel_is, -// TaskArgument(&ff.iter_config, sizeof(FFIterationConfig)), -// argmap, -// Predicate::TRUE_PRED, -// false /*must*/, -// 0 /*mapper_id*/, -// outputs[0]->machine_view.hash()); -// // regions[0](I): output -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// outputs[0]->region)); -// launcher.add_field(0, FID_DATA); -// // regions[1](I): output_grad -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// outputs[0]->region_grad)); -// launcher.add_field(1, FID_DATA); -// // regions[2](I): A -// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[0]->region)); -// launcher.add_field(2, FID_DATA); -// // regions[3](I/O): A_grad -// launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, -// 0 /*projection id*/, -// READ_WRITE, -// EXCLUSIVE, -// inputs[0]->region_grad)); -// launcher.add_field(3, FID_DATA); -// // regions[4](I): B -// launcher.add_region_requirement(RegionRequirement(inputs[1]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[1]->region)); -// launcher.add_field(4, FID_DATA); -// // regions[5](I/O): B_grad -// launcher.add_region_requirement(RegionRequirement(inputs[1]->part_grad, -// 0 /*projection id*/, -// READ_WRITE, -// EXCLUSIVE, -// inputs[1]->region_grad)); -// launcher.add_field(5, FID_DATA); -// runtime->execute_index_space(ctx, launcher); -// } - -/* - regions[0](I): output - regions[1](I): output_grad - regions[2](I): A - regions[3](I/O): A_grad - regions[4](I): B - regions[5](I/O): B_grad - regions[6](I/O): C_grad -*/ -__host__ void - BatchMatmul::backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { +static optional backward_task_impl(TaskArgumentAccessor const &acc) { // Currently assume C is NULL assert(regions.size() == 6); assert(task->regions.size() == 6); + // BatchMatmul* bmm = (BatchMatmul*) task->args; - TaskArgumentAccessor acc(task, regions, ctx, runtime); FFIterationConfig const *iter_config = (FFIterationConfig const *)task->args; - BatchMatmulPerDeviceState const *meta = - *((BatchMatmulPerDeviceState **)task->local_args); - // output domains - auto output = acc.get_tensor(OUTPUT); - auto output_grad = acc.get_tensor(OUTPUT_GRAD); - assert(output == - output_grad); // is this equivalent to checking `Domain` equality? - // A domains - auto a_input = acc.get_tensor(A_INPUT); - auto a_input_grad = acc.get_tensor(A_INPUT_GRAD); + ProfilingSettings profiling = acc.get_argument(PROFILING); + auto per_device_state = acc.get_argument(PER_DEVICE_STATE); + + auto output = acc.get_tensor(OUTPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); + // is this equivalent to checking `Domain` equality? + assert(output == output_grad); + + auto a_input = acc.get_tensor(A_INPUT); + auto a_input_grad = acc.get_tensor_grad(A_INPUT); assert(a_input == a_input_grad); - // B domains - auto b_input = acc.get_tensor(B_INPUT); - auto b_input_grad = acc.get_tensor(B_INPUT_GRAD); + + auto b_input = acc.get_tensor(B_INPUT); + auto b_input_grad = acc.get_tensor_grad(B_INPUT); assert(b_input == b_input_grad); // check dins - int m = b_input.shape[0]; - assert(m == output.shape[0]); - int n = a_input.shape[1]; - assert(n == output.shape[1]); - int k = a_input.shape[0]; - assert(k == b_input.shape[1]); + int m = b_input.shape[legion_dim_t(0)]; + assert(m == output.shape[legion_dim_t(0)]); + int n = a_input.shape[legion_dim_t(1)]; + assert(n == output.shape[legion_dim_t(1)]); + int k = a_input.shape[legion_dim_t(0)]; + assert(k == b_input.shape[legion_dim_t(1)]); assert(a_input.shape.size() == b_input.shape.size()); assert(a_input.shape.size() == output.shape.size()); int batch = 1; for (int i = 2; i < a_input.shape.size(); i++) { - int dim_size = a_input.shape[i]; - assert(dim_size == b_input.shape[i]); - assert(dim_size == output.shape[i]); + int dim_size = a_input.shape[legion_dim_t(i)]; + assert(dim_size == b_input.shape[legion_dim_t(i)]); + assert(dim_size == output.shape[legion_dim_t(i)]); batch *= dim_size; } - // get pointers - float const *out_ptr = output.get_float_ptr(); - float const *out_grad_ptr = output_grad.get_float_ptr(); - float const *a_ptr = a_input.get_float_ptr(); - float *a_grad_ptr = a_input_grad.get_float_ptr(); - float const *b_ptr = b_input.get_float_ptr(); - float *b_grad_ptr = b_input_grad.get_float_ptr(); - - float *c_grad_ptr = NULL; // TODO: add support for meta->a_seq_length_dim >= 0 // or meta->b_seq_length_dim >= 0 assert((meta->a_seq_length_dim >= a_len) || (iter_config->seq_length == 0)); assert((meta->b_seq_length_dim >= b_len) || (iter_config->seq_length == 0)); - profile(backward_kernel, - meta->profiling, + return profile(backward_kernel, + profiling, "[BatchMatmul] backward_time = %.2lfms\n", - meta, - out_ptr, - out_grad_ptr, - a_ptr, - a_grad_ptr, - b_ptr, - b_grad_ptr, - c_grad_ptr, + per_device_state, + output.get_float_ptr(), + output_grad.get_float_ptr(), + a_input.get_float_ptr(), + a_input_grad.get_float_ptr(), + b_input.get_float_ptr(), + b_input_grad.get_float_ptr(), + NULL, //c_grad_ptr m, n, k, batch); } -void BatchMatmul::print_layer(FFModel const &ff) { - return; +static void backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + backward_task_impl(acc); } -bool BatchMatmul::measure_operator_cost(Simulator *sim, - MachineView const &pc, - CostMetrics &cost_metrics) const { - ParallelTensorBase sub_output, sub_input0, sub_input1; - if (!outputs[0]->get_sub_tensor(pc, sub_output)) { - return false; - } - if (!inputs[0]->get_sub_tensor(pc, sub_input0)) { - return false; - } - if (!inputs[1]->get_sub_tensor(pc, sub_input1)) { - return false; - } - - int input0_c = sub_input0.dims[0].size; - int input0_r = sub_input0.dims[1].size; - int input1_c = sub_input1.dims[0].size; - int input1_r = sub_input1.dims[1].size; - int output_c = sub_output.dims[0].size; - int output_r = sub_output.dims[1].size; - - assert(input0_c == input1_r); - assert(input0_r == output_r); - assert(input1_c == output_c); - - assert(sub_input0.dims[2] == sub_input1.dims[2]); - assert(sub_input1.dims[2] == sub_output.dims[2]); - int batch = 1; - assert(sub_input0.num_dims == sub_input1.num_dims); - for (int i = 2; i < sub_input0.num_dims; i++) { - assert(sub_input0.dims[i] == sub_input1.dims[i]); - assert(sub_input0.dims[i] == sub_output.dims[i]); - batch *= sub_input0.dims[i].size; - } - - BatchMatmulPerDeviceState *meta = sim->batch_matmul_meta; +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + BatchMatmulAttrs const &attrs, + InputParallelTensorDesc const &a_input, + InputParallelTensorDesc const &b_input, + ProfilingSettings const &settings, + MachineView const &pc) const { + auto env = sim.new_environment(); + + //todo add get_output_shape and get_weights_shape to batch_matmul op-attrs + // ParallelTensorShape output_shape = get_output_shape(attrs, inputs); + // ParallelTensorShape weight_shape = get_weights_shape(attrs, inputs); + + SimTaskBinding init_binding; + init_binding.bind_arg(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); + init_binding.bind_arg(B_SEQ_LENGTH_DIM, get_bSeqLengthDim(attrs)); + init_binding.bind_arg(HANDLE, ff_handle()); + + auto init_accessor = + env.get_init_accessor(BATCHMATMUL_INIT_TASK_ID, init_binding); + DeviceSpecificArg per_device_state = + init_task_impl(init_accessor); + + SimTaskBinding fwd_binding; + fwd_binding.bind(A_INPUT, a_input); + fwd_binding.bind(B_INPUT, b_input); + // fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind_arg(PROFILING, settings); + fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); + + SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); + + auto fwd_accessor = env.get_fwd_accessor(BATCHMATMUL_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = env.get_bwd_accessor(BATCHMATMUL_BWD_TASK_ID, bwd_binding); + + float forward_time = forward_task_impl(fwd_accessor).value(); + float backward_time = backward_task_impl(bwd_accessor).value(); + + float sync_time = default_estimate_sync_time(env); + return make_metrics(forward_time, backward_time, sync_time, env); +} - // allocate tensors in simulator - sim->free_all(); - float *a_ptr = (float *)sim->allocate(sub_input0.get_volume(), DT_FLOAT); - assert(a_ptr != NULL); - float *b_ptr = (float *)sim->allocate(sub_input1.get_volume(), DT_FLOAT); - assert(b_ptr != NULL); - float *c_ptr = NULL; - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); +template <> +void register_task() { + OpTaskSignature init(OpTaskType::INIT); - float *out_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - assert(out_ptr != NULL); - cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); + init.add_arg_slot(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); + init.add_arg_slot(B_SEQ_LENGTH_DIM, get_bSeqLengthDim(attrs)); + init.add_unchecked_arg_slot(HANDLE, ff_handle()); - int m = input1_c; - int n = input0_r; - int k = input0_c; + register_task(BATCHMATMUL_INIT_TASK_ID, "BatchMatmul Init", init, init_task); +} - assert(meta->profiling == false); +template <> +void register_task() { + OpTaskSignature fwd(OpTaskType::FWD); - std::function forward, backward; - forward = [&](ffStream_t stream) { - forward_kernel(stream, meta, out_ptr, a_ptr, b_ptr, c_ptr, m, n, k, batch); - }; + fwd.add_input_slot(A_INPUT); + fwd.add_input_slot(B_INPUT); + fwd.add_output_slot(OUTPUT); + fwd.add_arg_slot(PROFILING); + fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); - if (sim->computationMode == COMP_MODE_TRAINING) { - float *a_grad_ptr = - (float *)sim->allocate(sub_input0.get_volume(), DT_FLOAT); - float *b_grad_ptr = - (float *)sim->allocate(sub_input1.get_volume(), DT_FLOAT); - float *c_grad_ptr = NULL; - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *out_grad_ptr = - (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - assert(out_grad_ptr != NULL); - cost_metrics.outputs_memory += - cost_metrics.total_mem_diff_from(sim->offset); - - backward = [&](ffStream_t stream) { - backward_kernel(stream, - meta, - out_ptr, - out_grad_ptr, - a_ptr, - a_grad_ptr, - b_ptr, - b_grad_ptr, - c_grad_ptr, - m, - n, - k, - batch); - }; - } + register_task(BATCHMATMUL_FWD_TASK_ID, "BatchMatmul Fwd", fwd, forward_task); +} - inner_measure_operator_cost(sim, forward, backward, cost_metrics); - - if (sim->computationMode == COMP_MODE_TRAINING) { - printf("[Measure BatchMatmul] name(%s) adim(%d %d %d) bdim(%d %d %d) " - "odim(%d %d %d) forward_time(%.4lf) backward_time(%.4lf)\n", - name, - batch, - input0_r, - input0_c, - batch, - input1_r, - input1_c, - batch, - output_r, - output_c, - cost_metrics.forward_time, - cost_metrics.backward_time); - } else { - printf("[Measure BatchMatmul] name(%s) adim(%d %d %d) bdim(%d %d %d) " - "odim(%d %d %d) forward_time(%.4lf)\n", - name, - batch, - input0_r, - input0_c, - batch, - input1_r, - input1_c, - batch, - output_r, - output_c, - cost_metrics.forward_time); - } +template <> +void register_task() { + OpTaskSignature bwd = + infer_bwd_signature(get_op_signature(ATTENTION_FWD_TASK_ID)); - return true; + register_task(BATCHMATMUL_BWD_TASK_ID, "BatchMatmul Bwd", bwd, backward_task); } -} -; // namespace FlexFlow + +}; // namespace FlexFlow diff --git a/lib/runtime/src/ops/batch_matmul.h b/lib/runtime/src/ops/batch_matmul.h index c133c2a875..ab9bb45e8d 100644 --- a/lib/runtime/src/ops/batch_matmul.h +++ b/lib/runtime/src/ops/batch_matmul.h @@ -1,10 +1,14 @@ #ifndef _FLEXFLOW_BATCH_MATMUL_H #define _FLEXFLOW_BATCH_MATMUL_H +// #include "op-attrs/ops/batch_matmul.h" +// #include "task_spec/op_task_invocation.h" +// #include "task_spec/op_task_signature.h" +// #include "sim_environment.h" + #include "op-attrs/ops/batch_matmul.h" -#include "op_task_invocation.h" -#include "op_task_signature.h" #include "sim_environment.h" +#include "task_spec/op_task_invocation.h" namespace FlexFlow { @@ -19,12 +23,12 @@ OpTaskInvocation init(BatchMatmulAttrs const &); OpTaskInvocation forward(BatchMatmulAttrs const &); OpTaskInvocation backward(BatchMatmulAttrs const &); -CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, - BatchMatmulAttrs const &attrs, - ParallelTensorShape const &lhs_input_shape, - ParallelTensorShape const &rhs_input_shape, - ProfilingSettings const &settings, - MachineView const &); +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + BatchMatmulAttrs const &attrs, + InputParallelTensorDesc const &a_input, + InputParallelTensorDesc const &b_input, + ProfilingSettings const &settings, + MachineView const &pc); /* class BatchMatmul : public Op { */ /* public: */ @@ -84,3 +88,424 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, } // namespace FlexFlow #endif + + + +// BatchMatmulParams BatchMatmul::get_params() const { +// BatchMatmulParams params; +// params.a_seq_length_dim = inputs[0]->num_dims - 1 - this->a_seq_length_dim; +// params.b_seq_length_dim = inputs[1]->num_dims - 1 - this->b_seq_length_dim; +// return params; +// } + +// Tensor FFModel::batch_matmul(const Tensor A, +// const Tensor B, +// int a_seq_length_dim, +// int b_seq_length_dim, +// char const *name) { +// Layer *bmm = new Layer(this, +// OP_BATCHMATMUL, +// DT_FLOAT, +// name, +// 2 /*inputs*/, +// 0 /*weights*/, +// 1 /*outputs*/, +// A, +// B); +// assert((a_seq_length_dim <= 1) && +// "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " +// "Fortran ordering)."); +// assert((b_seq_length_dim <= 1) && +// "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " +// "Fortran ordering)."); +// assert(A->num_dims == B->num_dims); +// for (int i = A->num_dims - 1; i >= 2; i--) { +// assert(A->dims[i] == B->dims[i]); +// } +// assert(A->dims[0] == B->dims[1]); +// int dims[MAX_TENSOR_DIM]; +// int numdim = A->num_dims; +// for (int i = 0; i < numdim; i++) { +// dims[i] = A->dims[i]; +// } +// dims[0] = B->dims[0]; +// bmm->outputs[0] = create_tensor_legion_ordering( +// numdim, dims, A->data_type, bmm, 0, true /*create_grad*/); +// bmm->add_int_property("a_seq_length_dim", a_seq_length_dim); +// bmm->add_int_property("b_seq_length_dim", b_seq_length_dim); +// layers.push_back(bmm); +// return bmm->outputs[0]; +// } + +// Op *BatchMatmul::create_operator_from_layer( +// FFModel &model, +// Layer const *layer, +// std::vector const &inputs) { +// long long value; +// layer->get_int_property("a_seq_length_dim", value); +// int a_seq_length_dim = value; +// layer->get_int_property("b_seq_length_dim", value); +// int b_seq_length_dim = value; +// return new BatchMatmul(model, +// inputs[0], +// inputs[1], +// a_seq_length_dim, +// b_seq_length_dim, +// layer->name); +// } + +// BatchMatmul::BatchMatmul( +// FFModel &model, +// BatchMatmulParams const ¶ms, +// std::pair const &inputs, +// char const *name) +// : BatchMatmul(model, +// inputs.first, +// inputs.second, +// params.a_seq_length_dim, +// params.b_seq_length_dim, +// name) {} + +// // return A*B +// BatchMatmul::BatchMatmul(FFModel &model, +// const ParallelTensor A, +// const ParallelTensor B, +// int _a_seq_length_dim, +// int _b_seq_length_dim, +// char const *name) +// : Op(model, +// OP_BATCHMATMUL, +// DT_FLOAT, +// name, +// 2 /*inputs*/, +// 0 /*weights*/, +// 1 /*outputs*/, +// A, +// B), +// a_seq_length_dim(A->num_dims - 1 - _a_seq_length_dim), +// b_seq_length_dim(B->num_dims - 1 - _b_seq_length_dim) { +// assert((_a_seq_length_dim <= 1) && +// "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " +// "Fortran ordering)."); +// assert((_b_seq_length_dim <= 1) && +// "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " +// "Fortran ordering)."); +// assert(A->num_dims == B->num_dims); +// for (int i = A->num_dims - 1; i >= 2; i--) { +// assert(A->dims[i] == B->dims[i]); +// } +// assert(A->dims[0] == B->dims[1]); +// ParallelDim dims[MAX_TENSOR_DIM]; +// for (int i = 0; i < A->num_dims; i++) { +// dims[i] = A->dims[i]; +// } +// dims[0] = B->dims[0]; +// numOutputs = 1; +// outputs[0] = model.create_parallel_tensor_legion_ordering( +// A->num_dims, dims, DT_FLOAT, this); +// // C is not none +// // if (C != Tensor::NO_TENSOR) { +// // numInputs = 3; +// // assert(C.num_dims == outputs[0].num_dims); +// // for (int i = 0; i < C.num_dims; i++) +// // assert(C.adim[i] == outputs[0].adim[i]); +// //} +// } + +// void BatchMatmul::serialize(Legion::Serializer &sez) const { +// BatchMatmulParams params = get_params(); +// sez.serialize(params.a_seq_length_dim); +// sez.serialize(params.b_seq_length_dim); +// } + +// using PCG::Node; +// /*static*/ +// Node BatchMatmul::deserialize(FFModel &ff, +// Legion::Deserializer &dez, +// ParallelTensor inputs[], +// int num_inputs) { +// assert(num_inputs == 2); +// int a_seq_length_dim, b_seq_length_dim; +// dez.deserialize(a_seq_length_dim); +// dez.deserialize(b_seq_length_dim); + +// BatchMatmulParams params; +// params.a_seq_length_dim = a_seq_length_dim; +// params.b_seq_length_dim = b_seq_length_dim; +// return ff.get_or_create_node({inputs[0], inputs[1]}, params); +// } + +// Op *BatchMatmul::materialize(FFModel &ff, +// ParallelTensor inputs[], +// int num_inputs) const { +// BatchMatmulParams params = get_params(); +// return new BatchMatmul(ff, params, {inputs[0], inputs[1]}, this->name); +// } + + +// void BatchMatmul::forward(FFModel const &ff) { +// int dim = outputs[0]->num_dims; +// switch (dim) { +// #define DIMFUNC(DIM) \ +// case DIM: { \ +// // forward_with_dim(ff); +// this->execute_task(ff, BATCHMATMUL_FWD_TASK_ID, get_fwd_task_signature()); +// break; +// } +// LEGION_FOREACH_N(DIMFUNC) +// #undef DIMFUNC +// default: +// assert(false); +// } +// } + +// template +// void BatchMatmul::forward_with_dim(FFModel const &ff) { +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_forward(ff, argmap); +// IndexLauncher launcher( +// BATCHMATMUL_FWD_TASK_ID, +// parallel_is, +// TaskArgument(&ff.iter_config, sizeof(FFIterationConfig)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// for (int i = 0; i < numInputs; i++) { +// launcher.add_region_requirement(RegionRequirement(inputs[i]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[i]->region)); +// launcher.add_field(i + 1, FID_DATA); +// } +// runtime->execute_index_space(ctx, launcher); +// } + +/* + regions[0](O): output + regions[1](I): A + regions[2](I): B + ////////////////////(optional) regions[3](I): C -- TODO: is C deprecated? + output = A * B /////////+ C +*/ + + +// void BatchMatmul::init(FFModel const &ff) { +// int dim = outputs[0]->num_dims; +// switch (dim) { +// #define DIMFUNC(DIM) \ +// case DIM: { \ +// // init_with_dim(ff); +// this->execute_task(ff, BATCHMATMUL_INIT_TASK_ID, get_init_task_signature()); +// break; +// } +// LEGION_FOREACH_N(DIMFUNC) +// #undef DIMFUNC +// default: +// assert(false); +// } +// } // namespace FlexFlow +// // / +// // template +// // void BatchMatmul::init_with_dim(FFModel const &ff) { +// // assert(check_output_input_weight_same_parallel_is()); +// // parallel_is = outputs[0]->parallel_is; +// // ArgumentMap argmap; +// // Context ctx = ff.config.lg_ctx; +// // Runtime *runtime = ff.config.lg_hlr; +// // set_argumentmap_for_init(ff, argmap); +// // IndexLauncher launcher(BATCHMATMUL_INIT_TASK_ID, +// // parallel_is, +// // TaskArgument(this, sizeof(BatchMatmul)), +// // argmap, +// // Predicate::TRUE_PRED, +// // false /*must*/, +// // 0 /*mapper_id*/, +// // outputs[0]->machine_view.hash()); +// // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// // 0 /*projection id*/, +// // WRITE_ONLY, +// // EXCLUSIVE, +// // outputs[0]->region)); +// // launcher.add_field(0, FID_DATA); +// // for (int i = 0; i < numInputs; i++) { +// // launcher.add_region_requirement(RegionRequirement(inputs[i]->part, +// // 0 /*projection id*/, +// // READ_ONLY, +// // EXCLUSIVE, +// // inputs[i]->region)); +// // launcher.add_field(i + 1, FID_DATA); +// // } +// // FutureMap fm = runtime->execute_index_space(ctx, launcher); +// // fm.wait_all_results(); +// // set_opmeta_from_futuremap(ff, fm); +// // } + +// OpTaskBinding BatchMatmul::get_bwd_task_binding() const { +// OpTaskBinding binding; +// binding.bind(A_INPUT, input_tensor(0)); +// binding.bind(B_INPUT, input_tensor(1)); +// binding.bind_grad(A_INPUT_GRAD, input_tensor(0).grad()); +// binding.bind_grad(B_INPUT_GRAD, input_tensor(1).grad()); + +// binding.bind(OUTPUT, output_tensor(0)); +// binding.bind_grad(OUTPUT_GRAD, output_tensor(0).grad()); + +// binding.bind_arg(ATTRS, this->attrs); +// return binding; +// } + + +// static OpTaskSignature get_fwd_task_signature() { +// OpTaskSignature fwd(OpTaskType::FWD); + +// fwd.add_input_slot(A_INPUT, READ_WRITE); +// fwd.add_input_slot(B_INPUT, READ_WRITE); +// fwd.add_output_slot(OUTPUT); + +// return fwd; +// } + +// static OpTaskSignature get_bwd_task_signature() { +// OpTaskSignature bwd(OpTaskType::BWD); + +// bwd.add_input_slot(A_INPUT); +// bwd.add_input_slot(B_INPUT); +// bwd.add_input_grad_slot(A_INPUT_GRAD); +// bwd.add_input_grad_slot(B_INPUT_GRAD); +// bwd.add_output_slot(OUTPUT); +// bwd.add_output_grad_slot(OUTPUT_GRAD); + +// return bwd; +// } + +// OpTaskBinding BatchMatmul::get_init_task_binding() const { +// OpTaskBinding binding; + +// binding.bind_arg(ATTRS, this->attrs); +// binding.bind_arg(PROFILING, this->profiling); + +// return binding; +// } + +// OpTaskBinding BatchMatmul::get_fwd_task_binding() const { +// OpTaskBinding binding; + +// binding.bind(A_INPUT, input_tensor(0)); +// binding.bind(B_INPUT, input_tensor(1)); +// binding.bind(OUTPUT, output_tensor(0)); + +// binding.bind_arg(ATTRS, this->attrs); +// return binding; +// } + +//void BatchMatmul::backward(FFModel const &ff) { +// int dim = outputs[0]->num_dims; +// switch (dim) { +// #define DIMFUNC(DIM) \ +// case DIM: { \ +// backward_with_dim(ff); \ +// break; \ +// } +// LEGION_FOREACH_N(DIMFUNC) +// #undef DIMFUNC +// default: +// assert(false); +// } +// } + + +// void BatchMatmul::print_layer(FFModel const &ff) { +// return; +// } + + + +/* + regions[0](I): output + regions[1](I): output_grad + regions[2](I): A + regions[3](I/O): A_grad + regions[4](I): B + regions[5](I/O): B_grad + regions[6](I/O): C_grad +*/ +// template +// void BatchMatmul::backward_with_dim(FFModel const &ff) { +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_backward(ff, argmap); +// IndexLauncher launcher( +// BATCHMATMUL_BWD_TASK_ID, +// parallel_is, +// TaskArgument(&ff.iter_config, sizeof(FFIterationConfig)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// // regions[0](I): output +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// // regions[1](I): output_grad +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// outputs[0]->region_grad)); +// launcher.add_field(1, FID_DATA); +// // regions[2](I): A +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(2, FID_DATA); +// // regions[3](I/O): A_grad +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// inputs[0]->region_grad)); +// launcher.add_field(3, FID_DATA); +// // regions[4](I): B +// launcher.add_region_requirement(RegionRequirement(inputs[1]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[1]->region)); +// launcher.add_field(4, FID_DATA); +// // regions[5](I/O): B_grad +// launcher.add_region_requirement(RegionRequirement(inputs[1]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// inputs[1]->region_grad)); +// launcher.add_field(5, FID_DATA); +// runtime->execute_index_space(ctx, launcher); +// } + +/* + regions[0](I): output + regions[1](I): output_grad + regions[2](I): A + regions[3](I/O): A_grad + regions[4](I): B + regions[5](I/O): B_grad + regions[6](I/O): C_grad +*/ \ No newline at end of file From d9aae744914fb6d710a5f174d4e88cc0cd7a95d5 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 22 Aug 2023 14:29:35 -0700 Subject: [PATCH 02/19] fix FF_VISITABLE_STRUCT_NO_EQ --- lib/kernels/include/kernels/batch_matmul_kernels.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index cdffdf1907..b8e6b9cd4b 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -14,7 +14,10 @@ struct BMMPerDeviceState { }; FF_VISITABLE_STRUCT_NO_EQ(BMMPerDeviceState, - handle,); + handle, + allocator, + a_seq_length_dim, + b_seq_length_dim); namespace Kernels { namespace BatchMatmul { From 2073822dbc697615407763f10a8383e864c6929e Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Wed, 23 Aug 2023 12:03:39 -0700 Subject: [PATCH 03/19] finish draft 1 batch_matmul --- .../include/kernels/batch_matmul_kernels.h | 7 ++-- lib/runtime/include/runtime/config.h | 8 +++-- lib/runtime/src/ops/batch_matmul.cc | 32 ++++++++++--------- lib/runtime/src/task_spec/op_arg_ref.h | 8 ++++- 4 files changed, 32 insertions(+), 23 deletions(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index b8e6b9cd4b..1506305e9e 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -10,7 +10,7 @@ struct BMMPerDeviceState { PerDeviceFFHandle handle; Allocator allocator; int a_seq_length_dim; - int b_seq_length_dim; + req b_seq_length_dim; }; FF_VISITABLE_STRUCT_NO_EQ(BMMPerDeviceState, @@ -28,7 +28,7 @@ BMMPerDeviceState init_kernel(PerDeviceFFHandle handle, int b_seq_length_dim); void forward_kernel(ffStream_t stream, - BMMPerDeviceState const *meta, + BMMPerDeviceState const &meta, float *o_ptr, float const *a_ptr, float const *b_ptr, @@ -40,14 +40,13 @@ void forward_kernel(ffStream_t stream, int seq_length = -1); void backward_kernel(ffStream_t stream, - BMMPerDeviceState const *meta, + BMMPerDeviceState const &meta, float const *o_ptr, float const *o_grad_ptr, float const *a_ptr, float *a_grad_ptr, float const *b_ptr, float *b_grad_ptr, - float *c_grad_ptr, int m, int n, int k, diff --git a/lib/runtime/include/runtime/config.h b/lib/runtime/include/runtime/config.h index a7b8d86171..fc45c2c76c 100644 --- a/lib/runtime/include/runtime/config.h +++ b/lib/runtime/include/runtime/config.h @@ -104,13 +104,15 @@ struct FFConfig : public use_visitable_cmp { int python_data_loader_type = 2; }; -class FFIterationConfig { -public: +struct FFIterationConfig { FFIterationConfig(); void reset(); - int seq_length; + req seq_length; }; +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(FFIterationConfig, + seq_length); + enum FieldIDs { FID_DATA, }; diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 9bbc050f8b..761c565657 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -41,7 +41,8 @@ enum Slots { HANDLE, A_SEQ_LENGTH_DIM, B_SEQ_LENGTH_DIM, - PER_DEVICE_STATE + PER_DEVICE_STATE, + ITERATION_CONFIG }; OpTaskInvocation init(BatchMatmulAttrs const &attrs) { @@ -63,6 +64,7 @@ OpTaskInvocation forward(BatchMatmulAttrs const &attrs) { fwd.bind_arg(PROFILING, profiling_settings()); fwd.bind_arg(PER_DEVICE_STATE, per_device_op_state()); + fwd.bind_arg(ITERATION_CONFIG, iteration_config()); return {BATCHMATMUL_FWD_TASK_ID, fwd}; } @@ -110,7 +112,7 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto per_device_state = acc.get_argument(PER_DEVICE_STATE); - FFIterationConfig const *iter_config = (FFIterationConfig const *)task->args; + FFIterationConfig iter_config = acc.get_argument(ITERATION_CONFIG); int m = b_input.shape[legion_dim_t(0)]; assert(m == output.shape[legion_dim_t(0)]); @@ -123,7 +125,7 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { assert(a_input.shape.size() == output.shape.size()); int batch = 1; - for (int i = 2; i < a_input.shape.size(); i++) { + for (int i = 2; i < a_input.shape.get_dim(); i++) { //get_dim() or get_volume()? int dim_size = a_input.shape[legion_dim_t(i)]; assert(dim_size == b_input.shape[legion_dim_t(i)]); assert(dim_size == output.shape[legion_dim_t(i)]); @@ -137,12 +139,12 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { output.get_float_ptr(), a_input.get_float_ptr(), b_input.get_float_ptr(), - NULL, //c_ptr + nullptr, //c_ptr m, n, k, batch, - iter_config->seq_length); + iter_config.seq_length); } static void forward_task(Task const *task, @@ -159,7 +161,7 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(task->regions.size() == 6); // BatchMatmul* bmm = (BatchMatmul*) task->args; - FFIterationConfig const *iter_config = (FFIterationConfig const *)task->args; + FFIterationConfig iter_config = acc.get_argument(ITERATION_CONFIG); ProfilingSettings profiling = acc.get_argument(PROFILING); auto per_device_state = acc.get_argument(PER_DEVICE_STATE); @@ -186,7 +188,7 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(a_input.shape.size() == b_input.shape.size()); assert(a_input.shape.size() == output.shape.size()); int batch = 1; - for (int i = 2; i < a_input.shape.size(); i++) { + for (int i = 2; i < a_input.shape.get_dim(); i++) { //@colin get_dim() or get_volume()? int dim_size = a_input.shape[legion_dim_t(i)]; assert(dim_size == b_input.shape[legion_dim_t(i)]); assert(dim_size == output.shape[legion_dim_t(i)]); @@ -195,8 +197,8 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { // TODO: add support for meta->a_seq_length_dim >= 0 // or meta->b_seq_length_dim >= 0 - assert((meta->a_seq_length_dim >= a_len) || (iter_config->seq_length == 0)); - assert((meta->b_seq_length_dim >= b_len) || (iter_config->seq_length == 0)); + assert((meta->a_seq_length_dim >= a_len) || (iter_config.seq_length == 0)); + assert((meta->b_seq_length_dim >= b_len) || (iter_config.seq_length == 0)); return profile(backward_kernel, profiling, @@ -208,7 +210,6 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { a_input_grad.get_float_ptr(), b_input.get_float_ptr(), b_input_grad.get_float_ptr(), - NULL, //c_grad_ptr m, n, k, @@ -228,10 +229,10 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, InputParallelTensorDesc const &a_input, InputParallelTensorDesc const &b_input, ProfilingSettings const &settings, - MachineView const &pc) const { + MachineView const &pc) { auto env = sim.new_environment(); - //todo add get_output_shape and get_weights_shape to batch_matmul op-attrs + // @colin todo add get_output_shape and get_weights_shape to batch_matmul op-attrs // ParallelTensorShape output_shape = get_output_shape(attrs, inputs); // ParallelTensorShape weight_shape = get_weights_shape(attrs, inputs); @@ -248,6 +249,7 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, SimTaskBinding fwd_binding; fwd_binding.bind(A_INPUT, a_input); fwd_binding.bind(B_INPUT, b_input); + //@colin will uncomment after get_output_shape is done // fwd_binding.bind(OUTPUT, output_shape); fwd_binding.bind_arg(PROFILING, settings); fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); @@ -268,9 +270,9 @@ template <> void register_task() { OpTaskSignature init(OpTaskType::INIT); - init.add_arg_slot(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); - init.add_arg_slot(B_SEQ_LENGTH_DIM, get_bSeqLengthDim(attrs)); - init.add_unchecked_arg_slot(HANDLE, ff_handle()); + init.add_arg_slot(A_SEQ_LENGTH_DIM); + init.add_arg_slot(B_SEQ_LENGTH_DIM); + init.add_unchecked_arg_slot(HANDLE); register_task(BATCHMATMUL_INIT_TASK_ID, "BatchMatmul Init", init, init_task); } diff --git a/lib/runtime/src/task_spec/op_arg_ref.h b/lib/runtime/src/task_spec/op_arg_ref.h index 6e921c05e8..bc4f38e5fb 100644 --- a/lib/runtime/src/task_spec/op_arg_ref.h +++ b/lib/runtime/src/task_spec/op_arg_ref.h @@ -4,10 +4,12 @@ #include "arg_ref.h" #include "device_specific_arg.h" #include "op-attrs/parallel_tensor_shape.h" +#include "runtime/config.h" + namespace FlexFlow { -enum class OpArgRefType { PER_DEVICE_OP_STATE, PARALLEL_TENSOR_SHAPE }; +enum class OpArgRefType { PER_DEVICE_OP_STATE, PARALLEL_TENSOR_SHAPE, ITERATION_CONFIG }; template using OpArgRef = ArgRef; @@ -23,6 +25,10 @@ OpArgRef input_parallel_tensor_shape(int idx) { return {OpArgRefType::PARALLEL_TENSOR_SHAPE}; } +OpArgRef iteration_config() { + return {OpArgRefType::ITERATION_CONFIG}; +} + } // namespace FlexFlow #endif From d5806a01163ab9fb13f0457c7cc4ba4e330b41d0 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Wed, 23 Aug 2023 13:35:55 -0700 Subject: [PATCH 04/19] add output and weights --- lib/op-attrs/include/op-attrs/ops/batch_matmul.h | 3 +++ lib/runtime/src/ops/batch_matmul.cc | 15 ++++----------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index 00c700ba20..fa80ed3da9 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -14,6 +14,9 @@ FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim); int get_aSeqLengthDim(BatchMatmulAttrs const &attrs); int get_bSeqLengthDim(BatchMatmulAttrs const &attrs); +ParallelTensorShape get_weights_shape(BatchMatmulAttrs const &, + ParallelTensorShape const &, + ParallelTensorShape const &); CHECK_VALID_OP_ATTR(BatchMatmulAttrs); } // namespace FlexFlow diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 761c565657..3fabc86d1a 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -13,16 +13,11 @@ * limitations under the License. */ -// #include "batch_matmul.h" -// #include "kernels/batch_matmul_kernels.h" -// #include "kernels/profiling.h" -// #include "legion/legion_utilities.h" -// #include "tasks.h" - #include "batch_matmul.h" #include "kernels/batch_matmul_kernels.h" #include "legion.h" #include "op-attrs/ops/batch_matmul.h" +#include "op-attrs/get_output_shapes.h" namespace FlexFlow { @@ -232,9 +227,8 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, MachineView const &pc) { auto env = sim.new_environment(); - // @colin todo add get_output_shape and get_weights_shape to batch_matmul op-attrs - // ParallelTensorShape output_shape = get_output_shape(attrs, inputs); - // ParallelTensorShape weight_shape = get_weights_shape(attrs, inputs); + ParallelTensorShape output_shape = get_output_shape(attrs, a_input.shape, b_input.shape); + ParallelTensorShape weight_shape = get_weights_shape(attrs, a_input.shape, b_input.shape); SimTaskBinding init_binding; init_binding.bind_arg(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); @@ -249,8 +243,7 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, SimTaskBinding fwd_binding; fwd_binding.bind(A_INPUT, a_input); fwd_binding.bind(B_INPUT, b_input); - //@colin will uncomment after get_output_shape is done - // fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind(OUTPUT, output_shape); fwd_binding.bind_arg(PROFILING, settings); fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); From f620070ea8bc8e7373915a5063e1a5977ed7f433 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Wed, 23 Aug 2023 13:45:00 -0700 Subject: [PATCH 05/19] format --- .../include/kernels/batch_matmul_kernels.h | 15 +-- .../include/op-attrs/ops/batch_matmul.h | 4 +- lib/runtime/include/runtime/config.h | 3 +- lib/runtime/src/ops/batch_matmul.cc | 106 +++++++++--------- lib/runtime/src/ops/batch_matmul.h | 62 +++++----- lib/runtime/src/task_spec/op_arg_ref.h | 7 +- 6 files changed, 97 insertions(+), 100 deletions(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index 1506305e9e..23522a89f7 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -13,19 +13,16 @@ struct BMMPerDeviceState { req b_seq_length_dim; }; -FF_VISITABLE_STRUCT_NO_EQ(BMMPerDeviceState, - handle, - allocator, - a_seq_length_dim, - b_seq_length_dim); +FF_VISITABLE_STRUCT_NO_EQ( + BMMPerDeviceState, handle, allocator, a_seq_length_dim, b_seq_length_dim); namespace Kernels { namespace BatchMatmul { BMMPerDeviceState init_kernel(PerDeviceFFHandle handle, - Allocator allocator, - int a_seq_length_dim, - int b_seq_length_dim); + Allocator allocator, + int a_seq_length_dim, + int b_seq_length_dim); void forward_kernel(ffStream_t stream, BMMPerDeviceState const &meta, @@ -40,7 +37,7 @@ void forward_kernel(ffStream_t stream, int seq_length = -1); void backward_kernel(ffStream_t stream, - BMMPerDeviceState const &meta, + BMMPerDeviceState const &meta, float const *o_ptr, float const *o_grad_ptr, float const *a_ptr, diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index fa80ed3da9..6ac2a2cf69 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -15,8 +15,8 @@ FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim); int get_aSeqLengthDim(BatchMatmulAttrs const &attrs); int get_bSeqLengthDim(BatchMatmulAttrs const &attrs); ParallelTensorShape get_weights_shape(BatchMatmulAttrs const &, - ParallelTensorShape const &, - ParallelTensorShape const &); + ParallelTensorShape const &, + ParallelTensorShape const &); CHECK_VALID_OP_ATTR(BatchMatmulAttrs); } // namespace FlexFlow diff --git a/lib/runtime/include/runtime/config.h b/lib/runtime/include/runtime/config.h index fc45c2c76c..21cfca20cf 100644 --- a/lib/runtime/include/runtime/config.h +++ b/lib/runtime/include/runtime/config.h @@ -110,8 +110,7 @@ struct FFIterationConfig { req seq_length; }; -FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(FFIterationConfig, - seq_length); +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(FFIterationConfig, seq_length); enum FieldIDs { FID_DATA, diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 3fabc86d1a..0f4ad9e865 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -16,8 +16,8 @@ #include "batch_matmul.h" #include "kernels/batch_matmul_kernels.h" #include "legion.h" -#include "op-attrs/ops/batch_matmul.h" #include "op-attrs/get_output_shapes.h" +#include "op-attrs/ops/batch_matmul.h" namespace FlexFlow { @@ -29,16 +29,16 @@ using Legion::Runtime; using Legion::Task; enum Slots { - A_INPUT, //tensor - B_INPUT, //tensor - OUTPUT, //tensor + A_INPUT, // tensor + B_INPUT, // tensor + OUTPUT, // tensor PROFILING, HANDLE, A_SEQ_LENGTH_DIM, B_SEQ_LENGTH_DIM, PER_DEVICE_STATE, ITERATION_CONFIG - }; +}; OpTaskInvocation init(BatchMatmulAttrs const &attrs) { OpTaskBinding init; @@ -70,18 +70,16 @@ OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { return {BATCHMATMUL_BWD_TASK_ID, bwd}; } -static DeviceSpecificArg init_task_impl(TaskArgumentAccessor const &acc) { +static DeviceSpecificArg + init_task_impl(TaskArgumentAccessor const &acc) { auto const a_seq_length_dim = acc.get_argument(A_SEQ_LENGTH_DIM); auto const b_seq_length_dim = acc.get_argument(B_SEQ_LENGTH_DIM); PerDeviceFFHandle handle = acc.get_argument(HANDLE); Allocator allocator = acc.get_allocator(); - DeviceSpecificArg per_device_state = + DeviceSpecificArg per_device_state = acc.create_device_specific( - init_kernel(handle, - allocator, - a_seq_length_dim, - b_seq_length_dim)); + init_kernel(handle, allocator, a_seq_length_dim, b_seq_length_dim)); // assert(weight.shape.get_volume() * sizeof(float) == // acc.unwrap(per_device_state)->weightSize); @@ -107,7 +105,8 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto per_device_state = acc.get_argument(PER_DEVICE_STATE); - FFIterationConfig iter_config = acc.get_argument(ITERATION_CONFIG); + FFIterationConfig iter_config = + acc.get_argument(ITERATION_CONFIG); int m = b_input.shape[legion_dim_t(0)]; assert(m == output.shape[legion_dim_t(0)]); @@ -120,7 +119,8 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { assert(a_input.shape.size() == output.shape.size()); int batch = 1; - for (int i = 2; i < a_input.shape.get_dim(); i++) { //get_dim() or get_volume()? + for (int i = 2; i < a_input.shape.get_dim(); + i++) { // get_dim() or get_volume()? int dim_size = a_input.shape[legion_dim_t(i)]; assert(dim_size == b_input.shape[legion_dim_t(i)]); assert(dim_size == output.shape[legion_dim_t(i)]); @@ -128,18 +128,18 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { } return profile(forward_kernel, - profiling, - "[BatchMatmul] forward_time = %.2lfms\n", - per_device_state, - output.get_float_ptr(), - a_input.get_float_ptr(), - b_input.get_float_ptr(), - nullptr, //c_ptr - m, - n, - k, - batch, - iter_config.seq_length); + profiling, + "[BatchMatmul] forward_time = %.2lfms\n", + per_device_state, + output.get_float_ptr(), + a_input.get_float_ptr(), + b_input.get_float_ptr(), + nullptr, // c_ptr + m, + n, + k, + batch, + iter_config.seq_length); } static void forward_task(Task const *task, @@ -156,14 +156,15 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(task->regions.size() == 6); // BatchMatmul* bmm = (BatchMatmul*) task->args; - FFIterationConfig iter_config = acc.get_argument(ITERATION_CONFIG); + FFIterationConfig iter_config = + acc.get_argument(ITERATION_CONFIG); ProfilingSettings profiling = acc.get_argument(PROFILING); auto per_device_state = acc.get_argument(PER_DEVICE_STATE); auto output = acc.get_tensor(OUTPUT); auto output_grad = acc.get_tensor_grad(OUTPUT); // is this equivalent to checking `Domain` equality? - assert(output == output_grad); + assert(output == output_grad); auto a_input = acc.get_tensor(A_INPUT); auto a_input_grad = acc.get_tensor_grad(A_INPUT); @@ -183,7 +184,8 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(a_input.shape.size() == b_input.shape.size()); assert(a_input.shape.size() == output.shape.size()); int batch = 1; - for (int i = 2; i < a_input.shape.get_dim(); i++) { //@colin get_dim() or get_volume()? + for (int i = 2; i < a_input.shape.get_dim(); + i++) { //@colin get_dim() or get_volume()? int dim_size = a_input.shape[legion_dim_t(i)]; assert(dim_size == b_input.shape[legion_dim_t(i)]); assert(dim_size == output.shape[legion_dim_t(i)]); @@ -196,19 +198,19 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert((meta->b_seq_length_dim >= b_len) || (iter_config.seq_length == 0)); return profile(backward_kernel, - profiling, - "[BatchMatmul] backward_time = %.2lfms\n", - per_device_state, - output.get_float_ptr(), - output_grad.get_float_ptr(), - a_input.get_float_ptr(), - a_input_grad.get_float_ptr(), - b_input.get_float_ptr(), - b_input_grad.get_float_ptr(), - m, - n, - k, - batch); + profiling, + "[BatchMatmul] backward_time = %.2lfms\n", + per_device_state, + output.get_float_ptr(), + output_grad.get_float_ptr(), + a_input.get_float_ptr(), + a_input_grad.get_float_ptr(), + b_input.get_float_ptr(), + b_input_grad.get_float_ptr(), + m, + n, + k, + batch); } static void backward_task(Task const *task, @@ -220,15 +222,17 @@ static void backward_task(Task const *task, } CostMetrics measure_operator_cost(SimEnvFactory const &sim, - BatchMatmulAttrs const &attrs, - InputParallelTensorDesc const &a_input, - InputParallelTensorDesc const &b_input, - ProfilingSettings const &settings, - MachineView const &pc) { + BatchMatmulAttrs const &attrs, + InputParallelTensorDesc const &a_input, + InputParallelTensorDesc const &b_input, + ProfilingSettings const &settings, + MachineView const &pc) { auto env = sim.new_environment(); - ParallelTensorShape output_shape = get_output_shape(attrs, a_input.shape, b_input.shape); - ParallelTensorShape weight_shape = get_weights_shape(attrs, a_input.shape, b_input.shape); + ParallelTensorShape output_shape = + get_output_shape(attrs, a_input.shape, b_input.shape); + ParallelTensorShape weight_shape = + get_weights_shape(attrs, a_input.shape, b_input.shape); SimTaskBinding init_binding; init_binding.bind_arg(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); @@ -249,8 +253,10 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); - auto fwd_accessor = env.get_fwd_accessor(BATCHMATMUL_FWD_TASK_ID, fwd_binding); - auto bwd_accessor = env.get_bwd_accessor(BATCHMATMUL_BWD_TASK_ID, bwd_binding); + auto fwd_accessor = + env.get_fwd_accessor(BATCHMATMUL_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = + env.get_bwd_accessor(BATCHMATMUL_BWD_TASK_ID, bwd_binding); float forward_time = forward_task_impl(fwd_accessor).value(); float backward_time = backward_task_impl(bwd_accessor).value(); diff --git a/lib/runtime/src/ops/batch_matmul.h b/lib/runtime/src/ops/batch_matmul.h index ab9bb45e8d..42c9bc23de 100644 --- a/lib/runtime/src/ops/batch_matmul.h +++ b/lib/runtime/src/ops/batch_matmul.h @@ -24,11 +24,11 @@ OpTaskInvocation forward(BatchMatmulAttrs const &); OpTaskInvocation backward(BatchMatmulAttrs const &); CostMetrics measure_operator_cost(SimEnvFactory const &sim, - BatchMatmulAttrs const &attrs, - InputParallelTensorDesc const &a_input, - InputParallelTensorDesc const &b_input, - ProfilingSettings const &settings, - MachineView const &pc); + BatchMatmulAttrs const &attrs, + InputParallelTensorDesc const &a_input, + InputParallelTensorDesc const &b_input, + ProfilingSettings const &settings, + MachineView const &pc); /* class BatchMatmul : public Op { */ /* public: */ @@ -89,8 +89,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, #endif - - // BatchMatmulParams BatchMatmul::get_params() const { // BatchMatmulParams params; // params.a_seq_length_dim = inputs[0]->num_dims - 1 - this->a_seq_length_dim; @@ -242,15 +240,14 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, // return new BatchMatmul(ff, params, {inputs[0], inputs[1]}, this->name); // } - // void BatchMatmul::forward(FFModel const &ff) { // int dim = outputs[0]->num_dims; // switch (dim) { -// #define DIMFUNC(DIM) \ -// case DIM: { \ +// #define DIMFUNC(DIM) \ +// case DIM: { \ // // forward_with_dim(ff); -// this->execute_task(ff, BATCHMATMUL_FWD_TASK_ID, get_fwd_task_signature()); -// break; +// this->execute_task(ff, BATCHMATMUL_FWD_TASK_ID, +// get_fwd_task_signature()); break; // } // LEGION_FOREACH_N(DIMFUNC) // #undef DIMFUNC @@ -299,15 +296,14 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, output = A * B /////////+ C */ - // void BatchMatmul::init(FFModel const &ff) { // int dim = outputs[0]->num_dims; // switch (dim) { -// #define DIMFUNC(DIM) \ -// case DIM: { \ +// #define DIMFUNC(DIM) \ +// case DIM: { \ // // init_with_dim(ff); -// this->execute_task(ff, BATCHMATMUL_INIT_TASK_ID, get_init_task_signature()); -// break; +// this->execute_task(ff, BATCHMATMUL_INIT_TASK_ID, +// get_init_task_signature()); break; // } // LEGION_FOREACH_N(DIMFUNC) // #undef DIMFUNC @@ -365,7 +361,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, // return binding; // } - // static OpTaskSignature get_fwd_task_signature() { // OpTaskSignature fwd(OpTaskType::FWD); @@ -409,28 +404,25 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, // return binding; // } -//void BatchMatmul::backward(FFModel const &ff) { -// int dim = outputs[0]->num_dims; -// switch (dim) { -// #define DIMFUNC(DIM) \ -// case DIM: { \ -// backward_with_dim(ff); \ -// break; \ +// void BatchMatmul::backward(FFModel const &ff) { +// int dim = outputs[0]->num_dims; +// switch (dim) { +// #d ef ine DIMFUNC(DIM) \ +// case DIM: { \ +// backward_with_dim(ff); \ +// break; \ // } -// LEGION_FOREACH_N(DIMFUNC) -// #undef DIMFUNC -// default: -// assert(false); -// } -// } - +// LEGION_FOREACH_N(DIMFUNC) +// #undef DIMFUNC +// default: +// assert(false); +// } +// } // void BatchMatmul::print_layer(FFModel const &ff) { // return; // } - - /* regions[0](I): output regions[1](I): output_grad @@ -508,4 +500,4 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, regions[4](I): B regions[5](I/O): B_grad regions[6](I/O): C_grad -*/ \ No newline at end of file +*/ diff --git a/lib/runtime/src/task_spec/op_arg_ref.h b/lib/runtime/src/task_spec/op_arg_ref.h index bc4f38e5fb..07316c877b 100644 --- a/lib/runtime/src/task_spec/op_arg_ref.h +++ b/lib/runtime/src/task_spec/op_arg_ref.h @@ -6,10 +6,13 @@ #include "op-attrs/parallel_tensor_shape.h" #include "runtime/config.h" - namespace FlexFlow { -enum class OpArgRefType { PER_DEVICE_OP_STATE, PARALLEL_TENSOR_SHAPE, ITERATION_CONFIG }; +enum class OpArgRefType { + PER_DEVICE_OP_STATE, + PARALLEL_TENSOR_SHAPE, + ITERATION_CONFIG +}; template using OpArgRef = ArgRef; From 5458923bc344d2cefbd6bc18b7c3719e73291ba8 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Wed, 23 Aug 2023 15:23:36 -0700 Subject: [PATCH 06/19] fix DeviceSpecific --- lib/runtime/src/ops/batch_matmul.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 0f4ad9e865..59570f4417 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -70,14 +70,14 @@ OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { return {BATCHMATMUL_BWD_TASK_ID, bwd}; } -static DeviceSpecificArg +static DeviceSpecific init_task_impl(TaskArgumentAccessor const &acc) { auto const a_seq_length_dim = acc.get_argument(A_SEQ_LENGTH_DIM); auto const b_seq_length_dim = acc.get_argument(B_SEQ_LENGTH_DIM); PerDeviceFFHandle handle = acc.get_argument(HANDLE); Allocator allocator = acc.get_allocator(); - DeviceSpecificArg per_device_state = + DeviceSpecific per_device_state = acc.create_device_specific( init_kernel(handle, allocator, a_seq_length_dim, b_seq_length_dim)); @@ -86,7 +86,7 @@ static DeviceSpecificArg return per_device_state; } -static DeviceSpecificArg +static DeviceSpecific init_task(Task const *task, std::vector const ®ions, Context ctx, @@ -241,7 +241,7 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, auto init_accessor = env.get_init_accessor(BATCHMATMUL_INIT_TASK_ID, init_binding); - DeviceSpecificArg per_device_state = + DeviceSpecific per_device_state = init_task_impl(init_accessor); SimTaskBinding fwd_binding; From 2f604dca90edf8ecaef4dade5bbb6d0fc4ed612f Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Fri, 25 Aug 2023 18:57:40 -0700 Subject: [PATCH 07/19] batch_norm --- .../include/kernels/batch_norm_kernels.h | 70 +- lib/runtime/src/ops/batch_norm.cc | 700 ++++++------------ lib/runtime/src/ops/batch_norm.h | 242 +++++- 3 files changed, 537 insertions(+), 475 deletions(-) diff --git a/lib/kernels/include/kernels/batch_norm_kernels.h b/lib/kernels/include/kernels/batch_norm_kernels.h index 6ff90299db..74dfc96068 100644 --- a/lib/kernels/include/kernels/batch_norm_kernels.h +++ b/lib/kernels/include/kernels/batch_norm_kernels.h @@ -8,30 +8,66 @@ namespace FlexFlow { -class BatchNormPerDeviceState : public PerDeviceOpState { -public: - BatchNormPerDeviceState(FFHandler handle, - std::unique_ptr allocator, - int output_n, - int output_c, - int output_h, - int output_w, - bool relu, - bool profiling); - ~BatchNormPerDeviceState(void); - - ffTensorDescriptor_t inputTensor, outputTensor, biasTensor; +struct BatchNormPerDeviceState { + PerDeviceFFHandle handle; + Allocator allocator; + ffTensorDescriptor_t inputTensor; + ffTensorDescriptor_t outputTensor; + ffTensorDescriptor_t biasTensor; ffActivationDescriptor_t actiDesc; ffBatchNormMode_t mode; - float *runningMean, *runningVar, *saveMean, *saveVar; - bool relu; - bool profiling; - std::unique_ptr allocator; + float *runningMean; + float *runningVar; + float *saveMean; + float *saveVar; + int output_n; + int output_c; + int output_h; + int output_w; + ProfilingSettings profiling; + req relu; }; +FF_VISITABLE_STRUCT_NO_EQ(BatchNormPerDeviceState, + handle, + allocator, + inputTensor, + outputTensor, + biasTensor, + actiDesc, + mode, + runningMean, + runningVar, + saveMean, + saveVar, + output_n, + output_c, + output_h, + output_w, + profiling, + relu); + namespace Kernels { namespace BatchNorm { +BatchNormPerDeviceState init_kernel(PerDeviceFFHandle handle, + Allocator allocator, + ffTensorDescriptor_t inputTensor, + ffTensorDescriptor_t outputTensor, + ffTensorDescriptor_t biasTensor, + ffActivationDescriptor_t actiDesc, + ffBatchNormMode_t mode, + float *runningMean, + float *runningVar, + float *saveMean, + float *saveVar, + int output_n, + int output_c, + int output_h, + int output_w, + ProfilingSettings profiling, + bool relu); + void forward_kernel(ffStream_t stream, BatchNormPerDeviceState *m, float const *input_ptr, diff --git a/lib/runtime/src/ops/batch_norm.cc b/lib/runtime/src/ops/batch_norm.cc index 98cc4576a1..ffd52c96fb 100644 --- a/lib/runtime/src/ops/batch_norm.cc +++ b/lib/runtime/src/ops/batch_norm.cc @@ -16,505 +16,291 @@ #include "batch_norm.h" #include "kernels/batch_norm_kernels.h" #include "legion/legion_utilities.h" -#include "task_spec.h" + +namespace FlexFlow { using namespace FlexFlow::Kernels::BatchNorm; -namespace FlexFlow { +using Legion::Context; +using Legion::PhysicalRegion; +using Legion::Runtime; +using Legion::Task; enum Slots { - INPUT, - SCALE, - BIAS, - OUTPUT, - INPUT_GRAD, - SCALE_GRAD, - BIAS_GRAD, - OUTPUT_GRAD, + INPUT, // tensor + SCALE, // tensor + BIAS, // tensor + OUTPUT, // tensor ATTRS, - PROFILING -} - -Tensor - FFModel::batch_norm(const Tensor input, bool relu, char const *name) { - assert(input->num_dims == 4); /*NCHW*/ - Layer *bm = new Layer(this, - OP_BATCHNORM, - DT_FLOAT, - name, - 1 /*inputs*/, - 2 /*weights*/, - 1 /*outputs*/, - input); - int numdims = 4; - bm->outputs[0] = create_tensor_legion_ordering( - numdims, input->dims, DT_FLOAT, bm, 0, true /*create_grad*/); - bm->add_int_property("relu", relu); - layers.push_back(bm); - return bm->outputs[0]; -} - -/* - locals[0] = scale - locals[1] = bias -*/ -BatchNorm::BatchNorm(FFModel &model, - const ParallelTensor _input, - const ParallelTensor _scale, - const ParallelTensor _bias, - bool _relu, - char const *name) - : Op(model, - OP_BATCHNORM, - DT_FLOAT, - name, - 1 /*inputs*/, - 2 /*weights*/, - 1 /*outputs*/, - _input, - _scale, - _bias), - relu(_relu) { - assert(_input->num_dims == 4); - numOutputs = 1; - ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < _input->num_dims; i++) { - dims[i] = _input->dims[_input->num_dims - 1 - i]; - } - outputs[0] = - model.create_parallel_tensor(_input->num_dims, dims, DT_FLOAT, this); - return; -} - -static OpTaskSignature get_init_task_signature() { - OpTaskSignature init(OpTaskType::INIT); - - init.add_arg_slot(ATTRS); - init.add_arg_slot(PROFILING); - - // init.add_input_slot(INPUT); - // init.add_param_slot(SCALE); - // init.add_param_slot(BIAS); - init.add_output_slot(OUTPUT); -} + PROFILING, + PER_DEVICE_STATE, + RELU, + HANDLE +}; -static OpTaskSignature get_fwd_task_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_arg_slot(ATTRS); - - fwd.add_input_slot(INPUT); - fwd.add_param_slot(SCALE); - fwd.add_param_slot(BIAS); - fwd.add_output_slot(OUTPUT, WRITE_DISCARD); - - return fwd; -} - -static OpTaskSignature get_bwd_task_signature() { - OpTaskSignature bwd(OpTaskType::BWD); - - bwd.add_arg_slot(ATTRS); - - bwd.add_input_slot(INPUT); - bwd.add_input_grad_slot(INPUT_GRAD, READ_WRITE); - bwd.add_param_slot(SCALE); - bwd.add_param_grad_slot(SCALE_GRAD, READ_WRITE); - bwd.add_param_grad_slot(BIAS_GRAD, READ_WRITE); - bwd.add_output_grad_slot(OUTPUT_GRAD); - - return bwd; -} - -OpTaskBinding BatchNorm::get_init_task_binding() const { +OpTaskInvocation init(BatchNormAttrs const &attrs) { OpTaskBinding binding; - binding.bind_arg(ATTRS, this->attrs); - binding.bind_arg(PROFILING, this->profiling); - - // binding.bind(INPUT, input_tensor(0)); - // binding.bind(SCALE, param_tensor(0)); - // binding.bind(BIAS, param_tensor(1)); + binding.bind(INPUT, input_tensor(0)); + binding.bind(BIAS, input_tensor(2)); binding.bind(OUTPUT, output_tensor(0)); - return binding; + binding.bind_arg(ATTRS, attrs); + binding.bind_arg(PROFILING, profiling_settings()); + binding.bind_arg(HANDLE, ff_handle()); + + return {BATCHNORM_INIT_TASK_ID, binding}; } -OpTaskBinding BatchNorm::get_fwd_task_binding() const { +OpTaskInvocation forward(BatchNormAttrs const &attrs) { OpTaskBinding binding; - - binding.bind_arg(ATTRS, this->attrs); + binding.bind_arg(PROFILING, profiling_settings()); + binding.bind_arg(PER_DEVICE_STATE, + per_device_op_state()); binding.bind(INPUT, input_tensor(0)); - binding.bind(SCALE, param_tensor(0)); - binding.bind(BIAS, param_tensor(1)); + binding.bind(SCALE, input_tensor(1)); + binding.bind(BIAS, input_tensor(2)); binding.bind(OUTPUT, output_tensor(0)); - return binding; + return {BATCHNORM_FWD_TASK_ID, binding}; } -OpTaskBinding BatchNorm::get_bwd_task_binding() const { - OpTaskBinding binding; +OpTaskInvocation backward(BatchNormAttrs const &attrs) { + OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - binding.bind_arg(ATTRS, this->attrs); + return {BATCHNORM_BWD_TASK_ID, binding}; +} - binding.bind(INPUT, input_tensor(0)); - binding.bind(INPUT_GRAD, input_tensor(0).grad()); - binding.bind(SCALE, param_tensor(0)); - binding.bind(SCALE_GRAD, param_tensor(0).grad()); - binding.bind(BIAS_GRAD, param_tensor(1).grad()); - binding.bind(OUTPUT_GRAD, output_tensor(0).grad()); +static DeviceSpecific + init_task_impl(TaskArgumentAccessor const &acc) { + Allocator allocator = acc.get_allocator(); + PerDeviceFFHandle handle = acc.get_argument(HANDLE); + ProfilingSettings profiling = acc.get_argument(PROFILING); + auto output = acc.get_tensor(OUTPUT); + auto const &attrs = acc.get_argument(ATTRS); + + int output_w = output.shape[legion_dim_t(0)]; + int output_h = output.shape[legion_dim_t(1)]; + int output_c = output.shape[legion_dim_t(2)]; + int output_n = output.shape[legion_dim_t(3)]; + + ffTensorDescriptor_t inputTensor; + ffTensorDescriptor_t outputTensor; + ffTensorDescriptor_t biasTensor; + ffActivationDescriptor_t actiDesc; + ffBatchNormMode_t mode; + + size_t totalSize = sizeof(float) * output_c * 4; + float *runningMean = (float *)allocator.allocate(totalSize); + float *runningVar = (float *)runningMean + output_c; + float *saveMean = (float *)runningVar + output_c; + float *saveVar = (float *)saveMean + output_c; + + DeviceSpecific per_device_state = + acc.create_device_specific( + init_kernel(handle, + allocator, + inputTensor, + outputTensor, + biasTensor, + actiDesc, + mode, + runningMean, + runningVar, + saveMean, + saveVar, + output_n, + output_c, + output_h, + output_w, + profiling, + attrs.relu)); + + return per_device_state; +} - return binding; +static DeviceSpecific + init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + return init_task_impl(acc); } -void BatchNorm::init(FFModel const &ff) { - this->execute_task(ff, BATCHNORM_INIT_TASK_ID, get_init_task_signature()); - // assert(check_output_input_weight_same_parallel_is()); - // parallel_is = outputs[0]->parallel_is; - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_init(ff, argmap); - // IndexLauncher launcher(BATCHNORM_INIT_TASK_ID, - // parallel_is, - // TaskArgument(this, sizeof(BatchNorm)), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(1, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(weights[0]->region, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // weights[0]->region)); - // launcher.add_field(2, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(weights[1]->region, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // weights[1]->region)); - // launcher.add_field(3, FID_DATA); - // FutureMap fm = runtime->execute_index_space(ctx, launcher); - // fm.wait_all_results(); - // set_opmeta_from_futuremap(ff, fm); +static optional forward_task_impl(TaskArgumentAccessor const &acc) { + assert(regions.size() == 4); + assert(task->regions.size() == 4); + + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + ProfilingSettings profiling = acc.get_argument(PROFILING); + + auto input = acc.get_tensor(INPUT); + auto output = acc.get_tensor(OUTPUT); + auto scale = acc.get_tensor(SCALE); + auto bias = acc.get_tensor(SCALE); + + return profile(forward_kernel, + profiling, + "[BatchNorm] forward_time = %.2lfms\n", + &per_device_state, + input.get_float_ptr(), + output.get_float_ptr(), + scale.get_float_ptr(), + bias.get_float_ptr()); } -/* - regions[0]: input - regions[1]: output - regions[2](I): scale - regions[3](I): bias -*/ -PerDeviceOpState * - BatchNorm::init_task(Task const *task, +static void forward_task(Task const *task, std::vector const ®ions, Context ctx, Runtime *runtime) { - assert(regions.size() == 4); - assert(task->regions.size() == 4); TaskArgumentAccessor acc(task, regions, ctx, runtime); - FFHandler handle = *((FFHandler const *)task->local_args); - - auto output = acc.get_tensor(OUTPUT); - - int output_w = output.shape[0]; - int output_h = output.shape[1]; - int output_c = output.shape[2]; - int output_n = output.shape[3]; - - Memory gpu_mem = Machine::MemoryQuery(Machine::get_machine()) - .only_kind(Memory::GPU_FB_MEM) - .best_affinity_to(task->target_proc) - .first(); - BatchNormPerDeviceState *m = new BatchNormPerDeviceState( - handle, bm, gpu_mem, output_n, output_c, output_h, output_w); - return m; + forward_task_impl(acc); } -void BatchNorm::forward(FFModel const &ff) { - this->execute_task(ff, BATCHNORM_FWD_TASK_ID, get_fwd_task_signature()); - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_forward(ff, argmap); - // IndexLauncher launcher(BATCHNORM_FWD_TASK_ID, - // parallel_is, - // TaskArgument(NULL, 0), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // WRITE_DISCARD, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(1, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(weights[0]->region, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // weights[0]->region)); - // launcher.add_field(2, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(weights[1]->region, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // weights[1]->region)); - // launcher.add_field(3, FID_DATA); - - // runtime->execute_index_space(ctx, launcher); +static optional backward_task_impl(TaskArgumentAccessor const &acc) { + assert(regions.size() == 7); + assert(task->regions.size() == 7); + + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + ProfilingSettings profiling = acc.get_argument(PROFILING); + + auto input = acc.get_tensor(INPUT); + auto input_grad = acc.get_tensor_grad(INPUT); + auto output = acc.get_tensor(OUTPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); + auto scale = acc.get_tensor(SCALE); + auto scale_grad = acc.get_tensor_grad(SCALE); + auto bias_grad = acc.get_tensor_grad(BIAS); + + return profile(backward_kernel, + profiling, + "[BatchNorm] backward_time = %.2lfms\n", + &per_device_state, + input.get_float_ptr(), + output_grad.get_float_ptr(), + output.get_float_ptr(), + input_grad.get_float_ptr(), + scale.get_float_ptr(), + scale_grad.get_float_ptr(), + bias_grad.get_float_ptr(), + output.shape.get_volume()); } -/* - regions[0](I): input - regions[1](O): ouptut - regions[2](I): scale - regions[3](I): bias -*/ -void BatchNorm::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 4); - assert(task->regions.size() == 4); - // const BatchNorm* bm = (BatchNorm*) task->args; +static void backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { TaskArgumentAccessor acc(task, regions, ctx, runtime); - BatchNormPerDeviceState *m = *((BatchNormPerDeviceState **)task->local_args); - - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - auto scale = acc.get_tensor(SCALE); - auto bias = acc.get_tensor(SCALE); - - profile(forward_kernel, - m->profiling, - "[BatchNorm] forward_time = %.2lfms\n", - m, - input.get_float_ptr(), - output.get_float_ptr(), - scale.get_float_ptr(), - bias.get_float_ptr()); + backward_task_impl(acc); } -void BatchNorm::backward(FFModel const &ff) { - this->execute_task(ff, BATCHNORM_BWD_TASK_ID, get_bwd_task_signature()); - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_backward(ff, argmap); - // IndexLauncher launcher(BATCHNORM_BWD_TASK_ID, - // parallel_is, - // TaskArgument(NULL, 0), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // // regions[0](I): input - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // // regions[1](I/O): input_grad (we only need grad tensors) - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, - // 0 /*projection id*/, - // READ_WRITE, - // EXCLUSIVE, - // inputs[0]->region_grad)); - // launcher.add_field(1, FID_DATA); - // // regions[2](I): output - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(2, FID_DATA); - // // regions[3](I/O): output_grad - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, - // 0 /*projection id*/, - // READ_WRITE, - // EXCLUSIVE, - // outputs[0]->region_grad)); - // launcher.add_field(3, FID_DATA); - // // regions[4](I): filter - // launcher.add_region_requirement(RegionRequirement(weights[0]->region, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // weights[0]->region)); - // launcher.add_field(4, FID_DATA); - // // regions[5](I/O): filter_grad - // launcher.add_region_requirement(RegionRequirement(weights[0]->part_grad, - // 0 /*projection id*/, - // READ_WRITE, - // EXCLUSIVE, - // weights[0]->region_grad)); - // launcher.add_field(5, FID_DATA); - // // regions[6](I/O): bias_grad - // launcher.add_region_requirement(RegionRequirement(weights[1]->part_grad, - // 0 /*projection id*/, - // READ_WRITE, - // EXCLUSIVE, - // weights[1]->region_grad)); - // launcher.add_field(6, FID_DATA); - // FutureMap fm = runtime->execute_index_space(ctx, launcher); +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + BatchNormAttrs const &attrs, + InputParallelTensorDesc const &input_shape, + InputParallelTensorDesc const &scale_shape, + InputParallelTensorDesc const &bias_shape, + ProfilingSettings const &settings, + MachineView const &mv) { + + // int output_w = sub_output.dims[0].size; + // int output_h = sub_output.dims[1].size; + // int output_c = sub_output.dims[2].size; + // int output_n = sub_output.dims[3].size; + // BatchNormPerDeviceState *m = new BatchNormPerDeviceState( + // sim->handler, this, sim->memory, output_n, output_c, output_h, + // output_w); + + // sim->free_all(); + // float *input_ptr = (float *)sim->allocate(sub_input.get_volume(), + // DT_FLOAT); assert(input_ptr != NULL); cost_metrics.inputs_memory += + // cost_metrics.total_mem_diff_from(sim->offset); + + // float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), + // DT_FLOAT); assert(output_ptr != NULL); cost_metrics.outputs_memory += + // cost_metrics.total_mem_diff_from(sim->offset); + + // float *bias_ptr = (float *)sim->allocate(output_c, DT_FLOAT); + // assert(bias_ptr != NULL); + // float *scale_ptr = (float *)sim->allocate(output_c, DT_FLOAT); + // assert(scale_ptr != NULL); + // cost_metrics.weights_memory += + // cost_metrics.total_mem_diff_from(sim->offset); + + auto env = sim.new_environment(); + + ParallelTensorShape output_shape = get_output_shape(attrs); + + SimTaskBinding init_binding; + init_binding.bind(INPUT, input_shape); + init_binding.bind(BIAS, bias_shape); + init_binding.bind(OUTPUT, output_shape); + + init_binding.bind_arg(ATTRS, attrs); + init_binding.bind_arg(PROFILING, settings); + init_binding.bind_arg(HANDLE, ff_handle()); + + auto init_accessor = + env.get_init_accessor(ATTENTION_INIT_TASK_ID, init_binding); + DeviceSpecific per_device_state = + init_task_impl(init_accessor); + + SimTaskBinding fwd_binding; + fwd_binding.bind(INPUT, input_shape); + fwd_binding.bind(SCALE, scale_shape); + fwd_binding.bind(BIAS, bias_shape); + fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind_arg(PROFILING, settings); + fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); + + SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); + + auto fwd_accessor = env.get_fwd_accessor(ATTENTION_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = env.get_bwd_accessor(ATTENTION_BWD_TASK_ID, bwd_binding); + + float forward_time = forward_task_impl(fwd_accessor).value(); + float backward_time = backward_task_impl(bwd_accessor).value(); + + float sync_time = default_estimate_sync_time(env); + return make_metrics(forward_time, backward_time, sync_time, env); } -/* - regions[0](I): input - regions[1](I/O): input_grad - regions[2](I): output - regions[3](I/O): output_grad - regions[4](I): scale - regions[5](I/O): scale_grad - regions[6](I/O): bias_grad -*/ -__host__ void - BatchNorm::backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 7); - assert(task->regions.size() == 7); - // float beta = 0.0f; - // const BatchNorm* bm = (BatchNorm*) task->args; - TaskArgumentAccessor acc(task, regions, ctx, runtime); - BatchNormPerDeviceState *m = *((BatchNormPerDeviceState **)task->local_args); - - auto input = acc.get_tensor(INPUT); - auto input_grad = acc.get_tensor_grad(INPUT_GRAD); - auto output = acc.get_tensor(OUTPUT); - auto output_grad = acc.get_tensor_grad(OUTPUT_GRAD); - auto scale = acc.get_tensor(SCALE); - auto scale_grad = acc.get_tensor_grad(SCALE_GRAD); - auto bias_grad = acc.get_tensor_grad(BIAS_GRAD); - - profile(backward_kernel, - m->profiling, - "[BatchNorm] backward_time = %.2lfms\n", - m, - input.get_float_ptr(), - output_grad.get_float_ptr(), - output.get_float_ptr(), - input_grad.get_float_ptr(), - scale.get_float_ptr(), - scale_grad.get_float_ptr(), - bias_grad.get_float_ptr(), - output.get_volume()); +template <> +void register_task() { + OpTaskSignature init(OpTaskType::INIT); + init.add_input_slot(INPUT); + init.add_input_slot(BIAS); + init.add_output_slot(OUTPUT); + init.add_arg_slot(ATTRS); + init.add_arg_slot(PROFILING); + init.add_unchecked_arg_slot(HANDLE); + + register_task(BATCHNORM_INIT_TASK_ID, "BatchNorm Init", init, init_task); } -bool BatchNorm::measure_operator_cost(Simulator *sim, - MachineView const &mv, - CostMetrics &cost_metrics) const { - ParallelTensorBase sub_input, sub_output; - if (!outputs[0]->get_sub_tensor(mv, sub_output)) { - return false; - } - if (!inputs[0]->get_sub_tensor(mv, sub_input)) { - return false; - } - - int output_w = sub_output.dims[0].size; - int output_h = sub_output.dims[1].size; - int output_c = sub_output.dims[2].size; - int output_n = sub_output.dims[3].size; - BatchNormPerDeviceState *m = new BatchNormPerDeviceState( - sim->handler, this, sim->memory, output_n, output_c, output_h, output_w); - - sim->free_all(); - float *input_ptr = (float *)sim->allocate(sub_input.get_volume(), DT_FLOAT); - assert(input_ptr != NULL); - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - assert(output_ptr != NULL); - cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *bias_ptr = (float *)sim->allocate(output_c, DT_FLOAT); - assert(bias_ptr != NULL); - float *scale_ptr = (float *)sim->allocate(output_c, DT_FLOAT); - assert(scale_ptr != NULL); - cost_metrics.weights_memory += cost_metrics.total_mem_diff_from(sim->offset); - - std::function forward, backward; - forward = [&](ffStream_t stream) { - forward_kernel(stream, m, input_ptr, output_ptr, scale_ptr, bias_ptr); - }; - if (sim->computationMode == COMP_MODE_TRAINING) { - float *input_grad_ptr = - (float *)sim->allocate(sub_input.get_volume(), DT_FLOAT); - assert(input_grad_ptr != NULL); - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *output_grad_ptr = - (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - assert(output_grad_ptr != NULL); - cost_metrics.outputs_memory += - cost_metrics.total_mem_diff_from(sim->offset); - - float *scale_grad_ptr = (float *)sim->allocate(output_c, DT_FLOAT); - assert(scale_grad_ptr != NULL); - float *bias_grad_ptr = (float *)sim->allocate(output_c, DT_FLOAT); - assert(bias_grad_ptr != NULL); - cost_metrics.weights_memory += - cost_metrics.total_mem_diff_from(sim->offset); - - backward = [&](ffStream_t stream) { - backward_kernel(stream, - m, - input_ptr, - output_grad_ptr, - output_ptr, - input_grad_ptr, - scale_ptr, - scale_grad_ptr, - bias_grad_ptr, - sub_output.get_volume()); - }; - } - - inner_measure_operator_cost(sim, forward, backward, cost_metrics); - - if (sim->computationMode == COMP_MODE_TRAINING) { - printf("[Measure BatchNorm] name(%s) size(%zu) forward_time(%.4lf) " - "backward_time(%.4lf)\n", - name, - sub_input.get_volume(), - cost_metrics.forward_time, - cost_metrics.backward_time); - } else { - printf("[Measure BatchNorm] name(%s) size(%zu) forward_time(%.4lf)\n", - name, - sub_input.get_volume(), - cost_metrics.forward_time); - } - // Free batchnormmeta - delete m; - return true; +template <> +void register_task() { + OpTaskSignature fwd(OpTaskType::FWD); + + fwd.add_input_slot(INPUT); + fwd.add_input_slot(SCALE); + fwd.add_input_slot(BIAS); + fwd.add_output_slot(OUTPUT); + fwd.add_arg_slot(PROFILING); + fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); + + register_task(BATCHNORM_FWD_TASK_ID, "BatchNorm Fwd", fwd, forward_task); +} + +template <> +void register_task() { + OpTaskSignature bwd = + infer_bwd_signature(get_op_signature(BATCHNORM_FWD_TASK_ID)); + + register_task(BATCHNORM_BWD_TASK_ID, "BatchNorm Bwd", bwd, backward_task); } }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/batch_norm.h b/lib/runtime/src/ops/batch_norm.h index e54331665e..94bda5122b 100644 --- a/lib/runtime/src/ops/batch_norm.h +++ b/lib/runtime/src/ops/batch_norm.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_BATCH_NORM_H #include "op-attrs/ops/batch_norm.h" -#include "op_task_invocation.h" #include "sim_environment.h" +#include "task_spec/op_task_invocation.h" namespace FlexFlow { @@ -66,3 +66,243 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, } // namespace FlexFlow #endif + +// void BatchNorm::init(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_INIT_TASK_ID, get_init_task_signature()); +// } + +// void BatchNorm::forward(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_FWD_TASK_ID, get_fwd_task_signature()); +// } + +// void BatchNorm::backward(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_BWD_TASK_ID, get_bwd_task_signature()); +// } + +// Tensor batch_norm(const Tensor input, bool relu, char const *name) { +// assert(input->num_dims == 4); /*NCHW*/ +// Layer *bm = new Layer(this, +// OP_BATCHNORM, +// DT_FLOAT, +// name, +// 1 /*inputs*/, +// 2 /*weights*/, +// 1 /*outputs*/, +// input); +// int numdims = 4; +// bm->outputs[0] = create_tensor_legion_ordering( +// numdims, input->dims, DT_FLOAT, bm, 0, true /*create_grad*/); +// bm->add_int_property("relu", relu); +// layers.push_back(bm); +// return bm->outputs[0]; +// } + +// BatchNorm::BatchNorm(FFModel &model, +// const ParallelTensor _input, +// const ParallelTensor _scale, +// const ParallelTensor _bias, +// bool _relu, +// char const *name) +// : Op(model, +// OP_BATCHNORM, +// DT_FLOAT, +// name, +// 1 /*inputs*/, +// 2 /*weights*/, +// 1 /*outputs*/, +// _input, +// _scale, +// _bias), +// relu(_relu) { +// assert(_input->num_dims == 4); +// numOutputs = 1; +// ParallelDim dims[MAX_TENSOR_DIM]; +// for (int i = 0; i < _input->num_dims; i++) { +// dims[i] = _input->dims[_input->num_dims - 1 - i]; +// } +// outputs[0] = +// model.create_parallel_tensor(_input->num_dims, dims, DT_FLOAT, this); +// return; +// } + +/* + locals[0] = scale + locals[1] = bias +*/ + +// void BatchNorm::init(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_INIT_TASK_ID, get_init_task_signature()); +// assert(check_output_input_weight_same_parallel_is()); +// parallel_is = outputs[0]->parallel_is; +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_init(ff, argmap); +// IndexLauncher launcher(BATCHNORM_INIT_TASK_ID, +// parallel_is, +// TaskArgument(this, sizeof(BatchNorm)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(1, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(weights[0]->region, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// weights[0]->region)); +// launcher.add_field(2, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(weights[1]->region, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// weights[1]->region)); +// launcher.add_field(3, FID_DATA); +// FutureMap fm = runtime->execute_index_space(ctx, launcher); +// fm.wait_all_results(); +// set_opmeta_from_futuremap(ff, fm); +// } + +/* + regions[0]: input + regions[1]: output + regions[2](I): scale + regions[3](I): bias +*/ + +// void BatchNorm::forward(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_FWD_TASK_ID, get_fwd_task_signature()); +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_forward(ff, argmap); +// IndexLauncher launcher(BATCHNORM_FWD_TASK_ID, +// parallel_is, +// TaskArgument(NULL, 0), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_DISCARD, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(1, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(weights[0]->region, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// weights[0]->region)); +// launcher.add_field(2, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(weights[1]->region, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// weights[1]->region)); +// launcher.add_field(3, FID_DATA); + +// runtime->execute_index_space(ctx, launcher); +// } + +/* + regions[0](I): input + regions[1](O): ouptut + regions[2](I): scale + regions[3](I): bias +*/ + +// void BatchNorm::backward(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_BWD_TASK_ID, get_bwd_task_signature()); +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_backward(ff, argmap); +// IndexLauncher launcher(BATCHNORM_BWD_TASK_ID, +// parallel_is, +// TaskArgument(NULL, 0), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// // regions[0](I): input +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// // regions[1](I/O): input_grad (we only need grad tensors) +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// inputs[0]->region_grad)); +// launcher.add_field(1, FID_DATA); +// // regions[2](I): output +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(2, FID_DATA); +// // regions[3](I/O): output_grad +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// outputs[0]->region_grad)); +// launcher.add_field(3, FID_DATA); +// // regions[4](I): filter +// launcher.add_region_requirement(RegionRequirement(weights[0]->region, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// weights[0]->region)); +// launcher.add_field(4, FID_DATA); +// // regions[5](I/O): filter_grad +// launcher.add_region_requirement(RegionRequirement(weights[0]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// weights[0]->region_grad)); +// launcher.add_field(5, FID_DATA); +// // regions[6](I/O): bias_grad +// launcher.add_region_requirement(RegionRequirement(weights[1]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// weights[1]->region_grad)); +// launcher.add_field(6, FID_DATA); +// FutureMap fm = runtime->execute_index_space(ctx, launcher); +// } + +/* + regions[0](I): input + regions[1](I/O): input_grad + regions[2](I): output + regions[3](I/O): output_grad + regions[4](I): scale + regions[5](I/O): scale_grad + regions[6](I/O): bias_grad +*/ From f2205f43b18ca97e332c11d56668e1c0516c6826 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 10:26:59 -0700 Subject: [PATCH 08/19] change --- lib/kernels/include/kernels/batch_matmul_kernels.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index 23522a89f7..6850dec178 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -19,8 +19,8 @@ FF_VISITABLE_STRUCT_NO_EQ( namespace Kernels { namespace BatchMatmul { -BMMPerDeviceState init_kernel(PerDeviceFFHandle handle, - Allocator allocator, +BMMPerDeviceState init_kernel(PerDeviceFFHandle const &handle, + Allocator const &allocator, int a_seq_length_dim, int b_seq_length_dim); From 2f4662da4fcb64dad9de7dd6294e8e6baa3e532e Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 10:28:45 -0700 Subject: [PATCH 09/19] change --- lib/kernels/include/kernels/batch_matmul_kernels.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index 6850dec178..e5062b2c61 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -26,7 +26,7 @@ BMMPerDeviceState init_kernel(PerDeviceFFHandle const &handle, void forward_kernel(ffStream_t stream, BMMPerDeviceState const &meta, - float *o_ptr, + float *output_ptr, float const *a_ptr, float const *b_ptr, float const *c_ptr, From 642eb90d99bdb83b3a212be50232520b497816ba Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 10:34:04 -0700 Subject: [PATCH 10/19] change --- lib/kernels/include/kernels/batch_matmul_kernels.h | 5 ++--- lib/runtime/src/ops/batch_matmul.cc | 1 - 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index e5062b2c61..ec32648d0f 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -27,9 +27,8 @@ BMMPerDeviceState init_kernel(PerDeviceFFHandle const &handle, void forward_kernel(ffStream_t stream, BMMPerDeviceState const &meta, float *output_ptr, - float const *a_ptr, - float const *b_ptr, - float const *c_ptr, + float const *lhs_input_ptr, + float const *rhs_input_ptr, int m, int n, int k, diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 59570f4417..09eb1f23c7 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -134,7 +134,6 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { output.get_float_ptr(), a_input.get_float_ptr(), b_input.get_float_ptr(), - nullptr, // c_ptr m, n, k, From e79406a842117e7823d8d63ed21ad5613c846078 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 10:46:00 -0700 Subject: [PATCH 11/19] change --- lib/kernels/src/cuda/batch_matmul_kernels.cu | 5 +---- lib/kernels/src/hip/batch_matmul_kernels.cpp | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/lib/kernels/src/cuda/batch_matmul_kernels.cu b/lib/kernels/src/cuda/batch_matmul_kernels.cu index 3593ac4ab2..cde0df93c0 100644 --- a/lib/kernels/src/cuda/batch_matmul_kernels.cu +++ b/lib/kernels/src/cuda/batch_matmul_kernels.cu @@ -18,9 +18,6 @@ namespace FlexFlow { -BatchMatmulPerDeviceState::BatchMatmulPerDeviceState(FFHandler handler) - : PerDeviceOpState(handler) {} - namespace Kernels { namespace BatchMatmul { @@ -124,7 +121,7 @@ O = A * B */ void forward_kernel(cudaStream_t stream, - BatchMatmulPerDeviceState const *meta, + BatchMatmulPerDeviceState const &meta, float *o_ptr, float const *a_ptr, float const *b_ptr, diff --git a/lib/kernels/src/hip/batch_matmul_kernels.cpp b/lib/kernels/src/hip/batch_matmul_kernels.cpp index e5334b1841..a06442d3d6 100644 --- a/lib/kernels/src/hip/batch_matmul_kernels.cpp +++ b/lib/kernels/src/hip/batch_matmul_kernels.cpp @@ -29,7 +29,7 @@ O: (batch, n, m) O = A * B */ void forward_kernel(hipStream_t stream, - BatchMatmulPerDeviceState const *meta, + BatchMatmulPerDeviceState const &meta, float *o_ptr, float const *a_ptr, float const *b_ptr, From bb3c10f8a9f46c9a09d514a9c70f4e72e37e3e94 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 10:51:31 -0700 Subject: [PATCH 12/19] change --- lib/op-attrs/include/op-attrs/ops/batch_matmul.h | 3 --- lib/runtime/src/ops/batch_matmul.cc | 6 +----- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index 6ac2a2cf69..00c700ba20 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -14,9 +14,6 @@ FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim); int get_aSeqLengthDim(BatchMatmulAttrs const &attrs); int get_bSeqLengthDim(BatchMatmulAttrs const &attrs); -ParallelTensorShape get_weights_shape(BatchMatmulAttrs const &, - ParallelTensorShape const &, - ParallelTensorShape const &); CHECK_VALID_OP_ATTR(BatchMatmulAttrs); } // namespace FlexFlow diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 09eb1f23c7..669cad215b 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -80,9 +80,7 @@ static DeviceSpecific DeviceSpecific per_device_state = acc.create_device_specific( init_kernel(handle, allocator, a_seq_length_dim, b_seq_length_dim)); - - // assert(weight.shape.get_volume() * sizeof(float) == - // acc.unwrap(per_device_state)->weightSize); + return per_device_state; } @@ -230,8 +228,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, ParallelTensorShape output_shape = get_output_shape(attrs, a_input.shape, b_input.shape); - ParallelTensorShape weight_shape = - get_weights_shape(attrs, a_input.shape, b_input.shape); SimTaskBinding init_binding; init_binding.bind_arg(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); From 848850900f9d44fe29df90c69a2099cb71be299e Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 11:00:34 -0700 Subject: [PATCH 13/19] change --- lib/runtime/include/runtime/config.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/runtime/include/runtime/config.h b/lib/runtime/include/runtime/config.h index 21cfca20cf..ef7e779469 100644 --- a/lib/runtime/include/runtime/config.h +++ b/lib/runtime/include/runtime/config.h @@ -107,7 +107,7 @@ struct FFConfig : public use_visitable_cmp { struct FFIterationConfig { FFIterationConfig(); void reset(); - req seq_length; + int seq_length; }; FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(FFIterationConfig, seq_length); From 36eba295424c5c58ea0e3aedea76e3f3adfd3790 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 11:03:04 -0700 Subject: [PATCH 14/19] change --- lib/runtime/src/ops/batch_matmul.h | 55 ------------------------------ 1 file changed, 55 deletions(-) diff --git a/lib/runtime/src/ops/batch_matmul.h b/lib/runtime/src/ops/batch_matmul.h index 42c9bc23de..018fe1d582 100644 --- a/lib/runtime/src/ops/batch_matmul.h +++ b/lib/runtime/src/ops/batch_matmul.h @@ -30,61 +30,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, ProfilingSettings const &settings, MachineView const &pc); -/* class BatchMatmul : public Op { */ -/* public: */ -/* BatchMatmul(FFModel &model, */ -/* const ParallelTensor A, */ -/* const ParallelTensor B, */ -/* int a_seq_length_dim, */ -/* int b_seq_length_dim, */ -/* char const *name = nullptr); */ -/* static Op * */ -/* create_operator_from_layer(FFModel &model, */ -/* Layer const *layer, */ -/* std::vector const &inputs); - */ - -/* void init(FFModel const &) override; */ -/* void forward(FFModel const &) override; */ -/* void backward(FFModel const &) override; */ -/* /1* static PCG::Node deserialize(FFModel &ff, *1/ */ -/* /1* Legion::Deserializer &d, *1/ */ -/* /1* ParallelTensor inputs[], *1/ */ -/* /1* int num_inputs); *1/ */ -/* static PerDeviceOpState *init_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void forward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void backward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* bool measure_operator_cost(Simulator *sim, */ -/* MachineView const &pc, */ -/* CostMetrics &cost_metrics) const override; */ - -/* OpTaskBinding get_init_task_binding() const override; */ -/* OpTaskBinding get_fwd_task_binding() const override; */ -/* OpTaskBinding get_bwd_task_binding() const override; */ -/* private: */ -/* template */ -/* void init_with_dim(FFModel const &ff); */ -/* template */ -/* void forward_with_dim(FFModel const &ff); */ -/* template */ -/* void backward_with_dim(FFModel const &ff); */ - -/* public: */ -/* int a_seq_length_dim, b_seq_length_dim; */ -/* }; */ - } // namespace FlexFlow #endif From 98773b4237de4a4585711bca14d26b9d3c6e75ff Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 11:15:53 -0700 Subject: [PATCH 15/19] change --- lib/runtime/src/ops/batch_matmul.cc | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 669cad215b..a8c2ec7bd7 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -72,8 +72,8 @@ OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { static DeviceSpecific init_task_impl(TaskArgumentAccessor const &acc) { - auto const a_seq_length_dim = acc.get_argument(A_SEQ_LENGTH_DIM); - auto const b_seq_length_dim = acc.get_argument(B_SEQ_LENGTH_DIM); + int const a_seq_length_dim = acc.get_argument(A_SEQ_LENGTH_DIM); + int const b_seq_length_dim = acc.get_argument(B_SEQ_LENGTH_DIM); PerDeviceFFHandle handle = acc.get_argument(HANDLE); Allocator allocator = acc.get_allocator(); @@ -94,9 +94,6 @@ static DeviceSpecific } static optional forward_task_impl(TaskArgumentAccessor const &acc) { - assert(regions.size() == 3); - assert(task->regions.size() == 3); - auto a_input = acc.get_tensor(A_INPUT); auto b_input = acc.get_tensor(B_INPUT); auto output = acc.get_tensor(OUTPUT); @@ -148,10 +145,6 @@ static void forward_task(Task const *task, } static optional backward_task_impl(TaskArgumentAccessor const &acc) { - // Currently assume C is NULL - assert(regions.size() == 6); - assert(task->regions.size() == 6); - // BatchMatmul* bmm = (BatchMatmul*) task->args; FFIterationConfig iter_config = acc.get_argument(ITERATION_CONFIG); @@ -161,15 +154,15 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { auto output = acc.get_tensor(OUTPUT); auto output_grad = acc.get_tensor_grad(OUTPUT); // is this equivalent to checking `Domain` equality? - assert(output == output_grad); + assert(output.shape == output_grad.shape); auto a_input = acc.get_tensor(A_INPUT); auto a_input_grad = acc.get_tensor_grad(A_INPUT); - assert(a_input == a_input_grad); + assert(a_input.shape == a_input_grad.shape); auto b_input = acc.get_tensor(B_INPUT); auto b_input_grad = acc.get_tensor_grad(B_INPUT); - assert(b_input == b_input_grad); + assert(b_input.shape == b_input_grad.shape); // check dins int m = b_input.shape[legion_dim_t(0)]; @@ -181,8 +174,8 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(a_input.shape.size() == b_input.shape.size()); assert(a_input.shape.size() == output.shape.size()); int batch = 1; - for (int i = 2; i < a_input.shape.get_dim(); - i++) { //@colin get_dim() or get_volume()? + for (int i = 2; i < a_input.shape.dims.num_dims(); + i++) { int dim_size = a_input.shape[legion_dim_t(i)]; assert(dim_size == b_input.shape[legion_dim_t(i)]); assert(dim_size == output.shape[legion_dim_t(i)]); From ae592613cf3fa6c763ca0f14b30a1899831ee4a9 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 11:28:18 -0700 Subject: [PATCH 16/19] change --- lib/runtime/src/task_spec/op_arg_ref.h | 10 ++-------- lib/runtime/src/task_spec/runtime_arg_ref.h | 2 ++ 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/lib/runtime/src/task_spec/op_arg_ref.h b/lib/runtime/src/task_spec/op_arg_ref.h index f2361af05d..20ce6892a2 100644 --- a/lib/runtime/src/task_spec/op_arg_ref.h +++ b/lib/runtime/src/task_spec/op_arg_ref.h @@ -4,15 +4,13 @@ #include "arg_ref.h" #include "device_specific.h" #include "op-attrs/parallel_tensor_shape.h" -#include "runtime/config.h" namespace FlexFlow { enum class OpArgRefType { PER_DEVICE_OP_STATE, - PARALLEL_TENSOR_SHAPE, - ITERATION_CONFIG -}; + PARALLEL_TENSOR_SHAPE + }; template using OpArgRef = ArgRef; @@ -28,10 +26,6 @@ OpArgRef input_parallel_tensor_shape(int idx) { return {OpArgRefType::PARALLEL_TENSOR_SHAPE}; } -OpArgRef iteration_config() { - return {OpArgRefType::ITERATION_CONFIG}; -} - } // namespace FlexFlow #endif diff --git a/lib/runtime/src/task_spec/runtime_arg_ref.h b/lib/runtime/src/task_spec/runtime_arg_ref.h index 6b4345091a..033c2bcfbc 100644 --- a/lib/runtime/src/task_spec/runtime_arg_ref.h +++ b/lib/runtime/src/task_spec/runtime_arg_ref.h @@ -3,6 +3,7 @@ #include "arg_ref.h" #include "device_specific.h" +#include "runtime/config.h" namespace FlexFlow { @@ -15,6 +16,7 @@ using RuntimeArgRefSpec = ArgRefSpec; RuntimeArgRef profiling_settings(); RuntimeArgRef> ff_handle(); +RuntimeArgRef iteration_config(); } // namespace FlexFlow From 8db251252670a8107430318da166e2ba059879ca Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 17:50:28 -0700 Subject: [PATCH 17/19] fix asserts --- lib/runtime/src/ops/attention.cc | 15 --------------- lib/runtime/src/ops/batch_matmul.cc | 13 ++++--------- 2 files changed, 4 insertions(+), 24 deletions(-) diff --git a/lib/runtime/src/ops/attention.cc b/lib/runtime/src/ops/attention.cc index bca87bdb53..94e2b03731 100644 --- a/lib/runtime/src/ops/attention.cc +++ b/lib/runtime/src/ops/attention.cc @@ -121,18 +121,6 @@ static DeviceSpecific int num_samples = get_piece_shape(query_parallel_tensor_shape)[ff_dim_t(2)]; int num_heads = get_piece_shape(weight_parallel_tensor_shape)[ff_dim_t(1)]; - assert(qoSeqLength == query.shape[legion_dim_t(1)]); - assert(qSize == query.shape[legion_dim_t(0)]); - assert(num_samples == key.shape[legion_dim_t(2)]); - assert(kvSeqLength == key.shape[legion_dim_t(1)]); - assert(kSize == key.shape[legion_dim_t(0)]); - assert(num_samples == value.shape[legion_dim_t(2)]); - assert(kvSeqLength == value.shape[legion_dim_t(1)]); - assert(vSize == value.shape[legion_dim_t(0)]); - assert(num_samples == output.shape[legion_dim_t(2)]); - assert(qoSeqLength == output.shape[legion_dim_t(1)]); - assert(oProjSize == output.shape[legion_dim_t(0)]); - DeviceSpecific per_device_state = acc.create_device_specific( init_kernel(handle, @@ -149,9 +137,6 @@ static DeviceSpecific qoSeqLength, kvSeqLength, attrs.add_bias_kv)); - - assert(weight.shape.get_volume() * sizeof(float) == - acc.unwrap(per_device_state)->weightSize); return per_device_state; } diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index a8c2ec7bd7..00699652e7 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -110,8 +110,8 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { int k = a_input.shape[legion_dim_t(0)]; assert(k == b_input.shape[legion_dim_t(1)]); - assert(a_input.shape.size() == b_input.shape.size()); - assert(a_input.shape.size() == output.shape.size()); + assert(a_input.shape.get_volume() == b_input.shape.get_volume()); + assert(a_input.shape.get_volume() == output.shape.get_volume()); int batch = 1; for (int i = 2; i < a_input.shape.get_dim(); @@ -171,8 +171,8 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(n == output.shape[legion_dim_t(1)]); int k = a_input.shape[legion_dim_t(0)]; assert(k == b_input.shape[legion_dim_t(1)]); - assert(a_input.shape.size() == b_input.shape.size()); - assert(a_input.shape.size() == output.shape.size()); + assert(a_input.shape.get_volume() == b_input.shape.get_volume()); + assert(a_input.shape.get_volume() == output.shape.get_volume()); int batch = 1; for (int i = 2; i < a_input.shape.dims.num_dims(); i++) { @@ -182,11 +182,6 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { batch *= dim_size; } - // TODO: add support for meta->a_seq_length_dim >= 0 - // or meta->b_seq_length_dim >= 0 - assert((meta->a_seq_length_dim >= a_len) || (iter_config.seq_length == 0)); - assert((meta->b_seq_length_dim >= b_len) || (iter_config.seq_length == 0)); - return profile(backward_kernel, profiling, "[BatchMatmul] backward_time = %.2lfms\n", From e8a6c30a439fcb1168986939de909d6c717c66f5 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 17:57:44 -0700 Subject: [PATCH 18/19] remove asserts --- lib/runtime/src/ops/batch_norm.cc | 6 ------ 1 file changed, 6 deletions(-) diff --git a/lib/runtime/src/ops/batch_norm.cc b/lib/runtime/src/ops/batch_norm.cc index ffd52c96fb..6ebf359051 100644 --- a/lib/runtime/src/ops/batch_norm.cc +++ b/lib/runtime/src/ops/batch_norm.cc @@ -130,9 +130,6 @@ static DeviceSpecific } static optional forward_task_impl(TaskArgumentAccessor const &acc) { - assert(regions.size() == 4); - assert(task->regions.size() == 4); - auto per_device_state = acc.get_argument(PER_DEVICE_STATE); ProfilingSettings profiling = acc.get_argument(PROFILING); @@ -161,9 +158,6 @@ static void forward_task(Task const *task, } static optional backward_task_impl(TaskArgumentAccessor const &acc) { - assert(regions.size() == 7); - assert(task->regions.size() == 7); - auto per_device_state = acc.get_argument(PER_DEVICE_STATE); ProfilingSettings profiling = acc.get_argument(PROFILING); From 09acbe545ad4d6d1cab1e14e55bf1d7e9a638954 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 17:59:10 -0700 Subject: [PATCH 19/19] format --- lib/runtime/src/ops/batch_matmul.cc | 5 ++--- lib/runtime/src/task_spec/op_arg_ref.h | 5 +---- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 00699652e7..45c5e11b9c 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -80,7 +80,7 @@ static DeviceSpecific DeviceSpecific per_device_state = acc.create_device_specific( init_kernel(handle, allocator, a_seq_length_dim, b_seq_length_dim)); - + return per_device_state; } @@ -174,8 +174,7 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(a_input.shape.get_volume() == b_input.shape.get_volume()); assert(a_input.shape.get_volume() == output.shape.get_volume()); int batch = 1; - for (int i = 2; i < a_input.shape.dims.num_dims(); - i++) { + for (int i = 2; i < a_input.shape.dims.num_dims(); i++) { int dim_size = a_input.shape[legion_dim_t(i)]; assert(dim_size == b_input.shape[legion_dim_t(i)]); assert(dim_size == output.shape[legion_dim_t(i)]); diff --git a/lib/runtime/src/task_spec/op_arg_ref.h b/lib/runtime/src/task_spec/op_arg_ref.h index 20ce6892a2..3e931d79a4 100644 --- a/lib/runtime/src/task_spec/op_arg_ref.h +++ b/lib/runtime/src/task_spec/op_arg_ref.h @@ -7,10 +7,7 @@ namespace FlexFlow { -enum class OpArgRefType { - PER_DEVICE_OP_STATE, - PARALLEL_TENSOR_SHAPE - }; +enum class OpArgRefType { PER_DEVICE_OP_STATE, PARALLEL_TENSOR_SHAPE }; template using OpArgRef = ArgRef;