From 6b248cf20e560dbc2ff57d0db2bf463775177620 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 22 Aug 2023 14:11:42 -0700 Subject: [PATCH 01/24] batch matmul initial commit --- .../include/kernels/batch_matmul_kernels.h | 23 +- lib/kernels/src/hip/batch_matmul_kernels.cpp | 7 +- .../include/op-attrs/ops/batch_matmul.h | 4 +- lib/op-attrs/src/batch_matmul.cc | 8 + lib/runtime/src/ops/batch_matmul.cc | 855 ++++-------------- lib/runtime/src/ops/batch_matmul.h | 441 ++++++++- 6 files changed, 655 insertions(+), 683 deletions(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index 0e4437bdb8..cdffdf1907 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -6,17 +6,26 @@ namespace FlexFlow { -class BatchMatmulPerDeviceState : public PerDeviceOpState { -public: - BatchMatmulPerDeviceState(FFHandler handler); - int a_seq_length_dim, b_seq_length_dim; +struct BMMPerDeviceState { + PerDeviceFFHandle handle; + Allocator allocator; + int a_seq_length_dim; + int b_seq_length_dim; }; +FF_VISITABLE_STRUCT_NO_EQ(BMMPerDeviceState, + handle,); + namespace Kernels { namespace BatchMatmul { +BMMPerDeviceState init_kernel(PerDeviceFFHandle handle, + Allocator allocator, + int a_seq_length_dim, + int b_seq_length_dim); + void forward_kernel(ffStream_t stream, - BatchMatmulPerDeviceState const *, + BMMPerDeviceState const *meta, float *o_ptr, float const *a_ptr, float const *b_ptr, @@ -25,12 +34,10 @@ void forward_kernel(ffStream_t stream, int n, int k, int batch, - int a_seq_length_dim = -1, - int b_seq_length_dim = -1, int seq_length = -1); void backward_kernel(ffStream_t stream, - BatchMatmulPerDeviceState const *, + BMMPerDeviceState const *meta, float const *o_ptr, float const *o_grad_ptr, float const *a_ptr, diff --git a/lib/kernels/src/hip/batch_matmul_kernels.cpp b/lib/kernels/src/hip/batch_matmul_kernels.cpp index d8b6500326..e5334b1841 100644 --- a/lib/kernels/src/hip/batch_matmul_kernels.cpp +++ b/lib/kernels/src/hip/batch_matmul_kernels.cpp @@ -19,9 +19,6 @@ namespace FlexFlow { -BatchMatmulPerDeviceState::BatchMatmulPerDeviceState(FFHandler handler) - : PerDeviceOpState(handler) {} - namespace Kernels { namespace BatchMatmul { @@ -42,9 +39,9 @@ void forward_kernel(hipStream_t stream, int k, int batch, hipStream_t stream, - int a_seq_length_dim, - int b_seq_length_dim, int seq_length) { + int a_seq_length_dim = meta->a_seq_length_dim; + int b_seq_length_dim = meta->b_seq_length_dim; checkCUDA(hipblasSetStream(meta->handle.blas, stream)); checkCUDNN(miopenSetStream(meta->handle.dnn, stream)); diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index c74824570c..00c700ba20 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -12,8 +12,10 @@ struct BatchMatmulAttrs { }; FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim); -CHECK_VALID_OP_ATTR(BatchMatmulAttrs); +int get_aSeqLengthDim(BatchMatmulAttrs const &attrs); +int get_bSeqLengthDim(BatchMatmulAttrs const &attrs); +CHECK_VALID_OP_ATTR(BatchMatmulAttrs); } // namespace FlexFlow #endif diff --git a/lib/op-attrs/src/batch_matmul.cc b/lib/op-attrs/src/batch_matmul.cc index 1cc8c5cfda..bd61c24737 100644 --- a/lib/op-attrs/src/batch_matmul.cc +++ b/lib/op-attrs/src/batch_matmul.cc @@ -2,6 +2,14 @@ namespace FlexFlow { +int get_aSeqLengthDim(BatchMatmulAttrs const &attrs) { + return attrs.a_seq_length_dim; +} + +int get_bSeqLengthDim(BatchMatmulAttrs const &attrs) { + return attrs.b_seq_length_dim; +} + /* bool BatchMatmulAttrs::is_valid( */ /* ParallelTensorShape const &lhs, ParallelTensorShape const &rhs) const { */ diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 3e860bd413..9bbc050f8b 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -13,754 +13,287 @@ * limitations under the License. */ +// #include "batch_matmul.h" +// #include "kernels/batch_matmul_kernels.h" +// #include "kernels/profiling.h" +// #include "legion/legion_utilities.h" +// #include "tasks.h" + #include "batch_matmul.h" #include "kernels/batch_matmul_kernels.h" -#include "kernels/profiling.h" -#include "legion/legion_utilities.h" -#include "tasks.h" +#include "legion.h" +#include "op-attrs/ops/batch_matmul.h" namespace FlexFlow { using namespace FlexFlow::Kernels::BatchMatmul; +using Legion::Context; +using Legion::PhysicalRegion; +using Legion::Runtime; +using Legion::Task; + enum Slots { - A_INPUT, - B_INPUT, - OUTPUT, - A_INPUT_GRAD, - B_INPUT_GRAD, - OUTPUT_GRAD, - ATTRS, - PROFILING -}; + A_INPUT, //tensor + B_INPUT, //tensor + OUTPUT, //tensor + PROFILING, + HANDLE, + A_SEQ_LENGTH_DIM, + B_SEQ_LENGTH_DIM, + PER_DEVICE_STATE + }; OpTaskInvocation init(BatchMatmulAttrs const &attrs) { - OpTaskBinding b; + OpTaskBinding init; - b.bind_arg(ATTRS, attrs); - b.bind_arg(PROFILING, enable_profiling()); + init.bind_arg(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); + init.bind_arg(B_SEQ_LENGTH_DIM, get_bSeqLengthDim(attrs)); + init.bind_arg(HANDLE, ff_handle()); - return {BATCHMATMUL_INIT_TASK_ID, b}; + return {BATCHMATMUL_INIT_TASK_ID, init}; } OpTaskInvocation forward(BatchMatmulAttrs const &attrs) { - OpTaskBinding b; - - b.bind(A_INPUT, input_tensor(0)); - b.bind(B_INPUT, input_tensor(1)); - b.bind(OUTPUT, output_tensor(0)); - - return {BATCHMATMUL_FWD_TASK_ID, b}; -} - -OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { - OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - - return {BATCHMATMUL_BWD_TASK_ID, b}; -} - -BatchMatmulParams BatchMatmul::get_params() const { - BatchMatmulParams params; - params.a_seq_length_dim = inputs[0]->num_dims - 1 - this->a_seq_length_dim; - params.b_seq_length_dim = inputs[1]->num_dims - 1 - this->b_seq_length_dim; - return params; -} - -Tensor FFModel::batch_matmul(const Tensor A, - const Tensor B, - int a_seq_length_dim, - int b_seq_length_dim, - char const *name) { - Layer *bmm = new Layer(this, - OP_BATCHMATMUL, - DT_FLOAT, - name, - 2 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - A, - B); - assert((a_seq_length_dim <= 1) && - "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " - "Fortran ordering)."); - assert((b_seq_length_dim <= 1) && - "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " - "Fortran ordering)."); - assert(A->num_dims == B->num_dims); - for (int i = A->num_dims - 1; i >= 2; i--) { - assert(A->dims[i] == B->dims[i]); - } - assert(A->dims[0] == B->dims[1]); - int dims[MAX_TENSOR_DIM]; - int numdim = A->num_dims; - for (int i = 0; i < numdim; i++) { - dims[i] = A->dims[i]; - } - dims[0] = B->dims[0]; - bmm->outputs[0] = create_tensor_legion_ordering( - numdim, dims, A->data_type, bmm, 0, true /*create_grad*/); - bmm->add_int_property("a_seq_length_dim", a_seq_length_dim); - bmm->add_int_property("b_seq_length_dim", b_seq_length_dim); - layers.push_back(bmm); - return bmm->outputs[0]; -} - -Op *BatchMatmul::create_operator_from_layer( - FFModel &model, - Layer const *layer, - std::vector const &inputs) { - long long value; - layer->get_int_property("a_seq_length_dim", value); - int a_seq_length_dim = value; - layer->get_int_property("b_seq_length_dim", value); - int b_seq_length_dim = value; - return new BatchMatmul(model, - inputs[0], - inputs[1], - a_seq_length_dim, - b_seq_length_dim, - layer->name); -} - -BatchMatmul::BatchMatmul( - FFModel &model, - BatchMatmulParams const ¶ms, - std::pair const &inputs, - char const *name) - : BatchMatmul(model, - inputs.first, - inputs.second, - params.a_seq_length_dim, - params.b_seq_length_dim, - name) {} - -// return A*B -BatchMatmul::BatchMatmul(FFModel &model, - const ParallelTensor A, - const ParallelTensor B, - int _a_seq_length_dim, - int _b_seq_length_dim, - char const *name) - : Op(model, - OP_BATCHMATMUL, - DT_FLOAT, - name, - 2 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - A, - B), - a_seq_length_dim(A->num_dims - 1 - _a_seq_length_dim), - b_seq_length_dim(B->num_dims - 1 - _b_seq_length_dim) { - assert((_a_seq_length_dim <= 1) && - "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " - "Fortran ordering)."); - assert((_b_seq_length_dim <= 1) && - "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " - "Fortran ordering)."); - assert(A->num_dims == B->num_dims); - for (int i = A->num_dims - 1; i >= 2; i--) { - assert(A->dims[i] == B->dims[i]); - } - assert(A->dims[0] == B->dims[1]); - ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < A->num_dims; i++) { - dims[i] = A->dims[i]; - } - dims[0] = B->dims[0]; - numOutputs = 1; - outputs[0] = model.create_parallel_tensor_legion_ordering( - A->num_dims, dims, DT_FLOAT, this); - // C is not none - // if (C != Tensor::NO_TENSOR) { - // numInputs = 3; - // assert(C.num_dims == outputs[0].num_dims); - // for (int i = 0; i < C.num_dims; i++) - // assert(C.adim[i] == outputs[0].adim[i]); - //} -} - -void BatchMatmul::serialize(Legion::Serializer &sez) const { - BatchMatmulParams params = get_params(); - sez.serialize(params.a_seq_length_dim); - sez.serialize(params.b_seq_length_dim); -} - -using PCG::Node; -/*static*/ -Node BatchMatmul::deserialize(FFModel &ff, - Legion::Deserializer &dez, - ParallelTensor inputs[], - int num_inputs) { - assert(num_inputs == 2); - int a_seq_length_dim, b_seq_length_dim; - dez.deserialize(a_seq_length_dim); - dez.deserialize(b_seq_length_dim); - - BatchMatmulParams params; - params.a_seq_length_dim = a_seq_length_dim; - params.b_seq_length_dim = b_seq_length_dim; - return ff.get_or_create_node({inputs[0], inputs[1]}, params); -} + OpTaskBinding fwd; -Op *BatchMatmul::materialize(FFModel &ff, - ParallelTensor inputs[], - int num_inputs) const { - BatchMatmulParams params = get_params(); - return new BatchMatmul(ff, params, {inputs[0], inputs[1]}, this->name); -} - -template <> -void register_task() { - OpTaskSignature sig(OpTaskType::INIT); + fwd.bind(A_INPUT, input_tensor(0)); + fwd.bind(B_INPUT, input_tensor(1)); + fwd.bind(OUTPUT, output_tensor(0)); - sig.add_arg_slot(ATTRS); - sig.add_arg_slot(PROFILING); + fwd.bind_arg(PROFILING, profiling_settings()); + fwd.bind_arg(PER_DEVICE_STATE, per_device_op_state()); - register_task(BATCHMATMUL_INIT_TASK_ID, "BatchMatmul Init", sig, init_task); + return {BATCHMATMUL_FWD_TASK_ID, fwd}; } -static OpTaskSignature get_fwd_task_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_input_slot(A_INPUT, READ_WRITE); - fwd.add_input_slot(B_INPUT, READ_WRITE); - fwd.add_output_slot(OUTPUT); - - return fwd; -} - -static OpTaskSignature get_bwd_task_signature() { - OpTaskSignature bwd(OpTaskType::BWD); - - bwd.add_input_slot(A_INPUT); - bwd.add_input_slot(B_INPUT); - bwd.add_input_grad_slot(A_INPUT_GRAD); - bwd.add_input_grad_slot(B_INPUT_GRAD); - bwd.add_output_slot(OUTPUT); - bwd.add_output_grad_slot(OUTPUT_GRAD); - - return bwd; -} - -OpTaskBinding BatchMatmul::get_init_task_binding() const { - OpTaskBinding binding; - - binding.bind_arg(ATTRS, this->attrs); - binding.bind_arg(PROFILING, this->profiling); +OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { + OpTaskBinding bwd = infer_bwd_binding(forward(attrs).binding); - return binding; + return {BATCHMATMUL_BWD_TASK_ID, bwd}; } -OpTaskBinding BatchMatmul::get_fwd_task_binding() const { - OpTaskBinding binding; - - binding.bind(A_INPUT, input_tensor(0)); - binding.bind(B_INPUT, input_tensor(1)); - binding.bind(OUTPUT, output_tensor(0)); - - binding.bind_arg(ATTRS, this->attrs); - return binding; +static DeviceSpecificArg init_task_impl(TaskArgumentAccessor const &acc) { + auto const a_seq_length_dim = acc.get_argument(A_SEQ_LENGTH_DIM); + auto const b_seq_length_dim = acc.get_argument(B_SEQ_LENGTH_DIM); + PerDeviceFFHandle handle = acc.get_argument(HANDLE); + Allocator allocator = acc.get_allocator(); + + DeviceSpecificArg per_device_state = + acc.create_device_specific( + init_kernel(handle, + allocator, + a_seq_length_dim, + b_seq_length_dim)); + + // assert(weight.shape.get_volume() * sizeof(float) == + // acc.unwrap(per_device_state)->weightSize); + return per_device_state; } -OpTaskBinding BatchMatmul::get_bwd_task_binding() const { - OpTaskBinding binding; - binding.bind(A_INPUT, input_tensor(0)); - binding.bind(B_INPUT, input_tensor(1)); - binding.bind_grad(A_INPUT_GRAD, input_tensor(0).grad()); - binding.bind_grad(B_INPUT_GRAD, input_tensor(1).grad()); - - binding.bind(OUTPUT, output_tensor(0)); - binding.bind_grad(OUTPUT_GRAD, output_tensor(0).grad()); - - binding.bind_arg(ATTRS, this->attrs); - return binding; -} - -void BatchMatmul::init(FFModel const &ff) { - int dim = outputs[0]->num_dims; - switch (dim) { -#define DIMFUNC(DIM) \ - case DIM: { \ - // init_with_dim(ff); - this->execute_task(ff, BATCHMATMUL_INIT_TASK_ID, get_init_task_signature()); - break; - } - LEGION_FOREACH_N(DIMFUNC) -#undef DIMFUNC - default: - assert(false); -} -} // namespace FlexFlow -// / -// template -// void BatchMatmul::init_with_dim(FFModel const &ff) { -// assert(check_output_input_weight_same_parallel_is()); -// parallel_is = outputs[0]->parallel_is; -// ArgumentMap argmap; -// Context ctx = ff.config.lg_ctx; -// Runtime *runtime = ff.config.lg_hlr; -// set_argumentmap_for_init(ff, argmap); -// IndexLauncher launcher(BATCHMATMUL_INIT_TASK_ID, -// parallel_is, -// TaskArgument(this, sizeof(BatchMatmul)), -// argmap, -// Predicate::TRUE_PRED, -// false /*must*/, -// 0 /*mapper_id*/, -// outputs[0]->machine_view.hash()); -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, -// 0 /*projection id*/, -// WRITE_ONLY, -// EXCLUSIVE, -// outputs[0]->region)); -// launcher.add_field(0, FID_DATA); -// for (int i = 0; i < numInputs; i++) { -// launcher.add_region_requirement(RegionRequirement(inputs[i]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[i]->region)); -// launcher.add_field(i + 1, FID_DATA); -// } -// FutureMap fm = runtime->execute_index_space(ctx, launcher); -// fm.wait_all_results(); -// set_opmeta_from_futuremap(ff, fm); -// } - -PerDeviceOpState * - BatchMatmul::init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { +static DeviceSpecificArg + init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { TaskArgumentAccessor acc(task, regions, ctx, runtime); - auto const &attrs = acc.get_argument(ATTRS); - bool profiling = acc.get_argument(PROFILING); - - FFHandler handle = *((FFHandler const *)task->local_args); - BatchMatmulPerDeviceState *m = new BatchMatmulPerDeviceState(handle); - m->profiling = profiling; - m->a_seq_length_dim = attrs.a_seq_length_dim; - m->b_seq_length_dim = attrs.b_seq_length_dim; - return m; -} - -void BatchMatmul::forward(FFModel const &ff) { - int dim = outputs[0]->num_dims; - switch (dim) { -#define DIMFUNC(DIM) \ - case DIM: { \ - // forward_with_dim(ff); - this->execute_task(ff, BATCHMATMUL_FWD_TASK_ID, get_fwd_task_signature()); - break; - } - LEGION_FOREACH_N(DIMFUNC) -#undef DIMFUNC - default: - assert(false); -} + return init_task_impl(acc); } -// template -// void BatchMatmul::forward_with_dim(FFModel const &ff) { -// ArgumentMap argmap; -// Context ctx = ff.config.lg_ctx; -// Runtime *runtime = ff.config.lg_hlr; -// set_argumentmap_for_forward(ff, argmap); -// IndexLauncher launcher( -// BATCHMATMUL_FWD_TASK_ID, -// parallel_is, -// TaskArgument(&ff.iter_config, sizeof(FFIterationConfig)), -// argmap, -// Predicate::TRUE_PRED, -// false /*must*/, -// 0 /*mapper_id*/, -// outputs[0]->machine_view.hash()); -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, -// 0 /*projection id*/, -// WRITE_ONLY, -// EXCLUSIVE, -// outputs[0]->region)); -// launcher.add_field(0, FID_DATA); -// for (int i = 0; i < numInputs; i++) { -// launcher.add_region_requirement(RegionRequirement(inputs[i]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[i]->region)); -// launcher.add_field(i + 1, FID_DATA); -// } -// runtime->execute_index_space(ctx, launcher); -// } - -/* - regions[0](O): output - regions[1](I): A - regions[2](I): B - ////////////////////(optional) regions[3](I): C -- TODO: is C deprecated? - output = A * B /////////+ C -*/ -void BatchMatmul::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { +static optional forward_task_impl(TaskArgumentAccessor const &acc) { assert(regions.size() == 3); assert(task->regions.size() == 3); - TaskArgumentAccessor acc(task, regions, ctx, runtime); + auto a_input = acc.get_tensor(A_INPUT); + auto b_input = acc.get_tensor(B_INPUT); + auto output = acc.get_tensor(OUTPUT); - // const BatchMatmul* bmm = (const BatchMatmul*) task->args; + ProfilingSettings profiling = acc.get_argument(PROFILING); + auto per_device_state = acc.get_argument(PER_DEVICE_STATE); FFIterationConfig const *iter_config = (FFIterationConfig const *)task->args; - // BatchMatmulMeta const *meta = *((BatchMatmulMeta **)task->local_args); - BatchMatmulPerDeviceState const *meta = - *((BatchMatmulPerDeviceState **)task->local_args); - auto a_input = acc.get_tensor(A_INPUT); - auto b_input = acc.get_tensor(B_INPUT); - auto output = acc.get_tensor(OUTPUT); - - int m = b_input.shape[0]; - assert(m == output.shape[0]); - int n = a_input.shape[1]; - assert(n == output.shape[1]); - int k = a_input.shape[0]; - assert(k == b_input.shape[1]); + int m = b_input.shape[legion_dim_t(0)]; + assert(m == output.shape[legion_dim_t(0)]); + int n = a_input.shape[legion_dim_t(1)]; + assert(n == output.shape[legion_dim_t(1)]); + int k = a_input.shape[legion_dim_t(0)]; + assert(k == b_input.shape[legion_dim_t(1)]); assert(a_input.shape.size() == b_input.shape.size()); assert(a_input.shape.size() == output.shape.size()); + int batch = 1; for (int i = 2; i < a_input.shape.size(); i++) { - int dim_size = a_input.shape[i]; - assert(dim_size == b_input.shape[i]); - assert(dim_size == output.shape[i]); + int dim_size = a_input.shape[legion_dim_t(i)]; + assert(dim_size == b_input.shape[legion_dim_t(i)]); + assert(dim_size == output.shape[legion_dim_t(i)]); batch *= dim_size; } - float *out_ptr = output.get_float_ptr(); - c float const *a_ptr = a_input.get_float_ptr(); - float const *b_ptr = b_input.get_float_ptr(); - float const *c_ptr = NULL; - // if (regions.size() == 4) { - // Domain c_domain = runtime->get_index_space_domain( - // ctx, task->regions[3].region.get_index_space()); - // assert(c_domain == a_domain); - // c_ptr = helperGetTensorPointerRO( - // regions[3], task->regions[3], FID_DATA, ctx, runtime); - // } - - profile(forward_kernel, - meta->profiling, + + return profile(forward_kernel, + profiling, "[BatchMatmul] forward_time = %.2lfms\n", - out_ptr, - a_ptr, - b_ptr, - c_ptr, + per_device_state, + output.get_float_ptr(), + a_input.get_float_ptr(), + b_input.get_float_ptr(), + NULL, //c_ptr m, n, k, batch, - meta->a_seq_length_dim, - meta->b_seq_length_dim, iter_config->seq_length); } -void BatchMatmul::backward(FFModel const &ff) { - int dim = outputs[0]->num_dims; - switch (dim) { -#define DIMFUNC(DIM) \ - case DIM: { \ - backward_with_dim(ff); \ - break; \ - } - LEGION_FOREACH_N(DIMFUNC) -#undef DIMFUNC - default: - assert(false); - } +static void forward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + forward_task_impl(acc); } -/* - regions[0](I): output - regions[1](I): output_grad - regions[2](I): A - regions[3](I/O): A_grad - regions[4](I): B - regions[5](I/O): B_grad - regions[6](I/O): C_grad -*/ -// template -// void BatchMatmul::backward_with_dim(FFModel const &ff) { -// ArgumentMap argmap; -// Context ctx = ff.config.lg_ctx; -// Runtime *runtime = ff.config.lg_hlr; -// set_argumentmap_for_backward(ff, argmap); -// IndexLauncher launcher( -// BATCHMATMUL_BWD_TASK_ID, -// parallel_is, -// TaskArgument(&ff.iter_config, sizeof(FFIterationConfig)), -// argmap, -// Predicate::TRUE_PRED, -// false /*must*/, -// 0 /*mapper_id*/, -// outputs[0]->machine_view.hash()); -// // regions[0](I): output -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// outputs[0]->region)); -// launcher.add_field(0, FID_DATA); -// // regions[1](I): output_grad -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// outputs[0]->region_grad)); -// launcher.add_field(1, FID_DATA); -// // regions[2](I): A -// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[0]->region)); -// launcher.add_field(2, FID_DATA); -// // regions[3](I/O): A_grad -// launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, -// 0 /*projection id*/, -// READ_WRITE, -// EXCLUSIVE, -// inputs[0]->region_grad)); -// launcher.add_field(3, FID_DATA); -// // regions[4](I): B -// launcher.add_region_requirement(RegionRequirement(inputs[1]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[1]->region)); -// launcher.add_field(4, FID_DATA); -// // regions[5](I/O): B_grad -// launcher.add_region_requirement(RegionRequirement(inputs[1]->part_grad, -// 0 /*projection id*/, -// READ_WRITE, -// EXCLUSIVE, -// inputs[1]->region_grad)); -// launcher.add_field(5, FID_DATA); -// runtime->execute_index_space(ctx, launcher); -// } - -/* - regions[0](I): output - regions[1](I): output_grad - regions[2](I): A - regions[3](I/O): A_grad - regions[4](I): B - regions[5](I/O): B_grad - regions[6](I/O): C_grad -*/ -__host__ void - BatchMatmul::backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { +static optional backward_task_impl(TaskArgumentAccessor const &acc) { // Currently assume C is NULL assert(regions.size() == 6); assert(task->regions.size() == 6); + // BatchMatmul* bmm = (BatchMatmul*) task->args; - TaskArgumentAccessor acc(task, regions, ctx, runtime); FFIterationConfig const *iter_config = (FFIterationConfig const *)task->args; - BatchMatmulPerDeviceState const *meta = - *((BatchMatmulPerDeviceState **)task->local_args); - // output domains - auto output = acc.get_tensor(OUTPUT); - auto output_grad = acc.get_tensor(OUTPUT_GRAD); - assert(output == - output_grad); // is this equivalent to checking `Domain` equality? - // A domains - auto a_input = acc.get_tensor(A_INPUT); - auto a_input_grad = acc.get_tensor(A_INPUT_GRAD); + ProfilingSettings profiling = acc.get_argument(PROFILING); + auto per_device_state = acc.get_argument(PER_DEVICE_STATE); + + auto output = acc.get_tensor(OUTPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); + // is this equivalent to checking `Domain` equality? + assert(output == output_grad); + + auto a_input = acc.get_tensor(A_INPUT); + auto a_input_grad = acc.get_tensor_grad(A_INPUT); assert(a_input == a_input_grad); - // B domains - auto b_input = acc.get_tensor(B_INPUT); - auto b_input_grad = acc.get_tensor(B_INPUT_GRAD); + + auto b_input = acc.get_tensor(B_INPUT); + auto b_input_grad = acc.get_tensor_grad(B_INPUT); assert(b_input == b_input_grad); // check dins - int m = b_input.shape[0]; - assert(m == output.shape[0]); - int n = a_input.shape[1]; - assert(n == output.shape[1]); - int k = a_input.shape[0]; - assert(k == b_input.shape[1]); + int m = b_input.shape[legion_dim_t(0)]; + assert(m == output.shape[legion_dim_t(0)]); + int n = a_input.shape[legion_dim_t(1)]; + assert(n == output.shape[legion_dim_t(1)]); + int k = a_input.shape[legion_dim_t(0)]; + assert(k == b_input.shape[legion_dim_t(1)]); assert(a_input.shape.size() == b_input.shape.size()); assert(a_input.shape.size() == output.shape.size()); int batch = 1; for (int i = 2; i < a_input.shape.size(); i++) { - int dim_size = a_input.shape[i]; - assert(dim_size == b_input.shape[i]); - assert(dim_size == output.shape[i]); + int dim_size = a_input.shape[legion_dim_t(i)]; + assert(dim_size == b_input.shape[legion_dim_t(i)]); + assert(dim_size == output.shape[legion_dim_t(i)]); batch *= dim_size; } - // get pointers - float const *out_ptr = output.get_float_ptr(); - float const *out_grad_ptr = output_grad.get_float_ptr(); - float const *a_ptr = a_input.get_float_ptr(); - float *a_grad_ptr = a_input_grad.get_float_ptr(); - float const *b_ptr = b_input.get_float_ptr(); - float *b_grad_ptr = b_input_grad.get_float_ptr(); - - float *c_grad_ptr = NULL; // TODO: add support for meta->a_seq_length_dim >= 0 // or meta->b_seq_length_dim >= 0 assert((meta->a_seq_length_dim >= a_len) || (iter_config->seq_length == 0)); assert((meta->b_seq_length_dim >= b_len) || (iter_config->seq_length == 0)); - profile(backward_kernel, - meta->profiling, + return profile(backward_kernel, + profiling, "[BatchMatmul] backward_time = %.2lfms\n", - meta, - out_ptr, - out_grad_ptr, - a_ptr, - a_grad_ptr, - b_ptr, - b_grad_ptr, - c_grad_ptr, + per_device_state, + output.get_float_ptr(), + output_grad.get_float_ptr(), + a_input.get_float_ptr(), + a_input_grad.get_float_ptr(), + b_input.get_float_ptr(), + b_input_grad.get_float_ptr(), + NULL, //c_grad_ptr m, n, k, batch); } -void BatchMatmul::print_layer(FFModel const &ff) { - return; +static void backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + backward_task_impl(acc); } -bool BatchMatmul::measure_operator_cost(Simulator *sim, - MachineView const &pc, - CostMetrics &cost_metrics) const { - ParallelTensorBase sub_output, sub_input0, sub_input1; - if (!outputs[0]->get_sub_tensor(pc, sub_output)) { - return false; - } - if (!inputs[0]->get_sub_tensor(pc, sub_input0)) { - return false; - } - if (!inputs[1]->get_sub_tensor(pc, sub_input1)) { - return false; - } - - int input0_c = sub_input0.dims[0].size; - int input0_r = sub_input0.dims[1].size; - int input1_c = sub_input1.dims[0].size; - int input1_r = sub_input1.dims[1].size; - int output_c = sub_output.dims[0].size; - int output_r = sub_output.dims[1].size; - - assert(input0_c == input1_r); - assert(input0_r == output_r); - assert(input1_c == output_c); - - assert(sub_input0.dims[2] == sub_input1.dims[2]); - assert(sub_input1.dims[2] == sub_output.dims[2]); - int batch = 1; - assert(sub_input0.num_dims == sub_input1.num_dims); - for (int i = 2; i < sub_input0.num_dims; i++) { - assert(sub_input0.dims[i] == sub_input1.dims[i]); - assert(sub_input0.dims[i] == sub_output.dims[i]); - batch *= sub_input0.dims[i].size; - } - - BatchMatmulPerDeviceState *meta = sim->batch_matmul_meta; +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + BatchMatmulAttrs const &attrs, + InputParallelTensorDesc const &a_input, + InputParallelTensorDesc const &b_input, + ProfilingSettings const &settings, + MachineView const &pc) const { + auto env = sim.new_environment(); + + //todo add get_output_shape and get_weights_shape to batch_matmul op-attrs + // ParallelTensorShape output_shape = get_output_shape(attrs, inputs); + // ParallelTensorShape weight_shape = get_weights_shape(attrs, inputs); + + SimTaskBinding init_binding; + init_binding.bind_arg(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); + init_binding.bind_arg(B_SEQ_LENGTH_DIM, get_bSeqLengthDim(attrs)); + init_binding.bind_arg(HANDLE, ff_handle()); + + auto init_accessor = + env.get_init_accessor(BATCHMATMUL_INIT_TASK_ID, init_binding); + DeviceSpecificArg per_device_state = + init_task_impl(init_accessor); + + SimTaskBinding fwd_binding; + fwd_binding.bind(A_INPUT, a_input); + fwd_binding.bind(B_INPUT, b_input); + // fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind_arg(PROFILING, settings); + fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); + + SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); + + auto fwd_accessor = env.get_fwd_accessor(BATCHMATMUL_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = env.get_bwd_accessor(BATCHMATMUL_BWD_TASK_ID, bwd_binding); + + float forward_time = forward_task_impl(fwd_accessor).value(); + float backward_time = backward_task_impl(bwd_accessor).value(); + + float sync_time = default_estimate_sync_time(env); + return make_metrics(forward_time, backward_time, sync_time, env); +} - // allocate tensors in simulator - sim->free_all(); - float *a_ptr = (float *)sim->allocate(sub_input0.get_volume(), DT_FLOAT); - assert(a_ptr != NULL); - float *b_ptr = (float *)sim->allocate(sub_input1.get_volume(), DT_FLOAT); - assert(b_ptr != NULL); - float *c_ptr = NULL; - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); +template <> +void register_task() { + OpTaskSignature init(OpTaskType::INIT); - float *out_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - assert(out_ptr != NULL); - cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); + init.add_arg_slot(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); + init.add_arg_slot(B_SEQ_LENGTH_DIM, get_bSeqLengthDim(attrs)); + init.add_unchecked_arg_slot(HANDLE, ff_handle()); - int m = input1_c; - int n = input0_r; - int k = input0_c; + register_task(BATCHMATMUL_INIT_TASK_ID, "BatchMatmul Init", init, init_task); +} - assert(meta->profiling == false); +template <> +void register_task() { + OpTaskSignature fwd(OpTaskType::FWD); - std::function forward, backward; - forward = [&](ffStream_t stream) { - forward_kernel(stream, meta, out_ptr, a_ptr, b_ptr, c_ptr, m, n, k, batch); - }; + fwd.add_input_slot(A_INPUT); + fwd.add_input_slot(B_INPUT); + fwd.add_output_slot(OUTPUT); + fwd.add_arg_slot(PROFILING); + fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); - if (sim->computationMode == COMP_MODE_TRAINING) { - float *a_grad_ptr = - (float *)sim->allocate(sub_input0.get_volume(), DT_FLOAT); - float *b_grad_ptr = - (float *)sim->allocate(sub_input1.get_volume(), DT_FLOAT); - float *c_grad_ptr = NULL; - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *out_grad_ptr = - (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - assert(out_grad_ptr != NULL); - cost_metrics.outputs_memory += - cost_metrics.total_mem_diff_from(sim->offset); - - backward = [&](ffStream_t stream) { - backward_kernel(stream, - meta, - out_ptr, - out_grad_ptr, - a_ptr, - a_grad_ptr, - b_ptr, - b_grad_ptr, - c_grad_ptr, - m, - n, - k, - batch); - }; - } + register_task(BATCHMATMUL_FWD_TASK_ID, "BatchMatmul Fwd", fwd, forward_task); +} - inner_measure_operator_cost(sim, forward, backward, cost_metrics); - - if (sim->computationMode == COMP_MODE_TRAINING) { - printf("[Measure BatchMatmul] name(%s) adim(%d %d %d) bdim(%d %d %d) " - "odim(%d %d %d) forward_time(%.4lf) backward_time(%.4lf)\n", - name, - batch, - input0_r, - input0_c, - batch, - input1_r, - input1_c, - batch, - output_r, - output_c, - cost_metrics.forward_time, - cost_metrics.backward_time); - } else { - printf("[Measure BatchMatmul] name(%s) adim(%d %d %d) bdim(%d %d %d) " - "odim(%d %d %d) forward_time(%.4lf)\n", - name, - batch, - input0_r, - input0_c, - batch, - input1_r, - input1_c, - batch, - output_r, - output_c, - cost_metrics.forward_time); - } +template <> +void register_task() { + OpTaskSignature bwd = + infer_bwd_signature(get_op_signature(ATTENTION_FWD_TASK_ID)); - return true; + register_task(BATCHMATMUL_BWD_TASK_ID, "BatchMatmul Bwd", bwd, backward_task); } -} -; // namespace FlexFlow + +}; // namespace FlexFlow diff --git a/lib/runtime/src/ops/batch_matmul.h b/lib/runtime/src/ops/batch_matmul.h index c133c2a875..ab9bb45e8d 100644 --- a/lib/runtime/src/ops/batch_matmul.h +++ b/lib/runtime/src/ops/batch_matmul.h @@ -1,10 +1,14 @@ #ifndef _FLEXFLOW_BATCH_MATMUL_H #define _FLEXFLOW_BATCH_MATMUL_H +// #include "op-attrs/ops/batch_matmul.h" +// #include "task_spec/op_task_invocation.h" +// #include "task_spec/op_task_signature.h" +// #include "sim_environment.h" + #include "op-attrs/ops/batch_matmul.h" -#include "op_task_invocation.h" -#include "op_task_signature.h" #include "sim_environment.h" +#include "task_spec/op_task_invocation.h" namespace FlexFlow { @@ -19,12 +23,12 @@ OpTaskInvocation init(BatchMatmulAttrs const &); OpTaskInvocation forward(BatchMatmulAttrs const &); OpTaskInvocation backward(BatchMatmulAttrs const &); -CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, - BatchMatmulAttrs const &attrs, - ParallelTensorShape const &lhs_input_shape, - ParallelTensorShape const &rhs_input_shape, - ProfilingSettings const &settings, - MachineView const &); +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + BatchMatmulAttrs const &attrs, + InputParallelTensorDesc const &a_input, + InputParallelTensorDesc const &b_input, + ProfilingSettings const &settings, + MachineView const &pc); /* class BatchMatmul : public Op { */ /* public: */ @@ -84,3 +88,424 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, } // namespace FlexFlow #endif + + + +// BatchMatmulParams BatchMatmul::get_params() const { +// BatchMatmulParams params; +// params.a_seq_length_dim = inputs[0]->num_dims - 1 - this->a_seq_length_dim; +// params.b_seq_length_dim = inputs[1]->num_dims - 1 - this->b_seq_length_dim; +// return params; +// } + +// Tensor FFModel::batch_matmul(const Tensor A, +// const Tensor B, +// int a_seq_length_dim, +// int b_seq_length_dim, +// char const *name) { +// Layer *bmm = new Layer(this, +// OP_BATCHMATMUL, +// DT_FLOAT, +// name, +// 2 /*inputs*/, +// 0 /*weights*/, +// 1 /*outputs*/, +// A, +// B); +// assert((a_seq_length_dim <= 1) && +// "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " +// "Fortran ordering)."); +// assert((b_seq_length_dim <= 1) && +// "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " +// "Fortran ordering)."); +// assert(A->num_dims == B->num_dims); +// for (int i = A->num_dims - 1; i >= 2; i--) { +// assert(A->dims[i] == B->dims[i]); +// } +// assert(A->dims[0] == B->dims[1]); +// int dims[MAX_TENSOR_DIM]; +// int numdim = A->num_dims; +// for (int i = 0; i < numdim; i++) { +// dims[i] = A->dims[i]; +// } +// dims[0] = B->dims[0]; +// bmm->outputs[0] = create_tensor_legion_ordering( +// numdim, dims, A->data_type, bmm, 0, true /*create_grad*/); +// bmm->add_int_property("a_seq_length_dim", a_seq_length_dim); +// bmm->add_int_property("b_seq_length_dim", b_seq_length_dim); +// layers.push_back(bmm); +// return bmm->outputs[0]; +// } + +// Op *BatchMatmul::create_operator_from_layer( +// FFModel &model, +// Layer const *layer, +// std::vector const &inputs) { +// long long value; +// layer->get_int_property("a_seq_length_dim", value); +// int a_seq_length_dim = value; +// layer->get_int_property("b_seq_length_dim", value); +// int b_seq_length_dim = value; +// return new BatchMatmul(model, +// inputs[0], +// inputs[1], +// a_seq_length_dim, +// b_seq_length_dim, +// layer->name); +// } + +// BatchMatmul::BatchMatmul( +// FFModel &model, +// BatchMatmulParams const ¶ms, +// std::pair const &inputs, +// char const *name) +// : BatchMatmul(model, +// inputs.first, +// inputs.second, +// params.a_seq_length_dim, +// params.b_seq_length_dim, +// name) {} + +// // return A*B +// BatchMatmul::BatchMatmul(FFModel &model, +// const ParallelTensor A, +// const ParallelTensor B, +// int _a_seq_length_dim, +// int _b_seq_length_dim, +// char const *name) +// : Op(model, +// OP_BATCHMATMUL, +// DT_FLOAT, +// name, +// 2 /*inputs*/, +// 0 /*weights*/, +// 1 /*outputs*/, +// A, +// B), +// a_seq_length_dim(A->num_dims - 1 - _a_seq_length_dim), +// b_seq_length_dim(B->num_dims - 1 - _b_seq_length_dim) { +// assert((_a_seq_length_dim <= 1) && +// "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " +// "Fortran ordering)."); +// assert((_b_seq_length_dim <= 1) && +// "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " +// "Fortran ordering)."); +// assert(A->num_dims == B->num_dims); +// for (int i = A->num_dims - 1; i >= 2; i--) { +// assert(A->dims[i] == B->dims[i]); +// } +// assert(A->dims[0] == B->dims[1]); +// ParallelDim dims[MAX_TENSOR_DIM]; +// for (int i = 0; i < A->num_dims; i++) { +// dims[i] = A->dims[i]; +// } +// dims[0] = B->dims[0]; +// numOutputs = 1; +// outputs[0] = model.create_parallel_tensor_legion_ordering( +// A->num_dims, dims, DT_FLOAT, this); +// // C is not none +// // if (C != Tensor::NO_TENSOR) { +// // numInputs = 3; +// // assert(C.num_dims == outputs[0].num_dims); +// // for (int i = 0; i < C.num_dims; i++) +// // assert(C.adim[i] == outputs[0].adim[i]); +// //} +// } + +// void BatchMatmul::serialize(Legion::Serializer &sez) const { +// BatchMatmulParams params = get_params(); +// sez.serialize(params.a_seq_length_dim); +// sez.serialize(params.b_seq_length_dim); +// } + +// using PCG::Node; +// /*static*/ +// Node BatchMatmul::deserialize(FFModel &ff, +// Legion::Deserializer &dez, +// ParallelTensor inputs[], +// int num_inputs) { +// assert(num_inputs == 2); +// int a_seq_length_dim, b_seq_length_dim; +// dez.deserialize(a_seq_length_dim); +// dez.deserialize(b_seq_length_dim); + +// BatchMatmulParams params; +// params.a_seq_length_dim = a_seq_length_dim; +// params.b_seq_length_dim = b_seq_length_dim; +// return ff.get_or_create_node({inputs[0], inputs[1]}, params); +// } + +// Op *BatchMatmul::materialize(FFModel &ff, +// ParallelTensor inputs[], +// int num_inputs) const { +// BatchMatmulParams params = get_params(); +// return new BatchMatmul(ff, params, {inputs[0], inputs[1]}, this->name); +// } + + +// void BatchMatmul::forward(FFModel const &ff) { +// int dim = outputs[0]->num_dims; +// switch (dim) { +// #define DIMFUNC(DIM) \ +// case DIM: { \ +// // forward_with_dim(ff); +// this->execute_task(ff, BATCHMATMUL_FWD_TASK_ID, get_fwd_task_signature()); +// break; +// } +// LEGION_FOREACH_N(DIMFUNC) +// #undef DIMFUNC +// default: +// assert(false); +// } +// } + +// template +// void BatchMatmul::forward_with_dim(FFModel const &ff) { +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_forward(ff, argmap); +// IndexLauncher launcher( +// BATCHMATMUL_FWD_TASK_ID, +// parallel_is, +// TaskArgument(&ff.iter_config, sizeof(FFIterationConfig)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// for (int i = 0; i < numInputs; i++) { +// launcher.add_region_requirement(RegionRequirement(inputs[i]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[i]->region)); +// launcher.add_field(i + 1, FID_DATA); +// } +// runtime->execute_index_space(ctx, launcher); +// } + +/* + regions[0](O): output + regions[1](I): A + regions[2](I): B + ////////////////////(optional) regions[3](I): C -- TODO: is C deprecated? + output = A * B /////////+ C +*/ + + +// void BatchMatmul::init(FFModel const &ff) { +// int dim = outputs[0]->num_dims; +// switch (dim) { +// #define DIMFUNC(DIM) \ +// case DIM: { \ +// // init_with_dim(ff); +// this->execute_task(ff, BATCHMATMUL_INIT_TASK_ID, get_init_task_signature()); +// break; +// } +// LEGION_FOREACH_N(DIMFUNC) +// #undef DIMFUNC +// default: +// assert(false); +// } +// } // namespace FlexFlow +// // / +// // template +// // void BatchMatmul::init_with_dim(FFModel const &ff) { +// // assert(check_output_input_weight_same_parallel_is()); +// // parallel_is = outputs[0]->parallel_is; +// // ArgumentMap argmap; +// // Context ctx = ff.config.lg_ctx; +// // Runtime *runtime = ff.config.lg_hlr; +// // set_argumentmap_for_init(ff, argmap); +// // IndexLauncher launcher(BATCHMATMUL_INIT_TASK_ID, +// // parallel_is, +// // TaskArgument(this, sizeof(BatchMatmul)), +// // argmap, +// // Predicate::TRUE_PRED, +// // false /*must*/, +// // 0 /*mapper_id*/, +// // outputs[0]->machine_view.hash()); +// // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// // 0 /*projection id*/, +// // WRITE_ONLY, +// // EXCLUSIVE, +// // outputs[0]->region)); +// // launcher.add_field(0, FID_DATA); +// // for (int i = 0; i < numInputs; i++) { +// // launcher.add_region_requirement(RegionRequirement(inputs[i]->part, +// // 0 /*projection id*/, +// // READ_ONLY, +// // EXCLUSIVE, +// // inputs[i]->region)); +// // launcher.add_field(i + 1, FID_DATA); +// // } +// // FutureMap fm = runtime->execute_index_space(ctx, launcher); +// // fm.wait_all_results(); +// // set_opmeta_from_futuremap(ff, fm); +// // } + +// OpTaskBinding BatchMatmul::get_bwd_task_binding() const { +// OpTaskBinding binding; +// binding.bind(A_INPUT, input_tensor(0)); +// binding.bind(B_INPUT, input_tensor(1)); +// binding.bind_grad(A_INPUT_GRAD, input_tensor(0).grad()); +// binding.bind_grad(B_INPUT_GRAD, input_tensor(1).grad()); + +// binding.bind(OUTPUT, output_tensor(0)); +// binding.bind_grad(OUTPUT_GRAD, output_tensor(0).grad()); + +// binding.bind_arg(ATTRS, this->attrs); +// return binding; +// } + + +// static OpTaskSignature get_fwd_task_signature() { +// OpTaskSignature fwd(OpTaskType::FWD); + +// fwd.add_input_slot(A_INPUT, READ_WRITE); +// fwd.add_input_slot(B_INPUT, READ_WRITE); +// fwd.add_output_slot(OUTPUT); + +// return fwd; +// } + +// static OpTaskSignature get_bwd_task_signature() { +// OpTaskSignature bwd(OpTaskType::BWD); + +// bwd.add_input_slot(A_INPUT); +// bwd.add_input_slot(B_INPUT); +// bwd.add_input_grad_slot(A_INPUT_GRAD); +// bwd.add_input_grad_slot(B_INPUT_GRAD); +// bwd.add_output_slot(OUTPUT); +// bwd.add_output_grad_slot(OUTPUT_GRAD); + +// return bwd; +// } + +// OpTaskBinding BatchMatmul::get_init_task_binding() const { +// OpTaskBinding binding; + +// binding.bind_arg(ATTRS, this->attrs); +// binding.bind_arg(PROFILING, this->profiling); + +// return binding; +// } + +// OpTaskBinding BatchMatmul::get_fwd_task_binding() const { +// OpTaskBinding binding; + +// binding.bind(A_INPUT, input_tensor(0)); +// binding.bind(B_INPUT, input_tensor(1)); +// binding.bind(OUTPUT, output_tensor(0)); + +// binding.bind_arg(ATTRS, this->attrs); +// return binding; +// } + +//void BatchMatmul::backward(FFModel const &ff) { +// int dim = outputs[0]->num_dims; +// switch (dim) { +// #define DIMFUNC(DIM) \ +// case DIM: { \ +// backward_with_dim(ff); \ +// break; \ +// } +// LEGION_FOREACH_N(DIMFUNC) +// #undef DIMFUNC +// default: +// assert(false); +// } +// } + + +// void BatchMatmul::print_layer(FFModel const &ff) { +// return; +// } + + + +/* + regions[0](I): output + regions[1](I): output_grad + regions[2](I): A + regions[3](I/O): A_grad + regions[4](I): B + regions[5](I/O): B_grad + regions[6](I/O): C_grad +*/ +// template +// void BatchMatmul::backward_with_dim(FFModel const &ff) { +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_backward(ff, argmap); +// IndexLauncher launcher( +// BATCHMATMUL_BWD_TASK_ID, +// parallel_is, +// TaskArgument(&ff.iter_config, sizeof(FFIterationConfig)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// // regions[0](I): output +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// // regions[1](I): output_grad +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// outputs[0]->region_grad)); +// launcher.add_field(1, FID_DATA); +// // regions[2](I): A +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(2, FID_DATA); +// // regions[3](I/O): A_grad +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// inputs[0]->region_grad)); +// launcher.add_field(3, FID_DATA); +// // regions[4](I): B +// launcher.add_region_requirement(RegionRequirement(inputs[1]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[1]->region)); +// launcher.add_field(4, FID_DATA); +// // regions[5](I/O): B_grad +// launcher.add_region_requirement(RegionRequirement(inputs[1]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// inputs[1]->region_grad)); +// launcher.add_field(5, FID_DATA); +// runtime->execute_index_space(ctx, launcher); +// } + +/* + regions[0](I): output + regions[1](I): output_grad + regions[2](I): A + regions[3](I/O): A_grad + regions[4](I): B + regions[5](I/O): B_grad + regions[6](I/O): C_grad +*/ \ No newline at end of file From d9aae744914fb6d710a5f174d4e88cc0cd7a95d5 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 22 Aug 2023 14:29:35 -0700 Subject: [PATCH 02/24] fix FF_VISITABLE_STRUCT_NO_EQ --- lib/kernels/include/kernels/batch_matmul_kernels.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index cdffdf1907..b8e6b9cd4b 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -14,7 +14,10 @@ struct BMMPerDeviceState { }; FF_VISITABLE_STRUCT_NO_EQ(BMMPerDeviceState, - handle,); + handle, + allocator, + a_seq_length_dim, + b_seq_length_dim); namespace Kernels { namespace BatchMatmul { From 2073822dbc697615407763f10a8383e864c6929e Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Wed, 23 Aug 2023 12:03:39 -0700 Subject: [PATCH 03/24] finish draft 1 batch_matmul --- .../include/kernels/batch_matmul_kernels.h | 7 ++-- lib/runtime/include/runtime/config.h | 8 +++-- lib/runtime/src/ops/batch_matmul.cc | 32 ++++++++++--------- lib/runtime/src/task_spec/op_arg_ref.h | 8 ++++- 4 files changed, 32 insertions(+), 23 deletions(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index b8e6b9cd4b..1506305e9e 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -10,7 +10,7 @@ struct BMMPerDeviceState { PerDeviceFFHandle handle; Allocator allocator; int a_seq_length_dim; - int b_seq_length_dim; + req b_seq_length_dim; }; FF_VISITABLE_STRUCT_NO_EQ(BMMPerDeviceState, @@ -28,7 +28,7 @@ BMMPerDeviceState init_kernel(PerDeviceFFHandle handle, int b_seq_length_dim); void forward_kernel(ffStream_t stream, - BMMPerDeviceState const *meta, + BMMPerDeviceState const &meta, float *o_ptr, float const *a_ptr, float const *b_ptr, @@ -40,14 +40,13 @@ void forward_kernel(ffStream_t stream, int seq_length = -1); void backward_kernel(ffStream_t stream, - BMMPerDeviceState const *meta, + BMMPerDeviceState const &meta, float const *o_ptr, float const *o_grad_ptr, float const *a_ptr, float *a_grad_ptr, float const *b_ptr, float *b_grad_ptr, - float *c_grad_ptr, int m, int n, int k, diff --git a/lib/runtime/include/runtime/config.h b/lib/runtime/include/runtime/config.h index a7b8d86171..fc45c2c76c 100644 --- a/lib/runtime/include/runtime/config.h +++ b/lib/runtime/include/runtime/config.h @@ -104,13 +104,15 @@ struct FFConfig : public use_visitable_cmp { int python_data_loader_type = 2; }; -class FFIterationConfig { -public: +struct FFIterationConfig { FFIterationConfig(); void reset(); - int seq_length; + req seq_length; }; +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(FFIterationConfig, + seq_length); + enum FieldIDs { FID_DATA, }; diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 9bbc050f8b..761c565657 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -41,7 +41,8 @@ enum Slots { HANDLE, A_SEQ_LENGTH_DIM, B_SEQ_LENGTH_DIM, - PER_DEVICE_STATE + PER_DEVICE_STATE, + ITERATION_CONFIG }; OpTaskInvocation init(BatchMatmulAttrs const &attrs) { @@ -63,6 +64,7 @@ OpTaskInvocation forward(BatchMatmulAttrs const &attrs) { fwd.bind_arg(PROFILING, profiling_settings()); fwd.bind_arg(PER_DEVICE_STATE, per_device_op_state()); + fwd.bind_arg(ITERATION_CONFIG, iteration_config()); return {BATCHMATMUL_FWD_TASK_ID, fwd}; } @@ -110,7 +112,7 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto per_device_state = acc.get_argument(PER_DEVICE_STATE); - FFIterationConfig const *iter_config = (FFIterationConfig const *)task->args; + FFIterationConfig iter_config = acc.get_argument(ITERATION_CONFIG); int m = b_input.shape[legion_dim_t(0)]; assert(m == output.shape[legion_dim_t(0)]); @@ -123,7 +125,7 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { assert(a_input.shape.size() == output.shape.size()); int batch = 1; - for (int i = 2; i < a_input.shape.size(); i++) { + for (int i = 2; i < a_input.shape.get_dim(); i++) { //get_dim() or get_volume()? int dim_size = a_input.shape[legion_dim_t(i)]; assert(dim_size == b_input.shape[legion_dim_t(i)]); assert(dim_size == output.shape[legion_dim_t(i)]); @@ -137,12 +139,12 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { output.get_float_ptr(), a_input.get_float_ptr(), b_input.get_float_ptr(), - NULL, //c_ptr + nullptr, //c_ptr m, n, k, batch, - iter_config->seq_length); + iter_config.seq_length); } static void forward_task(Task const *task, @@ -159,7 +161,7 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(task->regions.size() == 6); // BatchMatmul* bmm = (BatchMatmul*) task->args; - FFIterationConfig const *iter_config = (FFIterationConfig const *)task->args; + FFIterationConfig iter_config = acc.get_argument(ITERATION_CONFIG); ProfilingSettings profiling = acc.get_argument(PROFILING); auto per_device_state = acc.get_argument(PER_DEVICE_STATE); @@ -186,7 +188,7 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(a_input.shape.size() == b_input.shape.size()); assert(a_input.shape.size() == output.shape.size()); int batch = 1; - for (int i = 2; i < a_input.shape.size(); i++) { + for (int i = 2; i < a_input.shape.get_dim(); i++) { //@colin get_dim() or get_volume()? int dim_size = a_input.shape[legion_dim_t(i)]; assert(dim_size == b_input.shape[legion_dim_t(i)]); assert(dim_size == output.shape[legion_dim_t(i)]); @@ -195,8 +197,8 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { // TODO: add support for meta->a_seq_length_dim >= 0 // or meta->b_seq_length_dim >= 0 - assert((meta->a_seq_length_dim >= a_len) || (iter_config->seq_length == 0)); - assert((meta->b_seq_length_dim >= b_len) || (iter_config->seq_length == 0)); + assert((meta->a_seq_length_dim >= a_len) || (iter_config.seq_length == 0)); + assert((meta->b_seq_length_dim >= b_len) || (iter_config.seq_length == 0)); return profile(backward_kernel, profiling, @@ -208,7 +210,6 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { a_input_grad.get_float_ptr(), b_input.get_float_ptr(), b_input_grad.get_float_ptr(), - NULL, //c_grad_ptr m, n, k, @@ -228,10 +229,10 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, InputParallelTensorDesc const &a_input, InputParallelTensorDesc const &b_input, ProfilingSettings const &settings, - MachineView const &pc) const { + MachineView const &pc) { auto env = sim.new_environment(); - //todo add get_output_shape and get_weights_shape to batch_matmul op-attrs + // @colin todo add get_output_shape and get_weights_shape to batch_matmul op-attrs // ParallelTensorShape output_shape = get_output_shape(attrs, inputs); // ParallelTensorShape weight_shape = get_weights_shape(attrs, inputs); @@ -248,6 +249,7 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, SimTaskBinding fwd_binding; fwd_binding.bind(A_INPUT, a_input); fwd_binding.bind(B_INPUT, b_input); + //@colin will uncomment after get_output_shape is done // fwd_binding.bind(OUTPUT, output_shape); fwd_binding.bind_arg(PROFILING, settings); fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); @@ -268,9 +270,9 @@ template <> void register_task() { OpTaskSignature init(OpTaskType::INIT); - init.add_arg_slot(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); - init.add_arg_slot(B_SEQ_LENGTH_DIM, get_bSeqLengthDim(attrs)); - init.add_unchecked_arg_slot(HANDLE, ff_handle()); + init.add_arg_slot(A_SEQ_LENGTH_DIM); + init.add_arg_slot(B_SEQ_LENGTH_DIM); + init.add_unchecked_arg_slot(HANDLE); register_task(BATCHMATMUL_INIT_TASK_ID, "BatchMatmul Init", init, init_task); } diff --git a/lib/runtime/src/task_spec/op_arg_ref.h b/lib/runtime/src/task_spec/op_arg_ref.h index 6e921c05e8..bc4f38e5fb 100644 --- a/lib/runtime/src/task_spec/op_arg_ref.h +++ b/lib/runtime/src/task_spec/op_arg_ref.h @@ -4,10 +4,12 @@ #include "arg_ref.h" #include "device_specific_arg.h" #include "op-attrs/parallel_tensor_shape.h" +#include "runtime/config.h" + namespace FlexFlow { -enum class OpArgRefType { PER_DEVICE_OP_STATE, PARALLEL_TENSOR_SHAPE }; +enum class OpArgRefType { PER_DEVICE_OP_STATE, PARALLEL_TENSOR_SHAPE, ITERATION_CONFIG }; template using OpArgRef = ArgRef; @@ -23,6 +25,10 @@ OpArgRef input_parallel_tensor_shape(int idx) { return {OpArgRefType::PARALLEL_TENSOR_SHAPE}; } +OpArgRef iteration_config() { + return {OpArgRefType::ITERATION_CONFIG}; +} + } // namespace FlexFlow #endif From d5806a01163ab9fb13f0457c7cc4ba4e330b41d0 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Wed, 23 Aug 2023 13:35:55 -0700 Subject: [PATCH 04/24] add output and weights --- lib/op-attrs/include/op-attrs/ops/batch_matmul.h | 3 +++ lib/runtime/src/ops/batch_matmul.cc | 15 ++++----------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index 00c700ba20..fa80ed3da9 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -14,6 +14,9 @@ FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim); int get_aSeqLengthDim(BatchMatmulAttrs const &attrs); int get_bSeqLengthDim(BatchMatmulAttrs const &attrs); +ParallelTensorShape get_weights_shape(BatchMatmulAttrs const &, + ParallelTensorShape const &, + ParallelTensorShape const &); CHECK_VALID_OP_ATTR(BatchMatmulAttrs); } // namespace FlexFlow diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 761c565657..3fabc86d1a 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -13,16 +13,11 @@ * limitations under the License. */ -// #include "batch_matmul.h" -// #include "kernels/batch_matmul_kernels.h" -// #include "kernels/profiling.h" -// #include "legion/legion_utilities.h" -// #include "tasks.h" - #include "batch_matmul.h" #include "kernels/batch_matmul_kernels.h" #include "legion.h" #include "op-attrs/ops/batch_matmul.h" +#include "op-attrs/get_output_shapes.h" namespace FlexFlow { @@ -232,9 +227,8 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, MachineView const &pc) { auto env = sim.new_environment(); - // @colin todo add get_output_shape and get_weights_shape to batch_matmul op-attrs - // ParallelTensorShape output_shape = get_output_shape(attrs, inputs); - // ParallelTensorShape weight_shape = get_weights_shape(attrs, inputs); + ParallelTensorShape output_shape = get_output_shape(attrs, a_input.shape, b_input.shape); + ParallelTensorShape weight_shape = get_weights_shape(attrs, a_input.shape, b_input.shape); SimTaskBinding init_binding; init_binding.bind_arg(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); @@ -249,8 +243,7 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, SimTaskBinding fwd_binding; fwd_binding.bind(A_INPUT, a_input); fwd_binding.bind(B_INPUT, b_input); - //@colin will uncomment after get_output_shape is done - // fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind(OUTPUT, output_shape); fwd_binding.bind_arg(PROFILING, settings); fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); From f620070ea8bc8e7373915a5063e1a5977ed7f433 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Wed, 23 Aug 2023 13:45:00 -0700 Subject: [PATCH 05/24] format --- .../include/kernels/batch_matmul_kernels.h | 15 +-- .../include/op-attrs/ops/batch_matmul.h | 4 +- lib/runtime/include/runtime/config.h | 3 +- lib/runtime/src/ops/batch_matmul.cc | 106 +++++++++--------- lib/runtime/src/ops/batch_matmul.h | 62 +++++----- lib/runtime/src/task_spec/op_arg_ref.h | 7 +- 6 files changed, 97 insertions(+), 100 deletions(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index 1506305e9e..23522a89f7 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -13,19 +13,16 @@ struct BMMPerDeviceState { req b_seq_length_dim; }; -FF_VISITABLE_STRUCT_NO_EQ(BMMPerDeviceState, - handle, - allocator, - a_seq_length_dim, - b_seq_length_dim); +FF_VISITABLE_STRUCT_NO_EQ( + BMMPerDeviceState, handle, allocator, a_seq_length_dim, b_seq_length_dim); namespace Kernels { namespace BatchMatmul { BMMPerDeviceState init_kernel(PerDeviceFFHandle handle, - Allocator allocator, - int a_seq_length_dim, - int b_seq_length_dim); + Allocator allocator, + int a_seq_length_dim, + int b_seq_length_dim); void forward_kernel(ffStream_t stream, BMMPerDeviceState const &meta, @@ -40,7 +37,7 @@ void forward_kernel(ffStream_t stream, int seq_length = -1); void backward_kernel(ffStream_t stream, - BMMPerDeviceState const &meta, + BMMPerDeviceState const &meta, float const *o_ptr, float const *o_grad_ptr, float const *a_ptr, diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index fa80ed3da9..6ac2a2cf69 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -15,8 +15,8 @@ FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim); int get_aSeqLengthDim(BatchMatmulAttrs const &attrs); int get_bSeqLengthDim(BatchMatmulAttrs const &attrs); ParallelTensorShape get_weights_shape(BatchMatmulAttrs const &, - ParallelTensorShape const &, - ParallelTensorShape const &); + ParallelTensorShape const &, + ParallelTensorShape const &); CHECK_VALID_OP_ATTR(BatchMatmulAttrs); } // namespace FlexFlow diff --git a/lib/runtime/include/runtime/config.h b/lib/runtime/include/runtime/config.h index fc45c2c76c..21cfca20cf 100644 --- a/lib/runtime/include/runtime/config.h +++ b/lib/runtime/include/runtime/config.h @@ -110,8 +110,7 @@ struct FFIterationConfig { req seq_length; }; -FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(FFIterationConfig, - seq_length); +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(FFIterationConfig, seq_length); enum FieldIDs { FID_DATA, diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 3fabc86d1a..0f4ad9e865 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -16,8 +16,8 @@ #include "batch_matmul.h" #include "kernels/batch_matmul_kernels.h" #include "legion.h" -#include "op-attrs/ops/batch_matmul.h" #include "op-attrs/get_output_shapes.h" +#include "op-attrs/ops/batch_matmul.h" namespace FlexFlow { @@ -29,16 +29,16 @@ using Legion::Runtime; using Legion::Task; enum Slots { - A_INPUT, //tensor - B_INPUT, //tensor - OUTPUT, //tensor + A_INPUT, // tensor + B_INPUT, // tensor + OUTPUT, // tensor PROFILING, HANDLE, A_SEQ_LENGTH_DIM, B_SEQ_LENGTH_DIM, PER_DEVICE_STATE, ITERATION_CONFIG - }; +}; OpTaskInvocation init(BatchMatmulAttrs const &attrs) { OpTaskBinding init; @@ -70,18 +70,16 @@ OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { return {BATCHMATMUL_BWD_TASK_ID, bwd}; } -static DeviceSpecificArg init_task_impl(TaskArgumentAccessor const &acc) { +static DeviceSpecificArg + init_task_impl(TaskArgumentAccessor const &acc) { auto const a_seq_length_dim = acc.get_argument(A_SEQ_LENGTH_DIM); auto const b_seq_length_dim = acc.get_argument(B_SEQ_LENGTH_DIM); PerDeviceFFHandle handle = acc.get_argument(HANDLE); Allocator allocator = acc.get_allocator(); - DeviceSpecificArg per_device_state = + DeviceSpecificArg per_device_state = acc.create_device_specific( - init_kernel(handle, - allocator, - a_seq_length_dim, - b_seq_length_dim)); + init_kernel(handle, allocator, a_seq_length_dim, b_seq_length_dim)); // assert(weight.shape.get_volume() * sizeof(float) == // acc.unwrap(per_device_state)->weightSize); @@ -107,7 +105,8 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto per_device_state = acc.get_argument(PER_DEVICE_STATE); - FFIterationConfig iter_config = acc.get_argument(ITERATION_CONFIG); + FFIterationConfig iter_config = + acc.get_argument(ITERATION_CONFIG); int m = b_input.shape[legion_dim_t(0)]; assert(m == output.shape[legion_dim_t(0)]); @@ -120,7 +119,8 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { assert(a_input.shape.size() == output.shape.size()); int batch = 1; - for (int i = 2; i < a_input.shape.get_dim(); i++) { //get_dim() or get_volume()? + for (int i = 2; i < a_input.shape.get_dim(); + i++) { // get_dim() or get_volume()? int dim_size = a_input.shape[legion_dim_t(i)]; assert(dim_size == b_input.shape[legion_dim_t(i)]); assert(dim_size == output.shape[legion_dim_t(i)]); @@ -128,18 +128,18 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { } return profile(forward_kernel, - profiling, - "[BatchMatmul] forward_time = %.2lfms\n", - per_device_state, - output.get_float_ptr(), - a_input.get_float_ptr(), - b_input.get_float_ptr(), - nullptr, //c_ptr - m, - n, - k, - batch, - iter_config.seq_length); + profiling, + "[BatchMatmul] forward_time = %.2lfms\n", + per_device_state, + output.get_float_ptr(), + a_input.get_float_ptr(), + b_input.get_float_ptr(), + nullptr, // c_ptr + m, + n, + k, + batch, + iter_config.seq_length); } static void forward_task(Task const *task, @@ -156,14 +156,15 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(task->regions.size() == 6); // BatchMatmul* bmm = (BatchMatmul*) task->args; - FFIterationConfig iter_config = acc.get_argument(ITERATION_CONFIG); + FFIterationConfig iter_config = + acc.get_argument(ITERATION_CONFIG); ProfilingSettings profiling = acc.get_argument(PROFILING); auto per_device_state = acc.get_argument(PER_DEVICE_STATE); auto output = acc.get_tensor(OUTPUT); auto output_grad = acc.get_tensor_grad(OUTPUT); // is this equivalent to checking `Domain` equality? - assert(output == output_grad); + assert(output == output_grad); auto a_input = acc.get_tensor(A_INPUT); auto a_input_grad = acc.get_tensor_grad(A_INPUT); @@ -183,7 +184,8 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(a_input.shape.size() == b_input.shape.size()); assert(a_input.shape.size() == output.shape.size()); int batch = 1; - for (int i = 2; i < a_input.shape.get_dim(); i++) { //@colin get_dim() or get_volume()? + for (int i = 2; i < a_input.shape.get_dim(); + i++) { //@colin get_dim() or get_volume()? int dim_size = a_input.shape[legion_dim_t(i)]; assert(dim_size == b_input.shape[legion_dim_t(i)]); assert(dim_size == output.shape[legion_dim_t(i)]); @@ -196,19 +198,19 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert((meta->b_seq_length_dim >= b_len) || (iter_config.seq_length == 0)); return profile(backward_kernel, - profiling, - "[BatchMatmul] backward_time = %.2lfms\n", - per_device_state, - output.get_float_ptr(), - output_grad.get_float_ptr(), - a_input.get_float_ptr(), - a_input_grad.get_float_ptr(), - b_input.get_float_ptr(), - b_input_grad.get_float_ptr(), - m, - n, - k, - batch); + profiling, + "[BatchMatmul] backward_time = %.2lfms\n", + per_device_state, + output.get_float_ptr(), + output_grad.get_float_ptr(), + a_input.get_float_ptr(), + a_input_grad.get_float_ptr(), + b_input.get_float_ptr(), + b_input_grad.get_float_ptr(), + m, + n, + k, + batch); } static void backward_task(Task const *task, @@ -220,15 +222,17 @@ static void backward_task(Task const *task, } CostMetrics measure_operator_cost(SimEnvFactory const &sim, - BatchMatmulAttrs const &attrs, - InputParallelTensorDesc const &a_input, - InputParallelTensorDesc const &b_input, - ProfilingSettings const &settings, - MachineView const &pc) { + BatchMatmulAttrs const &attrs, + InputParallelTensorDesc const &a_input, + InputParallelTensorDesc const &b_input, + ProfilingSettings const &settings, + MachineView const &pc) { auto env = sim.new_environment(); - ParallelTensorShape output_shape = get_output_shape(attrs, a_input.shape, b_input.shape); - ParallelTensorShape weight_shape = get_weights_shape(attrs, a_input.shape, b_input.shape); + ParallelTensorShape output_shape = + get_output_shape(attrs, a_input.shape, b_input.shape); + ParallelTensorShape weight_shape = + get_weights_shape(attrs, a_input.shape, b_input.shape); SimTaskBinding init_binding; init_binding.bind_arg(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); @@ -249,8 +253,10 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); - auto fwd_accessor = env.get_fwd_accessor(BATCHMATMUL_FWD_TASK_ID, fwd_binding); - auto bwd_accessor = env.get_bwd_accessor(BATCHMATMUL_BWD_TASK_ID, bwd_binding); + auto fwd_accessor = + env.get_fwd_accessor(BATCHMATMUL_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = + env.get_bwd_accessor(BATCHMATMUL_BWD_TASK_ID, bwd_binding); float forward_time = forward_task_impl(fwd_accessor).value(); float backward_time = backward_task_impl(bwd_accessor).value(); diff --git a/lib/runtime/src/ops/batch_matmul.h b/lib/runtime/src/ops/batch_matmul.h index ab9bb45e8d..42c9bc23de 100644 --- a/lib/runtime/src/ops/batch_matmul.h +++ b/lib/runtime/src/ops/batch_matmul.h @@ -24,11 +24,11 @@ OpTaskInvocation forward(BatchMatmulAttrs const &); OpTaskInvocation backward(BatchMatmulAttrs const &); CostMetrics measure_operator_cost(SimEnvFactory const &sim, - BatchMatmulAttrs const &attrs, - InputParallelTensorDesc const &a_input, - InputParallelTensorDesc const &b_input, - ProfilingSettings const &settings, - MachineView const &pc); + BatchMatmulAttrs const &attrs, + InputParallelTensorDesc const &a_input, + InputParallelTensorDesc const &b_input, + ProfilingSettings const &settings, + MachineView const &pc); /* class BatchMatmul : public Op { */ /* public: */ @@ -89,8 +89,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, #endif - - // BatchMatmulParams BatchMatmul::get_params() const { // BatchMatmulParams params; // params.a_seq_length_dim = inputs[0]->num_dims - 1 - this->a_seq_length_dim; @@ -242,15 +240,14 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, // return new BatchMatmul(ff, params, {inputs[0], inputs[1]}, this->name); // } - // void BatchMatmul::forward(FFModel const &ff) { // int dim = outputs[0]->num_dims; // switch (dim) { -// #define DIMFUNC(DIM) \ -// case DIM: { \ +// #define DIMFUNC(DIM) \ +// case DIM: { \ // // forward_with_dim(ff); -// this->execute_task(ff, BATCHMATMUL_FWD_TASK_ID, get_fwd_task_signature()); -// break; +// this->execute_task(ff, BATCHMATMUL_FWD_TASK_ID, +// get_fwd_task_signature()); break; // } // LEGION_FOREACH_N(DIMFUNC) // #undef DIMFUNC @@ -299,15 +296,14 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, output = A * B /////////+ C */ - // void BatchMatmul::init(FFModel const &ff) { // int dim = outputs[0]->num_dims; // switch (dim) { -// #define DIMFUNC(DIM) \ -// case DIM: { \ +// #define DIMFUNC(DIM) \ +// case DIM: { \ // // init_with_dim(ff); -// this->execute_task(ff, BATCHMATMUL_INIT_TASK_ID, get_init_task_signature()); -// break; +// this->execute_task(ff, BATCHMATMUL_INIT_TASK_ID, +// get_init_task_signature()); break; // } // LEGION_FOREACH_N(DIMFUNC) // #undef DIMFUNC @@ -365,7 +361,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, // return binding; // } - // static OpTaskSignature get_fwd_task_signature() { // OpTaskSignature fwd(OpTaskType::FWD); @@ -409,28 +404,25 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, // return binding; // } -//void BatchMatmul::backward(FFModel const &ff) { -// int dim = outputs[0]->num_dims; -// switch (dim) { -// #define DIMFUNC(DIM) \ -// case DIM: { \ -// backward_with_dim(ff); \ -// break; \ +// void BatchMatmul::backward(FFModel const &ff) { +// int dim = outputs[0]->num_dims; +// switch (dim) { +// #d ef ine DIMFUNC(DIM) \ +// case DIM: { \ +// backward_with_dim(ff); \ +// break; \ // } -// LEGION_FOREACH_N(DIMFUNC) -// #undef DIMFUNC -// default: -// assert(false); -// } -// } - +// LEGION_FOREACH_N(DIMFUNC) +// #undef DIMFUNC +// default: +// assert(false); +// } +// } // void BatchMatmul::print_layer(FFModel const &ff) { // return; // } - - /* regions[0](I): output regions[1](I): output_grad @@ -508,4 +500,4 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, regions[4](I): B regions[5](I/O): B_grad regions[6](I/O): C_grad -*/ \ No newline at end of file +*/ diff --git a/lib/runtime/src/task_spec/op_arg_ref.h b/lib/runtime/src/task_spec/op_arg_ref.h index bc4f38e5fb..07316c877b 100644 --- a/lib/runtime/src/task_spec/op_arg_ref.h +++ b/lib/runtime/src/task_spec/op_arg_ref.h @@ -6,10 +6,13 @@ #include "op-attrs/parallel_tensor_shape.h" #include "runtime/config.h" - namespace FlexFlow { -enum class OpArgRefType { PER_DEVICE_OP_STATE, PARALLEL_TENSOR_SHAPE, ITERATION_CONFIG }; +enum class OpArgRefType { + PER_DEVICE_OP_STATE, + PARALLEL_TENSOR_SHAPE, + ITERATION_CONFIG +}; template using OpArgRef = ArgRef; From 5458923bc344d2cefbd6bc18b7c3719e73291ba8 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Wed, 23 Aug 2023 15:23:36 -0700 Subject: [PATCH 06/24] fix DeviceSpecific --- lib/runtime/src/ops/batch_matmul.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 0f4ad9e865..59570f4417 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -70,14 +70,14 @@ OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { return {BATCHMATMUL_BWD_TASK_ID, bwd}; } -static DeviceSpecificArg +static DeviceSpecific init_task_impl(TaskArgumentAccessor const &acc) { auto const a_seq_length_dim = acc.get_argument(A_SEQ_LENGTH_DIM); auto const b_seq_length_dim = acc.get_argument(B_SEQ_LENGTH_DIM); PerDeviceFFHandle handle = acc.get_argument(HANDLE); Allocator allocator = acc.get_allocator(); - DeviceSpecificArg per_device_state = + DeviceSpecific per_device_state = acc.create_device_specific( init_kernel(handle, allocator, a_seq_length_dim, b_seq_length_dim)); @@ -86,7 +86,7 @@ static DeviceSpecificArg return per_device_state; } -static DeviceSpecificArg +static DeviceSpecific init_task(Task const *task, std::vector const ®ions, Context ctx, @@ -241,7 +241,7 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, auto init_accessor = env.get_init_accessor(BATCHMATMUL_INIT_TASK_ID, init_binding); - DeviceSpecificArg per_device_state = + DeviceSpecific per_device_state = init_task_impl(init_accessor); SimTaskBinding fwd_binding; From 2f604dca90edf8ecaef4dade5bbb6d0fc4ed612f Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Fri, 25 Aug 2023 18:57:40 -0700 Subject: [PATCH 07/24] batch_norm --- .../include/kernels/batch_norm_kernels.h | 70 +- lib/runtime/src/ops/batch_norm.cc | 700 ++++++------------ lib/runtime/src/ops/batch_norm.h | 242 +++++- 3 files changed, 537 insertions(+), 475 deletions(-) diff --git a/lib/kernels/include/kernels/batch_norm_kernels.h b/lib/kernels/include/kernels/batch_norm_kernels.h index 6ff90299db..74dfc96068 100644 --- a/lib/kernels/include/kernels/batch_norm_kernels.h +++ b/lib/kernels/include/kernels/batch_norm_kernels.h @@ -8,30 +8,66 @@ namespace FlexFlow { -class BatchNormPerDeviceState : public PerDeviceOpState { -public: - BatchNormPerDeviceState(FFHandler handle, - std::unique_ptr allocator, - int output_n, - int output_c, - int output_h, - int output_w, - bool relu, - bool profiling); - ~BatchNormPerDeviceState(void); - - ffTensorDescriptor_t inputTensor, outputTensor, biasTensor; +struct BatchNormPerDeviceState { + PerDeviceFFHandle handle; + Allocator allocator; + ffTensorDescriptor_t inputTensor; + ffTensorDescriptor_t outputTensor; + ffTensorDescriptor_t biasTensor; ffActivationDescriptor_t actiDesc; ffBatchNormMode_t mode; - float *runningMean, *runningVar, *saveMean, *saveVar; - bool relu; - bool profiling; - std::unique_ptr allocator; + float *runningMean; + float *runningVar; + float *saveMean; + float *saveVar; + int output_n; + int output_c; + int output_h; + int output_w; + ProfilingSettings profiling; + req relu; }; +FF_VISITABLE_STRUCT_NO_EQ(BatchNormPerDeviceState, + handle, + allocator, + inputTensor, + outputTensor, + biasTensor, + actiDesc, + mode, + runningMean, + runningVar, + saveMean, + saveVar, + output_n, + output_c, + output_h, + output_w, + profiling, + relu); + namespace Kernels { namespace BatchNorm { +BatchNormPerDeviceState init_kernel(PerDeviceFFHandle handle, + Allocator allocator, + ffTensorDescriptor_t inputTensor, + ffTensorDescriptor_t outputTensor, + ffTensorDescriptor_t biasTensor, + ffActivationDescriptor_t actiDesc, + ffBatchNormMode_t mode, + float *runningMean, + float *runningVar, + float *saveMean, + float *saveVar, + int output_n, + int output_c, + int output_h, + int output_w, + ProfilingSettings profiling, + bool relu); + void forward_kernel(ffStream_t stream, BatchNormPerDeviceState *m, float const *input_ptr, diff --git a/lib/runtime/src/ops/batch_norm.cc b/lib/runtime/src/ops/batch_norm.cc index 98cc4576a1..ffd52c96fb 100644 --- a/lib/runtime/src/ops/batch_norm.cc +++ b/lib/runtime/src/ops/batch_norm.cc @@ -16,505 +16,291 @@ #include "batch_norm.h" #include "kernels/batch_norm_kernels.h" #include "legion/legion_utilities.h" -#include "task_spec.h" + +namespace FlexFlow { using namespace FlexFlow::Kernels::BatchNorm; -namespace FlexFlow { +using Legion::Context; +using Legion::PhysicalRegion; +using Legion::Runtime; +using Legion::Task; enum Slots { - INPUT, - SCALE, - BIAS, - OUTPUT, - INPUT_GRAD, - SCALE_GRAD, - BIAS_GRAD, - OUTPUT_GRAD, + INPUT, // tensor + SCALE, // tensor + BIAS, // tensor + OUTPUT, // tensor ATTRS, - PROFILING -} - -Tensor - FFModel::batch_norm(const Tensor input, bool relu, char const *name) { - assert(input->num_dims == 4); /*NCHW*/ - Layer *bm = new Layer(this, - OP_BATCHNORM, - DT_FLOAT, - name, - 1 /*inputs*/, - 2 /*weights*/, - 1 /*outputs*/, - input); - int numdims = 4; - bm->outputs[0] = create_tensor_legion_ordering( - numdims, input->dims, DT_FLOAT, bm, 0, true /*create_grad*/); - bm->add_int_property("relu", relu); - layers.push_back(bm); - return bm->outputs[0]; -} - -/* - locals[0] = scale - locals[1] = bias -*/ -BatchNorm::BatchNorm(FFModel &model, - const ParallelTensor _input, - const ParallelTensor _scale, - const ParallelTensor _bias, - bool _relu, - char const *name) - : Op(model, - OP_BATCHNORM, - DT_FLOAT, - name, - 1 /*inputs*/, - 2 /*weights*/, - 1 /*outputs*/, - _input, - _scale, - _bias), - relu(_relu) { - assert(_input->num_dims == 4); - numOutputs = 1; - ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < _input->num_dims; i++) { - dims[i] = _input->dims[_input->num_dims - 1 - i]; - } - outputs[0] = - model.create_parallel_tensor(_input->num_dims, dims, DT_FLOAT, this); - return; -} - -static OpTaskSignature get_init_task_signature() { - OpTaskSignature init(OpTaskType::INIT); - - init.add_arg_slot(ATTRS); - init.add_arg_slot(PROFILING); - - // init.add_input_slot(INPUT); - // init.add_param_slot(SCALE); - // init.add_param_slot(BIAS); - init.add_output_slot(OUTPUT); -} + PROFILING, + PER_DEVICE_STATE, + RELU, + HANDLE +}; -static OpTaskSignature get_fwd_task_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_arg_slot(ATTRS); - - fwd.add_input_slot(INPUT); - fwd.add_param_slot(SCALE); - fwd.add_param_slot(BIAS); - fwd.add_output_slot(OUTPUT, WRITE_DISCARD); - - return fwd; -} - -static OpTaskSignature get_bwd_task_signature() { - OpTaskSignature bwd(OpTaskType::BWD); - - bwd.add_arg_slot(ATTRS); - - bwd.add_input_slot(INPUT); - bwd.add_input_grad_slot(INPUT_GRAD, READ_WRITE); - bwd.add_param_slot(SCALE); - bwd.add_param_grad_slot(SCALE_GRAD, READ_WRITE); - bwd.add_param_grad_slot(BIAS_GRAD, READ_WRITE); - bwd.add_output_grad_slot(OUTPUT_GRAD); - - return bwd; -} - -OpTaskBinding BatchNorm::get_init_task_binding() const { +OpTaskInvocation init(BatchNormAttrs const &attrs) { OpTaskBinding binding; - binding.bind_arg(ATTRS, this->attrs); - binding.bind_arg(PROFILING, this->profiling); - - // binding.bind(INPUT, input_tensor(0)); - // binding.bind(SCALE, param_tensor(0)); - // binding.bind(BIAS, param_tensor(1)); + binding.bind(INPUT, input_tensor(0)); + binding.bind(BIAS, input_tensor(2)); binding.bind(OUTPUT, output_tensor(0)); - return binding; + binding.bind_arg(ATTRS, attrs); + binding.bind_arg(PROFILING, profiling_settings()); + binding.bind_arg(HANDLE, ff_handle()); + + return {BATCHNORM_INIT_TASK_ID, binding}; } -OpTaskBinding BatchNorm::get_fwd_task_binding() const { +OpTaskInvocation forward(BatchNormAttrs const &attrs) { OpTaskBinding binding; - - binding.bind_arg(ATTRS, this->attrs); + binding.bind_arg(PROFILING, profiling_settings()); + binding.bind_arg(PER_DEVICE_STATE, + per_device_op_state()); binding.bind(INPUT, input_tensor(0)); - binding.bind(SCALE, param_tensor(0)); - binding.bind(BIAS, param_tensor(1)); + binding.bind(SCALE, input_tensor(1)); + binding.bind(BIAS, input_tensor(2)); binding.bind(OUTPUT, output_tensor(0)); - return binding; + return {BATCHNORM_FWD_TASK_ID, binding}; } -OpTaskBinding BatchNorm::get_bwd_task_binding() const { - OpTaskBinding binding; +OpTaskInvocation backward(BatchNormAttrs const &attrs) { + OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - binding.bind_arg(ATTRS, this->attrs); + return {BATCHNORM_BWD_TASK_ID, binding}; +} - binding.bind(INPUT, input_tensor(0)); - binding.bind(INPUT_GRAD, input_tensor(0).grad()); - binding.bind(SCALE, param_tensor(0)); - binding.bind(SCALE_GRAD, param_tensor(0).grad()); - binding.bind(BIAS_GRAD, param_tensor(1).grad()); - binding.bind(OUTPUT_GRAD, output_tensor(0).grad()); +static DeviceSpecific + init_task_impl(TaskArgumentAccessor const &acc) { + Allocator allocator = acc.get_allocator(); + PerDeviceFFHandle handle = acc.get_argument(HANDLE); + ProfilingSettings profiling = acc.get_argument(PROFILING); + auto output = acc.get_tensor(OUTPUT); + auto const &attrs = acc.get_argument(ATTRS); + + int output_w = output.shape[legion_dim_t(0)]; + int output_h = output.shape[legion_dim_t(1)]; + int output_c = output.shape[legion_dim_t(2)]; + int output_n = output.shape[legion_dim_t(3)]; + + ffTensorDescriptor_t inputTensor; + ffTensorDescriptor_t outputTensor; + ffTensorDescriptor_t biasTensor; + ffActivationDescriptor_t actiDesc; + ffBatchNormMode_t mode; + + size_t totalSize = sizeof(float) * output_c * 4; + float *runningMean = (float *)allocator.allocate(totalSize); + float *runningVar = (float *)runningMean + output_c; + float *saveMean = (float *)runningVar + output_c; + float *saveVar = (float *)saveMean + output_c; + + DeviceSpecific per_device_state = + acc.create_device_specific( + init_kernel(handle, + allocator, + inputTensor, + outputTensor, + biasTensor, + actiDesc, + mode, + runningMean, + runningVar, + saveMean, + saveVar, + output_n, + output_c, + output_h, + output_w, + profiling, + attrs.relu)); + + return per_device_state; +} - return binding; +static DeviceSpecific + init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + return init_task_impl(acc); } -void BatchNorm::init(FFModel const &ff) { - this->execute_task(ff, BATCHNORM_INIT_TASK_ID, get_init_task_signature()); - // assert(check_output_input_weight_same_parallel_is()); - // parallel_is = outputs[0]->parallel_is; - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_init(ff, argmap); - // IndexLauncher launcher(BATCHNORM_INIT_TASK_ID, - // parallel_is, - // TaskArgument(this, sizeof(BatchNorm)), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(1, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(weights[0]->region, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // weights[0]->region)); - // launcher.add_field(2, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(weights[1]->region, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // weights[1]->region)); - // launcher.add_field(3, FID_DATA); - // FutureMap fm = runtime->execute_index_space(ctx, launcher); - // fm.wait_all_results(); - // set_opmeta_from_futuremap(ff, fm); +static optional forward_task_impl(TaskArgumentAccessor const &acc) { + assert(regions.size() == 4); + assert(task->regions.size() == 4); + + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + ProfilingSettings profiling = acc.get_argument(PROFILING); + + auto input = acc.get_tensor(INPUT); + auto output = acc.get_tensor(OUTPUT); + auto scale = acc.get_tensor(SCALE); + auto bias = acc.get_tensor(SCALE); + + return profile(forward_kernel, + profiling, + "[BatchNorm] forward_time = %.2lfms\n", + &per_device_state, + input.get_float_ptr(), + output.get_float_ptr(), + scale.get_float_ptr(), + bias.get_float_ptr()); } -/* - regions[0]: input - regions[1]: output - regions[2](I): scale - regions[3](I): bias -*/ -PerDeviceOpState * - BatchNorm::init_task(Task const *task, +static void forward_task(Task const *task, std::vector const ®ions, Context ctx, Runtime *runtime) { - assert(regions.size() == 4); - assert(task->regions.size() == 4); TaskArgumentAccessor acc(task, regions, ctx, runtime); - FFHandler handle = *((FFHandler const *)task->local_args); - - auto output = acc.get_tensor(OUTPUT); - - int output_w = output.shape[0]; - int output_h = output.shape[1]; - int output_c = output.shape[2]; - int output_n = output.shape[3]; - - Memory gpu_mem = Machine::MemoryQuery(Machine::get_machine()) - .only_kind(Memory::GPU_FB_MEM) - .best_affinity_to(task->target_proc) - .first(); - BatchNormPerDeviceState *m = new BatchNormPerDeviceState( - handle, bm, gpu_mem, output_n, output_c, output_h, output_w); - return m; + forward_task_impl(acc); } -void BatchNorm::forward(FFModel const &ff) { - this->execute_task(ff, BATCHNORM_FWD_TASK_ID, get_fwd_task_signature()); - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_forward(ff, argmap); - // IndexLauncher launcher(BATCHNORM_FWD_TASK_ID, - // parallel_is, - // TaskArgument(NULL, 0), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // WRITE_DISCARD, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(1, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(weights[0]->region, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // weights[0]->region)); - // launcher.add_field(2, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(weights[1]->region, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // weights[1]->region)); - // launcher.add_field(3, FID_DATA); - - // runtime->execute_index_space(ctx, launcher); +static optional backward_task_impl(TaskArgumentAccessor const &acc) { + assert(regions.size() == 7); + assert(task->regions.size() == 7); + + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + ProfilingSettings profiling = acc.get_argument(PROFILING); + + auto input = acc.get_tensor(INPUT); + auto input_grad = acc.get_tensor_grad(INPUT); + auto output = acc.get_tensor(OUTPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); + auto scale = acc.get_tensor(SCALE); + auto scale_grad = acc.get_tensor_grad(SCALE); + auto bias_grad = acc.get_tensor_grad(BIAS); + + return profile(backward_kernel, + profiling, + "[BatchNorm] backward_time = %.2lfms\n", + &per_device_state, + input.get_float_ptr(), + output_grad.get_float_ptr(), + output.get_float_ptr(), + input_grad.get_float_ptr(), + scale.get_float_ptr(), + scale_grad.get_float_ptr(), + bias_grad.get_float_ptr(), + output.shape.get_volume()); } -/* - regions[0](I): input - regions[1](O): ouptut - regions[2](I): scale - regions[3](I): bias -*/ -void BatchNorm::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 4); - assert(task->regions.size() == 4); - // const BatchNorm* bm = (BatchNorm*) task->args; +static void backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { TaskArgumentAccessor acc(task, regions, ctx, runtime); - BatchNormPerDeviceState *m = *((BatchNormPerDeviceState **)task->local_args); - - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - auto scale = acc.get_tensor(SCALE); - auto bias = acc.get_tensor(SCALE); - - profile(forward_kernel, - m->profiling, - "[BatchNorm] forward_time = %.2lfms\n", - m, - input.get_float_ptr(), - output.get_float_ptr(), - scale.get_float_ptr(), - bias.get_float_ptr()); + backward_task_impl(acc); } -void BatchNorm::backward(FFModel const &ff) { - this->execute_task(ff, BATCHNORM_BWD_TASK_ID, get_bwd_task_signature()); - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_backward(ff, argmap); - // IndexLauncher launcher(BATCHNORM_BWD_TASK_ID, - // parallel_is, - // TaskArgument(NULL, 0), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // // regions[0](I): input - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // // regions[1](I/O): input_grad (we only need grad tensors) - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, - // 0 /*projection id*/, - // READ_WRITE, - // EXCLUSIVE, - // inputs[0]->region_grad)); - // launcher.add_field(1, FID_DATA); - // // regions[2](I): output - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(2, FID_DATA); - // // regions[3](I/O): output_grad - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, - // 0 /*projection id*/, - // READ_WRITE, - // EXCLUSIVE, - // outputs[0]->region_grad)); - // launcher.add_field(3, FID_DATA); - // // regions[4](I): filter - // launcher.add_region_requirement(RegionRequirement(weights[0]->region, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // weights[0]->region)); - // launcher.add_field(4, FID_DATA); - // // regions[5](I/O): filter_grad - // launcher.add_region_requirement(RegionRequirement(weights[0]->part_grad, - // 0 /*projection id*/, - // READ_WRITE, - // EXCLUSIVE, - // weights[0]->region_grad)); - // launcher.add_field(5, FID_DATA); - // // regions[6](I/O): bias_grad - // launcher.add_region_requirement(RegionRequirement(weights[1]->part_grad, - // 0 /*projection id*/, - // READ_WRITE, - // EXCLUSIVE, - // weights[1]->region_grad)); - // launcher.add_field(6, FID_DATA); - // FutureMap fm = runtime->execute_index_space(ctx, launcher); +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + BatchNormAttrs const &attrs, + InputParallelTensorDesc const &input_shape, + InputParallelTensorDesc const &scale_shape, + InputParallelTensorDesc const &bias_shape, + ProfilingSettings const &settings, + MachineView const &mv) { + + // int output_w = sub_output.dims[0].size; + // int output_h = sub_output.dims[1].size; + // int output_c = sub_output.dims[2].size; + // int output_n = sub_output.dims[3].size; + // BatchNormPerDeviceState *m = new BatchNormPerDeviceState( + // sim->handler, this, sim->memory, output_n, output_c, output_h, + // output_w); + + // sim->free_all(); + // float *input_ptr = (float *)sim->allocate(sub_input.get_volume(), + // DT_FLOAT); assert(input_ptr != NULL); cost_metrics.inputs_memory += + // cost_metrics.total_mem_diff_from(sim->offset); + + // float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), + // DT_FLOAT); assert(output_ptr != NULL); cost_metrics.outputs_memory += + // cost_metrics.total_mem_diff_from(sim->offset); + + // float *bias_ptr = (float *)sim->allocate(output_c, DT_FLOAT); + // assert(bias_ptr != NULL); + // float *scale_ptr = (float *)sim->allocate(output_c, DT_FLOAT); + // assert(scale_ptr != NULL); + // cost_metrics.weights_memory += + // cost_metrics.total_mem_diff_from(sim->offset); + + auto env = sim.new_environment(); + + ParallelTensorShape output_shape = get_output_shape(attrs); + + SimTaskBinding init_binding; + init_binding.bind(INPUT, input_shape); + init_binding.bind(BIAS, bias_shape); + init_binding.bind(OUTPUT, output_shape); + + init_binding.bind_arg(ATTRS, attrs); + init_binding.bind_arg(PROFILING, settings); + init_binding.bind_arg(HANDLE, ff_handle()); + + auto init_accessor = + env.get_init_accessor(ATTENTION_INIT_TASK_ID, init_binding); + DeviceSpecific per_device_state = + init_task_impl(init_accessor); + + SimTaskBinding fwd_binding; + fwd_binding.bind(INPUT, input_shape); + fwd_binding.bind(SCALE, scale_shape); + fwd_binding.bind(BIAS, bias_shape); + fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind_arg(PROFILING, settings); + fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); + + SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); + + auto fwd_accessor = env.get_fwd_accessor(ATTENTION_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = env.get_bwd_accessor(ATTENTION_BWD_TASK_ID, bwd_binding); + + float forward_time = forward_task_impl(fwd_accessor).value(); + float backward_time = backward_task_impl(bwd_accessor).value(); + + float sync_time = default_estimate_sync_time(env); + return make_metrics(forward_time, backward_time, sync_time, env); } -/* - regions[0](I): input - regions[1](I/O): input_grad - regions[2](I): output - regions[3](I/O): output_grad - regions[4](I): scale - regions[5](I/O): scale_grad - regions[6](I/O): bias_grad -*/ -__host__ void - BatchNorm::backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 7); - assert(task->regions.size() == 7); - // float beta = 0.0f; - // const BatchNorm* bm = (BatchNorm*) task->args; - TaskArgumentAccessor acc(task, regions, ctx, runtime); - BatchNormPerDeviceState *m = *((BatchNormPerDeviceState **)task->local_args); - - auto input = acc.get_tensor(INPUT); - auto input_grad = acc.get_tensor_grad(INPUT_GRAD); - auto output = acc.get_tensor(OUTPUT); - auto output_grad = acc.get_tensor_grad(OUTPUT_GRAD); - auto scale = acc.get_tensor(SCALE); - auto scale_grad = acc.get_tensor_grad(SCALE_GRAD); - auto bias_grad = acc.get_tensor_grad(BIAS_GRAD); - - profile(backward_kernel, - m->profiling, - "[BatchNorm] backward_time = %.2lfms\n", - m, - input.get_float_ptr(), - output_grad.get_float_ptr(), - output.get_float_ptr(), - input_grad.get_float_ptr(), - scale.get_float_ptr(), - scale_grad.get_float_ptr(), - bias_grad.get_float_ptr(), - output.get_volume()); +template <> +void register_task() { + OpTaskSignature init(OpTaskType::INIT); + init.add_input_slot(INPUT); + init.add_input_slot(BIAS); + init.add_output_slot(OUTPUT); + init.add_arg_slot(ATTRS); + init.add_arg_slot(PROFILING); + init.add_unchecked_arg_slot(HANDLE); + + register_task(BATCHNORM_INIT_TASK_ID, "BatchNorm Init", init, init_task); } -bool BatchNorm::measure_operator_cost(Simulator *sim, - MachineView const &mv, - CostMetrics &cost_metrics) const { - ParallelTensorBase sub_input, sub_output; - if (!outputs[0]->get_sub_tensor(mv, sub_output)) { - return false; - } - if (!inputs[0]->get_sub_tensor(mv, sub_input)) { - return false; - } - - int output_w = sub_output.dims[0].size; - int output_h = sub_output.dims[1].size; - int output_c = sub_output.dims[2].size; - int output_n = sub_output.dims[3].size; - BatchNormPerDeviceState *m = new BatchNormPerDeviceState( - sim->handler, this, sim->memory, output_n, output_c, output_h, output_w); - - sim->free_all(); - float *input_ptr = (float *)sim->allocate(sub_input.get_volume(), DT_FLOAT); - assert(input_ptr != NULL); - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - assert(output_ptr != NULL); - cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *bias_ptr = (float *)sim->allocate(output_c, DT_FLOAT); - assert(bias_ptr != NULL); - float *scale_ptr = (float *)sim->allocate(output_c, DT_FLOAT); - assert(scale_ptr != NULL); - cost_metrics.weights_memory += cost_metrics.total_mem_diff_from(sim->offset); - - std::function forward, backward; - forward = [&](ffStream_t stream) { - forward_kernel(stream, m, input_ptr, output_ptr, scale_ptr, bias_ptr); - }; - if (sim->computationMode == COMP_MODE_TRAINING) { - float *input_grad_ptr = - (float *)sim->allocate(sub_input.get_volume(), DT_FLOAT); - assert(input_grad_ptr != NULL); - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *output_grad_ptr = - (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - assert(output_grad_ptr != NULL); - cost_metrics.outputs_memory += - cost_metrics.total_mem_diff_from(sim->offset); - - float *scale_grad_ptr = (float *)sim->allocate(output_c, DT_FLOAT); - assert(scale_grad_ptr != NULL); - float *bias_grad_ptr = (float *)sim->allocate(output_c, DT_FLOAT); - assert(bias_grad_ptr != NULL); - cost_metrics.weights_memory += - cost_metrics.total_mem_diff_from(sim->offset); - - backward = [&](ffStream_t stream) { - backward_kernel(stream, - m, - input_ptr, - output_grad_ptr, - output_ptr, - input_grad_ptr, - scale_ptr, - scale_grad_ptr, - bias_grad_ptr, - sub_output.get_volume()); - }; - } - - inner_measure_operator_cost(sim, forward, backward, cost_metrics); - - if (sim->computationMode == COMP_MODE_TRAINING) { - printf("[Measure BatchNorm] name(%s) size(%zu) forward_time(%.4lf) " - "backward_time(%.4lf)\n", - name, - sub_input.get_volume(), - cost_metrics.forward_time, - cost_metrics.backward_time); - } else { - printf("[Measure BatchNorm] name(%s) size(%zu) forward_time(%.4lf)\n", - name, - sub_input.get_volume(), - cost_metrics.forward_time); - } - // Free batchnormmeta - delete m; - return true; +template <> +void register_task() { + OpTaskSignature fwd(OpTaskType::FWD); + + fwd.add_input_slot(INPUT); + fwd.add_input_slot(SCALE); + fwd.add_input_slot(BIAS); + fwd.add_output_slot(OUTPUT); + fwd.add_arg_slot(PROFILING); + fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); + + register_task(BATCHNORM_FWD_TASK_ID, "BatchNorm Fwd", fwd, forward_task); +} + +template <> +void register_task() { + OpTaskSignature bwd = + infer_bwd_signature(get_op_signature(BATCHNORM_FWD_TASK_ID)); + + register_task(BATCHNORM_BWD_TASK_ID, "BatchNorm Bwd", bwd, backward_task); } }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/batch_norm.h b/lib/runtime/src/ops/batch_norm.h index e54331665e..94bda5122b 100644 --- a/lib/runtime/src/ops/batch_norm.h +++ b/lib/runtime/src/ops/batch_norm.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_BATCH_NORM_H #include "op-attrs/ops/batch_norm.h" -#include "op_task_invocation.h" #include "sim_environment.h" +#include "task_spec/op_task_invocation.h" namespace FlexFlow { @@ -66,3 +66,243 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, } // namespace FlexFlow #endif + +// void BatchNorm::init(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_INIT_TASK_ID, get_init_task_signature()); +// } + +// void BatchNorm::forward(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_FWD_TASK_ID, get_fwd_task_signature()); +// } + +// void BatchNorm::backward(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_BWD_TASK_ID, get_bwd_task_signature()); +// } + +// Tensor batch_norm(const Tensor input, bool relu, char const *name) { +// assert(input->num_dims == 4); /*NCHW*/ +// Layer *bm = new Layer(this, +// OP_BATCHNORM, +// DT_FLOAT, +// name, +// 1 /*inputs*/, +// 2 /*weights*/, +// 1 /*outputs*/, +// input); +// int numdims = 4; +// bm->outputs[0] = create_tensor_legion_ordering( +// numdims, input->dims, DT_FLOAT, bm, 0, true /*create_grad*/); +// bm->add_int_property("relu", relu); +// layers.push_back(bm); +// return bm->outputs[0]; +// } + +// BatchNorm::BatchNorm(FFModel &model, +// const ParallelTensor _input, +// const ParallelTensor _scale, +// const ParallelTensor _bias, +// bool _relu, +// char const *name) +// : Op(model, +// OP_BATCHNORM, +// DT_FLOAT, +// name, +// 1 /*inputs*/, +// 2 /*weights*/, +// 1 /*outputs*/, +// _input, +// _scale, +// _bias), +// relu(_relu) { +// assert(_input->num_dims == 4); +// numOutputs = 1; +// ParallelDim dims[MAX_TENSOR_DIM]; +// for (int i = 0; i < _input->num_dims; i++) { +// dims[i] = _input->dims[_input->num_dims - 1 - i]; +// } +// outputs[0] = +// model.create_parallel_tensor(_input->num_dims, dims, DT_FLOAT, this); +// return; +// } + +/* + locals[0] = scale + locals[1] = bias +*/ + +// void BatchNorm::init(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_INIT_TASK_ID, get_init_task_signature()); +// assert(check_output_input_weight_same_parallel_is()); +// parallel_is = outputs[0]->parallel_is; +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_init(ff, argmap); +// IndexLauncher launcher(BATCHNORM_INIT_TASK_ID, +// parallel_is, +// TaskArgument(this, sizeof(BatchNorm)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(1, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(weights[0]->region, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// weights[0]->region)); +// launcher.add_field(2, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(weights[1]->region, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// weights[1]->region)); +// launcher.add_field(3, FID_DATA); +// FutureMap fm = runtime->execute_index_space(ctx, launcher); +// fm.wait_all_results(); +// set_opmeta_from_futuremap(ff, fm); +// } + +/* + regions[0]: input + regions[1]: output + regions[2](I): scale + regions[3](I): bias +*/ + +// void BatchNorm::forward(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_FWD_TASK_ID, get_fwd_task_signature()); +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_forward(ff, argmap); +// IndexLauncher launcher(BATCHNORM_FWD_TASK_ID, +// parallel_is, +// TaskArgument(NULL, 0), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_DISCARD, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(1, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(weights[0]->region, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// weights[0]->region)); +// launcher.add_field(2, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(weights[1]->region, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// weights[1]->region)); +// launcher.add_field(3, FID_DATA); + +// runtime->execute_index_space(ctx, launcher); +// } + +/* + regions[0](I): input + regions[1](O): ouptut + regions[2](I): scale + regions[3](I): bias +*/ + +// void BatchNorm::backward(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_BWD_TASK_ID, get_bwd_task_signature()); +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_backward(ff, argmap); +// IndexLauncher launcher(BATCHNORM_BWD_TASK_ID, +// parallel_is, +// TaskArgument(NULL, 0), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// // regions[0](I): input +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// // regions[1](I/O): input_grad (we only need grad tensors) +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// inputs[0]->region_grad)); +// launcher.add_field(1, FID_DATA); +// // regions[2](I): output +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(2, FID_DATA); +// // regions[3](I/O): output_grad +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// outputs[0]->region_grad)); +// launcher.add_field(3, FID_DATA); +// // regions[4](I): filter +// launcher.add_region_requirement(RegionRequirement(weights[0]->region, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// weights[0]->region)); +// launcher.add_field(4, FID_DATA); +// // regions[5](I/O): filter_grad +// launcher.add_region_requirement(RegionRequirement(weights[0]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// weights[0]->region_grad)); +// launcher.add_field(5, FID_DATA); +// // regions[6](I/O): bias_grad +// launcher.add_region_requirement(RegionRequirement(weights[1]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// weights[1]->region_grad)); +// launcher.add_field(6, FID_DATA); +// FutureMap fm = runtime->execute_index_space(ctx, launcher); +// } + +/* + regions[0](I): input + regions[1](I/O): input_grad + regions[2](I): output + regions[3](I/O): output_grad + regions[4](I): scale + regions[5](I/O): scale_grad + regions[6](I/O): bias_grad +*/ From a32f4e55b34d400b4de8a7daed52f2390d116cc5 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Mon, 28 Aug 2023 10:53:23 -0700 Subject: [PATCH 08/24] cast op --- lib/kernels/include/kernels/cast_kernels.h | 17 +- lib/runtime/src/ops/cast.cc | 501 +++++---------------- lib/runtime/src/ops/cast.h | 267 ++++++++++- 3 files changed, 394 insertions(+), 391 deletions(-) diff --git a/lib/kernels/include/kernels/cast_kernels.h b/lib/kernels/include/kernels/cast_kernels.h index d43446883c..28985f5501 100644 --- a/lib/kernels/include/kernels/cast_kernels.h +++ b/lib/kernels/include/kernels/cast_kernels.h @@ -3,19 +3,26 @@ #include "kernels/accessor.h" #include "kernels/device.h" -#include "op-attrs/ffconst.h" namespace FlexFlow { -class CastPerDeviceState : public PerDeviceOpState { -public: - CastPerDeviceState(FFHandler handle); - DataType input_data_type, output_data_type; +struct CastPerDeviceState { + PerDeviceFFHandle handle; + DataType input_data_type; + req output_data_type; }; +FF_VISITABLE_STRUCT_NO_EQ(CastPerDeviceState, + handle, + input_data_type, + output_data_type); + namespace Kernels { namespace Cast { +CastPerDeviceState + init_kernel(PerDeviceFFHandle const &, DataType input, DataType output); + void forward_kernel(ffStream_t stream, CastPerDeviceState const *, GenericTensorAccessorR const &input, diff --git a/lib/runtime/src/ops/cast.cc b/lib/runtime/src/ops/cast.cc index 23c1bc9940..36afdefcef 100644 --- a/lib/runtime/src/ops/cast.cc +++ b/lib/runtime/src/ops/cast.cc @@ -16,441 +16,178 @@ #include "cast.h" #include "kernels/cast_kernels.h" #include "legion/legion_utilities.h" -#include "task_spec.h" #include "utils/hash-utils.h" using namespace FlexFlow::Kernels::Cast; -namespace FlexFlow { - -enum Slots { - INPUT, - OUTPUT, - INPUT_GRAD, - OUTPUT_GRAD, - ATTRS, - PROFILING -} - -// declare Legion names -using Legion::ArgumentMap; using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; using Legion::Runtime; using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; - -Tensor FFModel::cast(const Tensor input, DataType dtype, char const *name) { - Layer *cast = new Layer(this, - OP_CAST, - dtype, - name, - 1 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - input); - int numdims = input->num_dims; - int dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdims; i++) { - dims[i] = input->dims[i]; - } - cast->outputs[0] = create_tensor_legion_ordering( - numdims, dims, dtype, cast, 0, true /*create_grad*/); - cast->add_int_property("dtype", dtype); - layers.push_back(cast); - return cast->outputs[0]; -} - -Op *Cast::create_operator_from_layer( - FFModel &model, - Layer const *layer, - std::vector const &inputs) { - long long value; - layer->get_int_property("dtype", value); - DataType dtype = (DataType)value; - return new Cast(model, inputs[0], dtype, layer->name); -} - -CastParams Cast::get_params() const { - CastParams params; - params.dtype = this->outputs[0]->data_type; - return params; -} - -Cast::Cast(FFModel &model, - ParallelTensor const &input, - DataType _dtype, - char const *name) - : Op(model, - OP_CAST, - _dtype, - name, - 1 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - input) { - numOutputs = 1; - numWeights = 0; - int numdim = input->num_dims; - ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdim; i++) { - dims[i] = input->dims[i]; - } - outputs[0] = - model.create_parallel_tensor_legion_ordering(numdim, dims, _dtype, this); -} - -Cast::Cast(FFModel &model, - CastParams const ¶ms, - ParallelTensor const &input, - char const *name) - : Cast(model, input, params.dtype, name) {} - -static OpTaskSignature get_init_task_signature() { - OpTaskSignature init(OpTaskType::INIT); - - init.add_arg_slot(ATTRS); - init.add_arg_slot(PROFILING); - - init.add_input_slot(INPUT); - init.add_output_slot(OUTPUT); - - return init; -} - -static OpTaskSignature get_fwd_task_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_arg_slot(ATTRS); - - fwd.add_input_slot(INPUT); - fwd.add_output_slot(OUTPUT); - - return init; -} - -static OpTaskSignature get_bwd_task_signature() { - OpTaskSignature bwd(OpTaskType::BWD); - - bwd.add_arg_slot(ATTRS); - bwd.add_input_grad_slot(INPUT_GRAD); - bwd.add_output_grad_slot(OUTPUT_GRAD); +namespace FlexFlow { - return bwd; -} +enum Slots { INPUT, OUTPUT, ATTRS, PROFILING, PER_DEVICE_STATE, HANDLE }; -OpTaskBinding Cast::get_init_task_binding() const { +// declare Legion names +// using Legion::ArgumentMap; +// using Legion::Context; +// using Legion::coord_t; +// using Legion::Domain; +// using Legion::FutureMap; +// using Legion::IndexLauncher; +// using Legion::PhysicalRegion; +// using Legion::Predicate; +// using Legion::Rect; +// using Legion::RegionRequirement; +// using Legion::Runtime; +// using Legion::Task; +// using Legion::TaskArgument; +// using Legion::TaskLauncher; + +OpTaskInvocation init(CastAttrs const &attrs) { OpTaskBinding binding; - binding.bind_arg(PROFILING, this->profiling); - binding.bind_arg(ATTRS, this->attrs); + binding.bind_arg(HANDLE, ff_handle()); binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return binding; + return {CAST_INIT_TASK_ID, binding}; } -OpTaskBinding Cast::get_fwd_task_binding() const { +OpTaskInvocation forward(CastAttrs const &attrs) { OpTaskBinding binding; - binding.bind_arg(ATTRS, this->attrs); + binding.bind_arg(PER_DEVICE_STATE, per_device_op_state()); + binding.bind_arg(PROFILING, profiling_settings()); binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return binding; + return {CAST_FWD_TASK_ID, binding}; } -OpTaskBinding Cast::get_bwd_task_binding() const { - OpTaskBinding binding; +OpTaskInvocation backward(CastAttrs const &attrs) { + OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - binding.bind_arg(ATTRS, this->attrs); + return {CAST_BWD_TASK_ID, binding}; +} - binding.bind(INPUT_GRAD, input_tensor(0).grad()); - binding.bind(OUTPUT_GRAD, output_tensor(0).grad()); +static DeviceSpecific + init_task_impl(TaskArgumentAccessor const &acc) { - return binding; -} + PerDeviceFFHandle handle = acc.get_argument(HANDLE); + auto input = acc.get_tensor(INPUT); + auto output = acc.get_tensor(OUTPUT); -void Cast::init(FFModel const &ff) { - this->execute_task(ff, CAST_INIT_TASK_ID, get_init_task_signature()); - // assert(check_output_input_weight_same_parallel_is()); - // parallel_is = outputs[0]->parallel_is; - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_init(ff, argmap); - // IndexLauncher launcher(CAST_INIT_TASK_ID, - // parallel_is, - // TaskArgument(this, sizeof(Cast)), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[0]->region)); - // launcher.add_field(1, FID_DATA); - // FutureMap fm = runtime->execute_index_space(ctx, launcher); - // fm.wait_all_results(); - // set_opmeta_from_futuremap(ff, fm); + DeviceSpecific per_device_state = + acc.create_device_specific( + init_kernel(handle, input.data_type, output.data_type)); + return per_device_state; } -PerDeviceOpState *Cast::init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { +static DeviceSpecific + init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { TaskArgumentAccessor acc(task, regions, ctx, runtime); - - FFHandler handler = *((FFHandler const *)task->local_args); - CastPerDeviceState *m = new CastPerDeviceState(handler); - bool profiling = acc.get_argument(PROFILING); - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - m->input_data_type = input->data_type; - m->output_data_type = output->data_type; - m->profiling = profiling; - return m; + return init_task_impl(acc); } -void Cast::forward(FFModel const &ff) { - this->execute_task(ff, CAST_FWD_TASK_ID, get_fwd_task_signature()); - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_forward(ff, argmap); - // IndexLauncher launcher(CAST_FWD_TASK_ID, - // parallel_is, - // TaskArgument(NULL, false), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(1, FID_DATA); - // runtime->execute_index_space(ctx, launcher); -} +static optional forward_task_impl(TaskArgumentAccessor const &acc) { + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + ProfilingSettings profiling = acc.get_argument(PROFILING); -// template -// void Cast::forward_task_with_1_type(Task const *task, -// std::vector const -// ®ions, Context ctx, Runtime *runtime) -// { -// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); -// if (m->output_data_type == DT_FLOAT) { -// Cast::forward_task_with_2_type(task, regions, ctx, runtime); -// } else if (m->output_data_type == DT_DOUBLE) { -// Cast::forward_task_with_2_type(task, regions, ctx, runtime); -// } else if (m->output_data_type == DT_INT32) { -// Cast::forward_task_with_2_type(task, regions, ctx, -// runtime); -// } else if (m->output_data_type == DT_INT64) { -// Cast::forward_task_with_2_type(task, regions, ctx, -// runtime); -// } -// } - -// template -// void Cast::forward_task_with_2_type(Task const *task, -// std::vector const -// ®ions, Context ctx, Runtime *runtime) -// { -// assert(regions.size() == 2); -// assert(task->regions.size() == regions.size()); -// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); -// // Domain input_domain = runtime->get_index_space_domain( -// // ctx, task->regions[0].region.get_index_space()); -// Domain output_domain = runtime->get_index_space_domain( -// ctx, task->regions[1].region.get_index_space()); -// const IDT *input_ptr = helperGetTensorPointerRO( -// regions[0], task->regions[0], FID_DATA, ctx, runtime); -// ODT *output_ptr = helperGetTensorPointerWO( -// regions[1], task->regions[1], FID_DATA, ctx, runtime); -// forward_kernel_wrapper( -// m, input_ptr, output_ptr, output_domain.get_volume()); -// } - -void Cast::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); - // if (m->input_data_type == DT_FLOAT) { - // Cast::forward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->input_data_type == DT_DOUBLE) { - // Cast::forward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->input_data_type == DT_INT32) { - // Cast::forward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->input_data_type == DT_INT64) { - // Cast::forward_task_with_1_type(task, regions, ctx, runtime); - // } - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - - profile(forward_kernel, - m->profiling, - "[Cast] forward_time = %.2lfms\n", - m, - input, - output) -} + auto input = acc.get_tensor(INPUT); + auto output = acc.get_tensor(OUTPUT); -void Cast::backward(FFModel const &ff) { - this->execute_task(ff, CAST_BWD_TASK_ID, get_bwd_task_signature()); - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_backward(ff, argmap); - // IndexLauncher launcher(CAST_BWD_TASK_ID, - // parallel_is, - // TaskArgument(NULL, false), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // outputs[0]->region_grad)); - // launcher.add_field(0, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // inputs[0]->region_grad)); - // launcher.add_field(1, FID_DATA); - // runtime->execute_index_space(ctx, launcher); + return profile(forward_kernel, + profiling, + "[Cast] forward_time = %.2lfms\n", + &per_device_state, + input, + output); } -// template -// void Cast::backward_task_with_1_type(Task const *task, -// std::vector const -// ®ions, Context ctx, Runtime *runtime) -// { -// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); -// if (m->input_data_type == DT_FLOAT) { -// Cast::backward_task_with_2_type(task, regions, ctx, runtime); -// } else if (m->input_data_type == DT_DOUBLE) { -// Cast::backward_task_with_2_type(task, regions, ctx, -// runtime); -// } else if (m->input_data_type == DT_INT32) { -// Cast::backward_task_with_2_type(task, regions, ctx, -// runtime); -// } else if (m->input_data_type == DT_INT64) { -// Cast::backward_task_with_2_type(task, regions, ctx, -// runtime); -// } -// } - -// template -// void Cast::backward_task_with_2_type(Task const *task, -// std::vector const -// ®ions, Context ctx, Runtime *runtime) -// { -// assert(regions.size() == 2); -// assert(task->regions.size() == regions.size()); -// // Domain input_domain = runtime->get_index_space_domain( -// // ctx, task->regions[0].region.get_index_space()); -// Domain output_domain = runtime->get_index_space_domain( -// ctx, task->regions[1].region.get_index_space()); -// const IDT *input_ptr = helperGetTensorPointerRO( -// regions[0], task->regions[0], FID_DATA, ctx, runtime); -// ODT *output_ptr = helperGetTensorPointerRW( -// regions[1], task->regions[1], FID_DATA, ctx, runtime); -// backward_kernel_wrapper( -// input_ptr, output_ptr, output_domain.get_volume()); -// } - -void Cast::backward_task(Task const *task, +static void forward_task(Task const *task, std::vector const ®ions, Context ctx, Runtime *runtime) { - CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); - // if (m->output_data_type == DT_FLOAT) { - // Cast::backward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->output_data_type == DT_DOUBLE) { - // Cast::backward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->output_data_type == DT_INT32) { - // Cast::backward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->output_data_type == DT_INT64) { - // Cast::backward_task_with_1_type(task, regions, ctx, runtime); - // } - auto input_grad = acc.get_tensor(INPUT); - auto output_grad = acc.get_tensor(OUTPUT); - - profile(backward_kernel, - m->profiling, - "[Cast] forward_time = %.2lfms\n", - m, - input_grad, - output_grad) + TaskArgumentAccessor acc(task, regions, ctx, runtime); + forward_task_impl(acc); +} + +static optional backward_task_impl(TaskArgumentAccessor const &acc) { + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + ProfilingSettings profiling = acc.get_argument(PROFILING); + + auto input_grad = acc.get_tensor_grad(INPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); + + return profile(backward_kernel, + profiling, + "[Cast] forward_time = %.2lfms\n", + &per_device_state, + input_grad, + output_grad); +} + +static void backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + backward_task_impl(acc); } -bool Cast::measure_operator_cost(Simulator *sim, - MachineView const &mv, - CostMetrics &cost_metrics) const { +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + CastAttrs const &attrs, + InputParallelTensorDesc const &input_shape, + ProfilingSettings const &settings, + MachineView const &mv) { + auto env = sim.new_environment(); + // Assume cast has no cost - cost_metrics.forward_time = 0.0f; - cost_metrics.backward_time = 0.0f; - cost_metrics.inputs_memory = 0; - cost_metrics.outputs_memory = 0; - cost_metrics.weights_memory = 0; - return true; + float forward_time = 0.0; + float backward_time = 0.0; + float sync_time = 0.0; + return make_metrics(forward_time, backward_time, sync_time, env); } -void Cast::serialize(Legion::Serializer &sez) const { - sez.serialize(this->outputs[0]->data_type); +template <> +void register_task() { + OpTaskSignature init(OpTaskType::INIT); + + init.add_unchecked_arg_slot(HANDLE); + + init.add_input_slot(INPUT); + init.add_output_slot(OUTPUT); + + register_task(CAST_INIT_TASK_ID, "Cast Init", init, init_task); } -using PCG::Node; +template <> +void register_task() { + OpTaskSignature fwd(OpTaskType::FWD); + + fwd.add_arg_slot(PROFILING); + fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); -Node Cast::deserialize(FFModel &ff, - Legion::Deserializer &dez, - ParallelTensor inputs[], - int num_inputs) { - assert(num_inputs == 1); - DataType dtype; - dez.deserialize(dtype); - return ff.get_or_create_node(inputs[0], {dtype}); + fwd.add_input_slot(INPUT); + fwd.add_output_slot(OUTPUT); + + register_task(CAST_FWD_TASK_ID, "Cast Fwd", fwd, forward_task); } -Op *Cast::materialize(FFModel &ff, - ParallelTensor inputs[], - int num_inputs) const { - assert(num_inputs == 1); - return new Cast(ff, inputs[0], this->outputs[0]->data_type, this->name); +template <> +void register_task() { + OpTaskSignature bwd = infer_bwd_signature(get_op_signature(CAST_FWD_TASK_ID)); + + register_task(CAST_BWD_TASK_ID, "Cast Bwd", bwd, backward_task); } }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/cast.h b/lib/runtime/src/ops/cast.h index 7d346584d5..c3781ad783 100644 --- a/lib/runtime/src/ops/cast.h +++ b/lib/runtime/src/ops/cast.h @@ -16,8 +16,8 @@ #define _FLEXFLOW_CAST_H #include "op-attrs/ops/cast.h" -#include "op_task_invocation.h" #include "sim_environment.h" +#include "task_spec/op_task_invocation.h" namespace FlexFlow { @@ -33,11 +33,54 @@ OpTaskInvocation forward(CastAttrs const &); OpTaskInvocation backward(CastAttrs const &); CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, - BatchNormAttrs const &attrs, + CastAttrs const &attrs, ParallelTensorShape const &input_shape, ProfilingSettings const &settings, MachineView const &machine_view); +} // namespace FlexFlow + +#endif + +// template +// void Cast::backward_task_with_1_type(Task const *task, +// std::vector const +// ®ions, Context ctx, Runtime *runtime) +// { +// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); +// if (m->input_data_type == DT_FLOAT) { +// Cast::backward_task_with_2_type(task, regions, ctx, runtime); +// } else if (m->input_data_type == DT_DOUBLE) { +// Cast::backward_task_with_2_type(task, regions, ctx, +// runtime); +// } else if (m->input_data_type == DT_INT32) { +// Cast::backward_task_with_2_type(task, regions, ctx, +// runtime); +// } else if (m->input_data_type == DT_INT64) { +// Cast::backward_task_with_2_type(task, regions, ctx, +// runtime); +// } +// } + +// template +// void Cast::backward_task_with_2_type(Task const *task, +// std::vector const +// ®ions, Context ctx, Runtime *runtime) +// { +// assert(regions.size() == 2); +// assert(task->regions.size() == regions.size()); +// // Domain input_domain = runtime->get_index_space_domain( +// // ctx, task->regions[0].region.get_index_space()); +// Domain output_domain = runtime->get_index_space_domain( +// ctx, task->regions[1].region.get_index_space()); +// const IDT *input_ptr = helperGetTensorPointerRO( +// regions[0], task->regions[0], FID_DATA, ctx, runtime); +// ODT *output_ptr = helperGetTensorPointerRW( +// regions[1], task->regions[1], FID_DATA, ctx, runtime); +// backward_kernel_wrapper( +// input_ptr, output_ptr, output_domain.get_volume()); +// } + /* class Cast : public Op { */ /* public: */ /* Cast(FFModel &model, */ @@ -80,6 +123,222 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, /* CostMetrics &cost_metrics) const; */ /* }; */ -} // namespace FlexFlow +// void Cast::backward(FFModel const &ff) { +// this->execute_task(ff, CAST_BWD_TASK_ID, get_bwd_task_signature()); +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_backward(ff, argmap); +// IndexLauncher launcher(CAST_BWD_TASK_ID, +// parallel_is, +// TaskArgument(NULL, false), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// outputs[0]->region_grad)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// inputs[0]->region_grad)); +// launcher.add_field(1, FID_DATA); +// runtime->execute_index_space(ctx, launcher); +// } -#endif +// template +// void Cast::forward_task_with_1_type(Task const *task, +// std::vector const +// ®ions, Context ctx, Runtime *runtime) +// { +// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); +// if (m->output_data_type == DT_FLOAT) { +// Cast::forward_task_with_2_type(task, regions, ctx, runtime); +// } else if (m->output_data_type == DT_DOUBLE) { +// Cast::forward_task_with_2_type(task, regions, ctx, runtime); +// } else if (m->output_data_type == DT_INT32) { +// Cast::forward_task_with_2_type(task, regions, ctx, +// runtime); +// } else if (m->output_data_type == DT_INT64) { +// Cast::forward_task_with_2_type(task, regions, ctx, +// runtime); +// } +// } + +// template +// void Cast::forward_task_with_2_type(Task const *task, +// std::vector const +// ®ions, Context ctx, Runtime *runtime) +// { +// assert(regions.size() == 2); +// assert(task->regions.size() == regions.size()); +// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); +// // Domain input_domain = runtime->get_index_space_domain( +// // ctx, task->regions[0].region.get_index_space()); +// Domain output_domain = runtime->get_index_space_domain( +// ctx, task->regions[1].region.get_index_space()); +// const IDT *input_ptr = helperGetTensorPointerRO( +// regions[0], task->regions[0], FID_DATA, ctx, runtime); +// ODT *output_ptr = helperGetTensorPointerWO( +// regions[1], task->regions[1], FID_DATA, ctx, runtime); +// forward_kernel_wrapper( +// m, input_ptr, output_ptr, output_domain.get_volume()); +// } + +// void Cast::forward(FFModel const &ff) { +// this->execute_task(ff, CAST_FWD_TASK_ID, get_fwd_task_signature()); +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_forward(ff, argmap); +// IndexLauncher launcher(CAST_FWD_TASK_ID, +// parallel_is, +// TaskArgument(NULL, false), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(1, FID_DATA); +// runtime->execute_index_space(ctx, launcher); +// } + +// void Cast::init(FFModel const &ff) { +// this->execute_task(ff, CAST_INIT_TASK_ID, get_init_task_signature()); +// assert(check_output_input_weight_same_parallel_is()); +// parallel_is = outputs[0]->parallel_is; +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_init(ff, argmap); +// IndexLauncher launcher(CAST_INIT_TASK_ID, +// parallel_is, +// TaskArgument(this, sizeof(Cast)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(1, FID_DATA); +// FutureMap fm = runtime->execute_index_space(ctx, launcher); +// fm.wait_all_results(); +// set_opmeta_from_futuremap(ff, fm); +// } + +// void Cast::serialize(Legion::Serializer &sez) const { +// sez.serialize(this->outputs[0]->data_type); +// } + +// using PCG::Node; + +// Node Cast::deserialize(FFModel &ff, +// Legion::Deserializer &dez, +// ParallelTensor inputs[], +// int num_inputs) { +// assert(num_inputs == 1); +// DataType dtype; +// dez.deserialize(dtype); +// return ff.get_or_create_node(inputs[0], {dtype}); +// } + +// Op *Cast::materialize(FFModel &ff, +// ParallelTensor inputs[], +// int num_inputs) const { +// assert(num_inputs == 1); +// return new Cast(ff, inputs[0], this->outputs[0]->data_type, this->name); +// } + +// Cast::Cast(FFModel &model, +// ParallelTensor const &input, +// DataType _dtype, +// char const *name) +// : Op(model, +// OP_CAST, +// _dtype, +// name, +// 1 /*inputs*/, +// 0 /*weights*/, +// 1 /*outputs*/, +// input) { +// numOutputs = 1; +// numWeights = 0; +// int numdim = input->num_dims; +// ParallelDim dims[MAX_TENSOR_DIM]; +// for (int i = 0; i < numdim; i++) { +// dims[i] = input->dims[i]; +// } +// outputs[0] = +// model.create_parallel_tensor_legion_ordering(numdim, dims, _dtype, +// this); +// } + +// Tensor FFModel::cast(const Tensor input, DataType dtype, char const *name) { +// Layer *cast = new Layer(this, +// OP_CAST, +// dtype, +// name, +// 1 /*inputs*/, +// 0 /*weights*/, +// 1 /*outputs*/, +// input); +// int numdims = input->num_dims; +// int dims[MAX_TENSOR_DIM]; +// for (int i = 0; i < numdims; i++) { +// dims[i] = input->dims[i]; +// } +// cast->outputs[0] = create_tensor_legion_ordering( +// numdims, dims, dtype, cast, 0, true /*create_grad*/); +// cast->add_int_property("dtype", dtype); +// layers.push_back(cast); +// return cast->outputs[0]; +// } + +// Op *Cast::create_operator_from_layer( +// FFModel &model, +// Layer const *layer, +// std::vector const &inputs) { +// long long value; +// layer->get_int_property("dtype", value); +// DataType dtype = (DataType)value; +// return new Cast(model, inputs[0], dtype, layer->name); +// } + +// CastParams Cast::get_params() const { +// CastParams params; +// params.dtype = this->outputs[0]->data_type; +// return params; +// } + +// Cast::Cast(FFModel &model, +// CastParams const ¶ms, +// ParallelTensor const &input, +// char const *name) +// : Cast(model, input, params.dtype, name) {} From d02998d2b1aed5cd8d1c9c9ba4b687d275844914 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Mon, 28 Aug 2023 13:53:04 -0700 Subject: [PATCH 09/24] combine --- lib/kernels/include/kernels/combine_kernels.h | 11 +- lib/op-attrs/include/op-attrs/get_op_type.h | 1 + .../include/op-attrs/operator_attrs.h | 2 + lib/runtime/src/ops/combine.cc | 391 ++++++------------ lib/runtime/src/ops/combine.h | 276 ++++++++++++- 5 files changed, 400 insertions(+), 281 deletions(-) diff --git a/lib/kernels/include/kernels/combine_kernels.h b/lib/kernels/include/kernels/combine_kernels.h index 174d1eb925..44ab67d9a7 100644 --- a/lib/kernels/include/kernels/combine_kernels.h +++ b/lib/kernels/include/kernels/combine_kernels.h @@ -6,15 +6,18 @@ namespace FlexFlow { -class CombinePerDeviceState : public PerDeviceOpState { -public: - CombinePerDeviceState(FFHandler handle); - DataType data_type; +struct CombinePerDeviceState { + req data_type; }; +FF_VISITABLE_STRUCT_NO_EQ(CombinePerDeviceState, + data_type); + namespace Kernels { namespace Combine { +CombinePerDeviceState init_kernel(DataType data_type); + void forward_kernel(ffStream_t stream, CombinePerDeviceState const *m, GenericTensorAccessorR const &input, diff --git a/lib/op-attrs/include/op-attrs/get_op_type.h b/lib/op-attrs/include/op-attrs/get_op_type.h index 8b451b2705..910d5dc925 100644 --- a/lib/op-attrs/include/op-attrs/get_op_type.h +++ b/lib/op-attrs/include/op-attrs/get_op_type.h @@ -12,6 +12,7 @@ OperatorType get_op_type(BatchMatmulAttrs const &); OperatorType get_op_type(BatchNormAttrs const &); OperatorType get_op_type(BroadcastAttrs const &); OperatorType get_op_type(CastAttrs const &); +OperatorType get_op_type(CombineAttrs const &); OperatorType get_op_type(ConcatAttrs const &); OperatorType get_op_type(Conv2DAttrs const &); OperatorType get_op_type(DropoutAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/operator_attrs.h b/lib/op-attrs/include/op-attrs/operator_attrs.h index 5fd067313e..b64fe73497 100644 --- a/lib/op-attrs/include/op-attrs/operator_attrs.h +++ b/lib/op-attrs/include/op-attrs/operator_attrs.h @@ -43,6 +43,7 @@ using SharedOperatorAttrs = variant::value, ""); static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); +static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); diff --git a/lib/runtime/src/ops/combine.cc b/lib/runtime/src/ops/combine.cc index 2485955124..c2d8d9a017 100644 --- a/lib/runtime/src/ops/combine.cc +++ b/lib/runtime/src/ops/combine.cc @@ -13,322 +13,163 @@ * limitations under the License. */ -#include "parallel_ops/combine.h" +#include "combine.h" #include "kernels/combine_kernels.h" #include "utils/hash-utils.h" namespace FlexFlow { // declare Legion names -using Legion::ArgumentMap; using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; -using Legion::LogicalPartition; -using Legion::LogicalRegion; -using Legion::Machine; -using Legion::Memory; using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; using Legion::Runtime; using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; using namespace FlexFlow::Kernels::Combine; -CombineParams Combine::get_params() const { - CombineParams params; - params.combine_legion_dim = this->combine_dim; - params.combine_degree = this->combine_degree; - return params; -} +enum Slots { + INPUT, + OUTPUT, + PROFILING, + PER_DEVICE_STATE +}; -Combine::Combine(FFModel &model, - CombineParams const ¶ms, - ParallelTensor const input, - char const *name) - : Combine(model, - input, - params.combine_legion_dim, - params.combine_degree, - name) {} - -Combine::Combine(FFModel &model, - const ParallelTensor _input, - int _combine_legion_dim, - int _combine_degree, - char const *name) - : ParallelOp(model, OP_COMBINE, name, _input), - combine_dim(_combine_legion_dim), combine_degree(_combine_degree) { - int numdim = _input->num_dims; - ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdim; i++) { - dims[i] = _input->dims[i]; - } - assert(combine_degree > 0 && "Must use combine_degree > 0"); - assert(dims[combine_dim].degree % combine_degree == 0); - dims[combine_dim].degree /= combine_degree; - ParallelTensorBase::update_parallel_ids(numdim, dims); - outputs[0] = model.create_parallel_tensor_legion_ordering( - numdim, dims, DT_FLOAT, this); - // inputs[0]->print("Combine::input"); - // outputs[0]->print("Combine::output"); -} +OpTaskInvocation init(CombineAttrs const &attrs) { + OpTaskBinding binding; + + binding.bind(INPUT, input_tensor(0)); -PerDeviceOpState *Combine::init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - Combine *rep = (Combine *)task->args; - // FFHandler handle = *((FFHandler *)task->local_args); - // CombineMeta* m = new CombineMeta(handle); - // m->data_type = rep->outputs[0]->data_type; - return nullptr; + return {COMBINE_INIT_TASK_ID, binding}; } -void Combine::init(FFModel const &ff) { - parallel_is = outputs[0]->parallel_is; - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - assert(numOutputs == 1); - assert(numInputs == 1); - IndexLauncher launcher(COMBINE_INIT_TASK_ID, - parallel_is, - TaskArgument(this, sizeof(Combine)), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - outputs[0]->machine_view.hash()); - launcher.add_region_requirement(RegionRequirement( - input_lp, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, inputs[0]->region)); - launcher.add_field(0, FID_DATA); - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - outputs[0]->region)); - launcher.add_field(1, FID_DATA); - FutureMap fm = runtime->execute_index_space(ctx, launcher); - fm.wait_all_results(); +OpTaskInvocation forward(CombineAttrs const &attrs) { + OpTaskBinding binding; + + binding.bind_arg(PER_DEVICE_STATE, per_device_op_state()); + binding.bind_arg(PROFILING, profiling_settings()); + + binding.bind(INPUT, input_tensor(0)); + binding.bind(OUTPUT, output_tensor(0)); + + return {COMBINE_FWD_TASK_ID, binding}; } -void Combine::create_input_partition(FFModel &ff) { - assert(outputs[0]->part != LogicalPartition::NO_PART); - assert(inputs[0]->part != LogicalPartition::NO_PART); - ff.create_disjoint_partition(outputs[0]->num_dims, - outputs[0]->dims, - outputs[0]->parallel_is, - inputs[0]->region, - input_lp); - ff.create_disjoint_partition(inputs[0]->num_dims, - inputs[0]->dims, - inputs[0]->parallel_is, - outputs[0]->region_grad, - output_grad_lp); +OpTaskInvocation backward(CombineAttrs const &attrs) { + OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); + + return {COMBINE_BWD_TASK_ID, b}; } -void Combine::forward(FFModel const &ff) { - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - assert(numOutputs == 1); - assert(numInputs == 1); - assert(inputs[0]->data_type == outputs[0]->data_type); - DataType data_type = inputs[0]->data_type; - IndexLauncher launcher(COMBINE_FWD_TASK_ID, - outputs[0]->parallel_is, - TaskArgument(&data_type, sizeof(data_type)), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - outputs[0]->machine_view.hash()); - launcher.add_region_requirement(RegionRequirement( - input_lp, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, inputs[0]->region)); - launcher.add_field(0, FID_DATA); - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - outputs[0]->region)); - launcher.add_field(1, FID_DATA); - runtime->execute_index_space(ctx, launcher); +static DeviceSpecific + init_task_impl(TaskArgumentAccessor const &acc) { + + auto input = acc.get_tensor(INPUT); + + DeviceSpecific per_device_state = + acc.create_device_specific( + init_kernel(input.data_type)); + return per_device_state; } -void Combine::backward(FFModel const &ff) { - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - assert(numOutputs == 1); - assert(numInputs == 1); - assert(inputs[0]->data_type == outputs[0]->data_type); - DataType data_type = inputs[0]->data_type; - IndexLauncher launcher(COMBINE_BWD_TASK_ID, - inputs[0]->parallel_is, - TaskArgument(&data_type, sizeof(DataType)), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - inputs[0]->machine_view.hash()); - launcher.add_region_requirement(RegionRequirement(output_grad_lp, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - outputs[0]->region_grad)); - launcher.add_field(0, FID_DATA); - launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, - 0 /*projection id*/, - READ_WRITE, - EXCLUSIVE, - inputs[0]->region_grad)); - launcher.add_field(1, FID_DATA); - runtime->execute_index_space(ctx, launcher); +static DeviceSpecific + init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + return init_task_impl(acc); } -bool Combine::measure_operator_cost(Simulator *sim, - MachineView const &mv, - CostMetrics &cost_metrics) const { - // TODO: to be implemented - cost_metrics = CostMetrics(); - cost_metrics.forward_time = 0.05f; - cost_metrics.backward_time = 0.05f; - return true; +static optional forward_task_impl(TaskArgumentAccessor const &acc) { + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + ProfilingSettings profiling = acc.get_argument(PROFILING); + + auto input = acc.get_tensor(INPUT); + auto output = acc.get_tensor(OUTPUT); + + return profile(forward_kernel, + profiling, + "[Combine] forward_time = %.2lfms\n", + &per_device_state, + input, + output); } -bool Combine::get_int_parameter(PMParameter para, int *value) const { - switch (para) { - case PM_COMBINE_DIM: - *value = combine_dim; - return true; - case PM_COMBINE_DEGREE: - *value = combine_degree; - return true; - default: - return Op::get_int_parameter(para, value); - } +static void forward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + forward_task_impl(acc); } -bool Combine::append_parallel_op_info( - std::vector ¶llel_ops) const { - ParallelOpInfo ret; - ret.op_type = op_type; - ret.parallel_dim = combine_dim; - ret.parallel_degree = combine_degree; - parallel_ops.push_back(ret); - return true; +static optional backward_task_impl(TaskArgumentAccessor const &acc) { + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + ProfilingSettings profiling = acc.get_argument(PROFILING); + + auto input_grad = acc.get_tensor_grad(INPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); + + return profile(backward_kernel, + profiling, + "[Combine] forward_time = %.2lfms\n", + &per_device_state, + input_grad, + output_grad); } -tl::optional Combine::as_dot() const { - RecordFormatter rf; - { - std::ostringstream oss; - oss << "dim(" << this->combine_dim << ")"; - rf << oss.str(); - } - { - std::ostringstream oss; - oss << "deg(" << this->combine_degree << ")"; - rf << oss.str(); - } - return rf; +static void backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + backward_task_impl(acc); } -/*static*/ -void Combine::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 2); - assert(task->regions.size() == 2); - DataType data_type = *((DataType *)task->args); - if (data_type == DT_FLOAT) { - forward_task_with_type(task, regions, ctx, runtime); - } else if (data_type == DT_DOUBLE) { - forward_task_with_type(task, regions, ctx, runtime); - } else if (data_type == DT_INT32) { - forward_task_with_type(task, regions, ctx, runtime); - } else if (data_type == DT_INT64) { - forward_task_with_type(task, regions, ctx, runtime); - } else { - assert(false && "Unsupported data type in Combine forward"); - } +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + CombineAttrs const &attrs, + InputParallelTensorDesc const &input_shape, + ProfilingSettings const &settings, + MachineView const &mv) { + auto env = sim.new_environment(); + // TODO: to be implemented + float forward_time = 0.5; + float backward_time = 0.5; + float sync_time = 0.0; + return make_metrics(forward_time, backward_time, sync_time, env); } -template -void Combine::forward_task_with_type(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - Domain input_domain = runtime->get_index_space_domain( - ctx, task->regions[0].region.get_index_space()); - Domain output_domain = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); - assert(output_domain == input_domain); - - const DT *input_ptr = helperGetTensorPointerRO
( - regions[0], task->regions[0], FID_DATA, ctx, runtime); - DT *output_ptr = helperGetTensorPointerWO
( - regions[1], task->regions[1], FID_DATA, ctx, runtime); - - forward_kernel
(input_ptr, output_ptr, output_domain.get_volume()); +template <> +void register_task() { + OpTaskSignature init(OpTaskType::INIT); + + init.add_input_slot(INPUT); + + register_task(COMBINE_INIT_TASK_ID, "Combine Init", init, init_task); } -void Combine::backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 2); - assert(task->regions.size() == 2); - DataType data_type = *((DataType *)task->args); - if (data_type == DT_FLOAT) { - backward_task_with_type(task, regions, ctx, runtime); - } else if (data_type == DT_DOUBLE) { - backward_task_with_type(task, regions, ctx, runtime); - } else if (data_type == DT_INT32) { - backward_task_with_type(task, regions, ctx, runtime); - } else if (data_type == DT_INT64) { - backward_task_with_type(task, regions, ctx, runtime); - } else { - assert(false && "Unsupported data type in Combine backward"); - } +template <> +void register_task() { + OpTaskSignature fwd(OpTaskType::FWD); + + fwd.add_arg_slot(PROFILING); + fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); + + fwd.add_input_slot(INPUT); + fwd.add_output_slot(OUTPUT); + register_task(COMBINE_FWD_TASK_ID, "Combine Fwd", fwd, forward_task); } -template -void Combine::backward_task_with_type( - Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - Domain output_grad_domain = runtime->get_index_space_domain( - ctx, task->regions[0].region.get_index_space()); - Domain input_grad_domain = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); - assert(output_grad_domain == input_grad_domain); - - const DT *output_grad_ptr = helperGetTensorPointerRO
( - regions[0], task->regions[0], FID_DATA, ctx, runtime); - DT *input_grad_ptr = helperGetTensorPointerRW
( - regions[1], task->regions[1], FID_DATA, ctx, runtime); - - backward_kernel
( - output_grad_ptr, input_grad_ptr, output_grad_domain.get_volume()); +template <> +void register_task() { + OpTaskSignature bwd = + infer_bwd_signature(get_op_signature(COMBINE_FWD_TASK_ID)); + + register_task(COMBINE_BWD_TASK_ID, "Combine Bwd", bwd, backward_task); } }; // namespace FlexFlow namespace std { -size_t hash::operator()( - FlexFlow::CombineParams const ¶ms) const { - size_t key = 0; - hash_combine(key, params.combine_legion_dim); - hash_combine(key, params.combine_degree); - return key; -} }; // namespace std diff --git a/lib/runtime/src/ops/combine.h b/lib/runtime/src/ops/combine.h index 455a8ea780..4b96ce6495 100644 --- a/lib/runtime/src/ops/combine.h +++ b/lib/runtime/src/ops/combine.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_COMBINE_H #include "op-attrs/ops/combine.h" -#include "op_task_invocation.h" +#include "task_spec/op_task_invocation.h" #include "sim_environment.h" namespace FlexFlow { @@ -20,7 +20,7 @@ OpTaskInvocation backward(CombineAttrs const &); CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, CombineAttrs const &attrs, - ParallelTensorShape const &input_shape, + InputParallelTensorDesc const &input_shape, ProfilingSettings const &settings, MachineView const &machine_view); @@ -82,3 +82,275 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, } // namespace FlexFlow #endif + +// size_t hash::operator()( +// FlexFlow::CombineParams const ¶ms) const { +// size_t key = 0; +// hash_combine(key, params.combine_legion_dim); +// hash_combine(key, params.combine_degree); +// return key; +// } + +// template +// void Combine::backward_task_with_type( +// Task const *task, +// std::vector const ®ions, +// Context ctx, +// Runtime *runtime) { +// Domain output_grad_domain = runtime->get_index_space_domain( +// ctx, task->regions[0].region.get_index_space()); +// Domain input_grad_domain = runtime->get_index_space_domain( +// ctx, task->regions[1].region.get_index_space()); +// assert(output_grad_domain == input_grad_domain); + +// const DT *output_grad_ptr = helperGetTensorPointerRO
( +// regions[0], task->regions[0], FID_DATA, ctx, runtime); +// DT *input_grad_ptr = helperGetTensorPointerRW
( +// regions[1], task->regions[1], FID_DATA, ctx, runtime); + +// backward_kernel
( +// output_grad_ptr, input_grad_ptr, output_grad_domain.get_volume()); +// } + + +// void Combine::backward_task(Task const *task, +// std::vector const ®ions, +// Context ctx, +// Runtime *runtime) { +// assert(regions.size() == 2); +// assert(task->regions.size() == 2); +// DataType data_type = *((DataType *)task->args); +// if (data_type == DT_FLOAT) { +// backward_task_with_type(task, regions, ctx, runtime); +// } else if (data_type == DT_DOUBLE) { +// backward_task_with_type(task, regions, ctx, runtime); +// } else if (data_type == DT_INT32) { +// backward_task_with_type(task, regions, ctx, runtime); +// } else if (data_type == DT_INT64) { +// backward_task_with_type(task, regions, ctx, runtime); +// } else { +// assert(false && "Unsupported data type in Combine backward"); +// } +// } + +// bool Combine::get_int_parameter(PMParameter para, int *value) const { +// switch (para) { +// case PM_COMBINE_DIM: +// *value = combine_dim; +// return true; +// case PM_COMBINE_DEGREE: +// *value = combine_degree; +// return true; +// default: +// return Op::get_int_parameter(para, value); +// } +// } + +// bool Combine::append_parallel_op_info( +// std::vector ¶llel_ops) const { +// ParallelOpInfo ret; +// ret.op_type = op_type; +// ret.parallel_dim = combine_dim; +// ret.parallel_degree = combine_degree; +// parallel_ops.push_back(ret); +// return true; +// } + +// tl::optional Combine::as_dot() const { +// RecordFormatter rf; +// { +// std::ostringstream oss; +// oss << "dim(" << this->combine_dim << ")"; +// rf << oss.str(); +// } +// { +// std::ostringstream oss; +// oss << "deg(" << this->combine_degree << ")"; +// rf << oss.str(); +// } +// return rf; +// } + + +// void Combine::init(FFModel const &ff) { +// parallel_is = outputs[0]->parallel_is; +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// assert(numOutputs == 1); +// assert(numInputs == 1); +// IndexLauncher launcher(COMBINE_INIT_TASK_ID, +// parallel_is, +// TaskArgument(this, sizeof(Combine)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement( +// input_lp, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, inputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(1, FID_DATA); +// FutureMap fm = runtime->execute_index_space(ctx, launcher); +// fm.wait_all_results(); +// } + +// void Combine::create_input_partition(FFModel &ff) { +// assert(outputs[0]->part != LogicalPartition::NO_PART); +// assert(inputs[0]->part != LogicalPartition::NO_PART); +// ff.create_disjoint_partition(outputs[0]->num_dims, +// outputs[0]->dims, +// outputs[0]->parallel_is, +// inputs[0]->region, +// input_lp); +// ff.create_disjoint_partition(inputs[0]->num_dims, +// inputs[0]->dims, +// inputs[0]->parallel_is, +// outputs[0]->region_grad, +// output_grad_lp); +// } + +// void Combine::forward(FFModel const &ff) { +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// assert(numOutputs == 1); +// assert(numInputs == 1); +// assert(inputs[0]->data_type == outputs[0]->data_type); +// DataType data_type = inputs[0]->data_type; +// IndexLauncher launcher(COMBINE_FWD_TASK_ID, +// outputs[0]->parallel_is, +// TaskArgument(&data_type, sizeof(data_type)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement( +// input_lp, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, inputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(1, FID_DATA); +// runtime->execute_index_space(ctx, launcher); +// } + +// void Combine::backward(FFModel const &ff) { +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// assert(numOutputs == 1); +// assert(numInputs == 1); +// assert(inputs[0]->data_type == outputs[0]->data_type); +// DataType data_type = inputs[0]->data_type; +// IndexLauncher launcher(COMBINE_BWD_TASK_ID, +// inputs[0]->parallel_is, +// TaskArgument(&data_type, sizeof(DataType)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// inputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(output_grad_lp, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// outputs[0]->region_grad)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// inputs[0]->region_grad)); +// launcher.add_field(1, FID_DATA); +// runtime->execute_index_space(ctx, launcher); +// } + + + +// CombineParams Combine::get_params() const { +// CombineParams params; +// params.combine_legion_dim = this->combine_dim; +// params.combine_degree = this->combine_degree; +// return params; +// } + +// Combine::Combine(FFModel &model, +// CombineParams const ¶ms, +// ParallelTensor const input, +// char const *name) +// : Combine(model, +// input, +// params.combine_legion_dim, +// params.combine_degree, +// name) {} + +// Combine::Combine(FFModel &model, +// const ParallelTensor _input, +// int _combine_legion_dim, +// int _combine_degree, +// char const *name) +// : ParallelOp(model, OP_COMBINE, name, _input), +// combine_dim(_combine_legion_dim), combine_degree(_combine_degree) { +// int numdim = _input->num_dims; +// ParallelDim dims[MAX_TENSOR_DIM]; +// for (int i = 0; i < numdim; i++) { +// dims[i] = _input->dims[i]; +// } +// assert(combine_degree > 0 && "Must use combine_degree > 0"); +// assert(dims[combine_dim].degree % combine_degree == 0); +// dims[combine_dim].degree /= combine_degree; +// ParallelTensorBase::update_parallel_ids(numdim, dims); +// outputs[0] = model.create_parallel_tensor_legion_ordering( +// numdim, dims, DT_FLOAT, this); +// // inputs[0]->print("Combine::input"); +// // outputs[0]->print("Combine::output"); +// } + +// /*static*/ +// void Combine::forward_task(Task const *task, +// std::vector const ®ions, +// Context ctx, +// Runtime *runtime) { +// assert(regions.size() == 2); +// assert(task->regions.size() == 2); +// DataType data_type = *((DataType *)task->args); +// if (data_type == DT_FLOAT) { +// forward_task_with_type(task, regions, ctx, runtime); +// } else if (data_type == DT_DOUBLE) { +// forward_task_with_type(task, regions, ctx, runtime); +// } else if (data_type == DT_INT32) { +// forward_task_with_type(task, regions, ctx, runtime); +// } else if (data_type == DT_INT64) { +// forward_task_with_type(task, regions, ctx, runtime); +// } else { +// assert(false && "Unsupported data type in Combine forward"); +// } +// } + +// template +// void Combine::forward_task_with_type(Task const *task, +// std::vector const ®ions, +// Context ctx, +// Runtime *runtime) { +// Domain input_domain = runtime->get_index_space_domain( +// ctx, task->regions[0].region.get_index_space()); +// Domain output_domain = runtime->get_index_space_domain( +// ctx, task->regions[1].region.get_index_space()); +// assert(output_domain == input_domain); + +// const DT *input_ptr = helperGetTensorPointerRO
( +// regions[0], task->regions[0], FID_DATA, ctx, runtime); +// DT *output_ptr = helperGetTensorPointerWO
( +// regions[1], task->regions[1], FID_DATA, ctx, runtime); + +// forward_kernel
(input_ptr, output_ptr, output_domain.get_volume()); +// } From f2205f43b18ca97e332c11d56668e1c0516c6826 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 10:26:59 -0700 Subject: [PATCH 10/24] change --- lib/kernels/include/kernels/batch_matmul_kernels.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index 23522a89f7..6850dec178 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -19,8 +19,8 @@ FF_VISITABLE_STRUCT_NO_EQ( namespace Kernels { namespace BatchMatmul { -BMMPerDeviceState init_kernel(PerDeviceFFHandle handle, - Allocator allocator, +BMMPerDeviceState init_kernel(PerDeviceFFHandle const &handle, + Allocator const &allocator, int a_seq_length_dim, int b_seq_length_dim); From 2f4662da4fcb64dad9de7dd6294e8e6baa3e532e Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 10:28:45 -0700 Subject: [PATCH 11/24] change --- lib/kernels/include/kernels/batch_matmul_kernels.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index 6850dec178..e5062b2c61 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -26,7 +26,7 @@ BMMPerDeviceState init_kernel(PerDeviceFFHandle const &handle, void forward_kernel(ffStream_t stream, BMMPerDeviceState const &meta, - float *o_ptr, + float *output_ptr, float const *a_ptr, float const *b_ptr, float const *c_ptr, From 642eb90d99bdb83b3a212be50232520b497816ba Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 10:34:04 -0700 Subject: [PATCH 12/24] change --- lib/kernels/include/kernels/batch_matmul_kernels.h | 5 ++--- lib/runtime/src/ops/batch_matmul.cc | 1 - 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index e5062b2c61..ec32648d0f 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -27,9 +27,8 @@ BMMPerDeviceState init_kernel(PerDeviceFFHandle const &handle, void forward_kernel(ffStream_t stream, BMMPerDeviceState const &meta, float *output_ptr, - float const *a_ptr, - float const *b_ptr, - float const *c_ptr, + float const *lhs_input_ptr, + float const *rhs_input_ptr, int m, int n, int k, diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 59570f4417..09eb1f23c7 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -134,7 +134,6 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { output.get_float_ptr(), a_input.get_float_ptr(), b_input.get_float_ptr(), - nullptr, // c_ptr m, n, k, From e79406a842117e7823d8d63ed21ad5613c846078 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 10:46:00 -0700 Subject: [PATCH 13/24] change --- lib/kernels/src/cuda/batch_matmul_kernels.cu | 5 +---- lib/kernels/src/hip/batch_matmul_kernels.cpp | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/lib/kernels/src/cuda/batch_matmul_kernels.cu b/lib/kernels/src/cuda/batch_matmul_kernels.cu index 3593ac4ab2..cde0df93c0 100644 --- a/lib/kernels/src/cuda/batch_matmul_kernels.cu +++ b/lib/kernels/src/cuda/batch_matmul_kernels.cu @@ -18,9 +18,6 @@ namespace FlexFlow { -BatchMatmulPerDeviceState::BatchMatmulPerDeviceState(FFHandler handler) - : PerDeviceOpState(handler) {} - namespace Kernels { namespace BatchMatmul { @@ -124,7 +121,7 @@ O = A * B */ void forward_kernel(cudaStream_t stream, - BatchMatmulPerDeviceState const *meta, + BatchMatmulPerDeviceState const &meta, float *o_ptr, float const *a_ptr, float const *b_ptr, diff --git a/lib/kernels/src/hip/batch_matmul_kernels.cpp b/lib/kernels/src/hip/batch_matmul_kernels.cpp index e5334b1841..a06442d3d6 100644 --- a/lib/kernels/src/hip/batch_matmul_kernels.cpp +++ b/lib/kernels/src/hip/batch_matmul_kernels.cpp @@ -29,7 +29,7 @@ O: (batch, n, m) O = A * B */ void forward_kernel(hipStream_t stream, - BatchMatmulPerDeviceState const *meta, + BatchMatmulPerDeviceState const &meta, float *o_ptr, float const *a_ptr, float const *b_ptr, From bb3c10f8a9f46c9a09d514a9c70f4e72e37e3e94 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 10:51:31 -0700 Subject: [PATCH 14/24] change --- lib/op-attrs/include/op-attrs/ops/batch_matmul.h | 3 --- lib/runtime/src/ops/batch_matmul.cc | 6 +----- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index 6ac2a2cf69..00c700ba20 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -14,9 +14,6 @@ FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim); int get_aSeqLengthDim(BatchMatmulAttrs const &attrs); int get_bSeqLengthDim(BatchMatmulAttrs const &attrs); -ParallelTensorShape get_weights_shape(BatchMatmulAttrs const &, - ParallelTensorShape const &, - ParallelTensorShape const &); CHECK_VALID_OP_ATTR(BatchMatmulAttrs); } // namespace FlexFlow diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 09eb1f23c7..669cad215b 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -80,9 +80,7 @@ static DeviceSpecific DeviceSpecific per_device_state = acc.create_device_specific( init_kernel(handle, allocator, a_seq_length_dim, b_seq_length_dim)); - - // assert(weight.shape.get_volume() * sizeof(float) == - // acc.unwrap(per_device_state)->weightSize); + return per_device_state; } @@ -230,8 +228,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, ParallelTensorShape output_shape = get_output_shape(attrs, a_input.shape, b_input.shape); - ParallelTensorShape weight_shape = - get_weights_shape(attrs, a_input.shape, b_input.shape); SimTaskBinding init_binding; init_binding.bind_arg(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); From 848850900f9d44fe29df90c69a2099cb71be299e Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 11:00:34 -0700 Subject: [PATCH 15/24] change --- lib/runtime/include/runtime/config.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/runtime/include/runtime/config.h b/lib/runtime/include/runtime/config.h index 21cfca20cf..ef7e779469 100644 --- a/lib/runtime/include/runtime/config.h +++ b/lib/runtime/include/runtime/config.h @@ -107,7 +107,7 @@ struct FFConfig : public use_visitable_cmp { struct FFIterationConfig { FFIterationConfig(); void reset(); - req seq_length; + int seq_length; }; FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(FFIterationConfig, seq_length); From 36eba295424c5c58ea0e3aedea76e3f3adfd3790 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 11:03:04 -0700 Subject: [PATCH 16/24] change --- lib/runtime/src/ops/batch_matmul.h | 55 ------------------------------ 1 file changed, 55 deletions(-) diff --git a/lib/runtime/src/ops/batch_matmul.h b/lib/runtime/src/ops/batch_matmul.h index 42c9bc23de..018fe1d582 100644 --- a/lib/runtime/src/ops/batch_matmul.h +++ b/lib/runtime/src/ops/batch_matmul.h @@ -30,61 +30,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, ProfilingSettings const &settings, MachineView const &pc); -/* class BatchMatmul : public Op { */ -/* public: */ -/* BatchMatmul(FFModel &model, */ -/* const ParallelTensor A, */ -/* const ParallelTensor B, */ -/* int a_seq_length_dim, */ -/* int b_seq_length_dim, */ -/* char const *name = nullptr); */ -/* static Op * */ -/* create_operator_from_layer(FFModel &model, */ -/* Layer const *layer, */ -/* std::vector const &inputs); - */ - -/* void init(FFModel const &) override; */ -/* void forward(FFModel const &) override; */ -/* void backward(FFModel const &) override; */ -/* /1* static PCG::Node deserialize(FFModel &ff, *1/ */ -/* /1* Legion::Deserializer &d, *1/ */ -/* /1* ParallelTensor inputs[], *1/ */ -/* /1* int num_inputs); *1/ */ -/* static PerDeviceOpState *init_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void forward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void backward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* bool measure_operator_cost(Simulator *sim, */ -/* MachineView const &pc, */ -/* CostMetrics &cost_metrics) const override; */ - -/* OpTaskBinding get_init_task_binding() const override; */ -/* OpTaskBinding get_fwd_task_binding() const override; */ -/* OpTaskBinding get_bwd_task_binding() const override; */ -/* private: */ -/* template */ -/* void init_with_dim(FFModel const &ff); */ -/* template */ -/* void forward_with_dim(FFModel const &ff); */ -/* template */ -/* void backward_with_dim(FFModel const &ff); */ - -/* public: */ -/* int a_seq_length_dim, b_seq_length_dim; */ -/* }; */ - } // namespace FlexFlow #endif From 98773b4237de4a4585711bca14d26b9d3c6e75ff Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 11:15:53 -0700 Subject: [PATCH 17/24] change --- lib/runtime/src/ops/batch_matmul.cc | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 669cad215b..a8c2ec7bd7 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -72,8 +72,8 @@ OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { static DeviceSpecific init_task_impl(TaskArgumentAccessor const &acc) { - auto const a_seq_length_dim = acc.get_argument(A_SEQ_LENGTH_DIM); - auto const b_seq_length_dim = acc.get_argument(B_SEQ_LENGTH_DIM); + int const a_seq_length_dim = acc.get_argument(A_SEQ_LENGTH_DIM); + int const b_seq_length_dim = acc.get_argument(B_SEQ_LENGTH_DIM); PerDeviceFFHandle handle = acc.get_argument(HANDLE); Allocator allocator = acc.get_allocator(); @@ -94,9 +94,6 @@ static DeviceSpecific } static optional forward_task_impl(TaskArgumentAccessor const &acc) { - assert(regions.size() == 3); - assert(task->regions.size() == 3); - auto a_input = acc.get_tensor(A_INPUT); auto b_input = acc.get_tensor(B_INPUT); auto output = acc.get_tensor(OUTPUT); @@ -148,10 +145,6 @@ static void forward_task(Task const *task, } static optional backward_task_impl(TaskArgumentAccessor const &acc) { - // Currently assume C is NULL - assert(regions.size() == 6); - assert(task->regions.size() == 6); - // BatchMatmul* bmm = (BatchMatmul*) task->args; FFIterationConfig iter_config = acc.get_argument(ITERATION_CONFIG); @@ -161,15 +154,15 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { auto output = acc.get_tensor(OUTPUT); auto output_grad = acc.get_tensor_grad(OUTPUT); // is this equivalent to checking `Domain` equality? - assert(output == output_grad); + assert(output.shape == output_grad.shape); auto a_input = acc.get_tensor(A_INPUT); auto a_input_grad = acc.get_tensor_grad(A_INPUT); - assert(a_input == a_input_grad); + assert(a_input.shape == a_input_grad.shape); auto b_input = acc.get_tensor(B_INPUT); auto b_input_grad = acc.get_tensor_grad(B_INPUT); - assert(b_input == b_input_grad); + assert(b_input.shape == b_input_grad.shape); // check dins int m = b_input.shape[legion_dim_t(0)]; @@ -181,8 +174,8 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(a_input.shape.size() == b_input.shape.size()); assert(a_input.shape.size() == output.shape.size()); int batch = 1; - for (int i = 2; i < a_input.shape.get_dim(); - i++) { //@colin get_dim() or get_volume()? + for (int i = 2; i < a_input.shape.dims.num_dims(); + i++) { int dim_size = a_input.shape[legion_dim_t(i)]; assert(dim_size == b_input.shape[legion_dim_t(i)]); assert(dim_size == output.shape[legion_dim_t(i)]); From ae592613cf3fa6c763ca0f14b30a1899831ee4a9 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 11:28:18 -0700 Subject: [PATCH 18/24] change --- lib/runtime/src/task_spec/op_arg_ref.h | 10 ++-------- lib/runtime/src/task_spec/runtime_arg_ref.h | 2 ++ 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/lib/runtime/src/task_spec/op_arg_ref.h b/lib/runtime/src/task_spec/op_arg_ref.h index f2361af05d..20ce6892a2 100644 --- a/lib/runtime/src/task_spec/op_arg_ref.h +++ b/lib/runtime/src/task_spec/op_arg_ref.h @@ -4,15 +4,13 @@ #include "arg_ref.h" #include "device_specific.h" #include "op-attrs/parallel_tensor_shape.h" -#include "runtime/config.h" namespace FlexFlow { enum class OpArgRefType { PER_DEVICE_OP_STATE, - PARALLEL_TENSOR_SHAPE, - ITERATION_CONFIG -}; + PARALLEL_TENSOR_SHAPE + }; template using OpArgRef = ArgRef; @@ -28,10 +26,6 @@ OpArgRef input_parallel_tensor_shape(int idx) { return {OpArgRefType::PARALLEL_TENSOR_SHAPE}; } -OpArgRef iteration_config() { - return {OpArgRefType::ITERATION_CONFIG}; -} - } // namespace FlexFlow #endif diff --git a/lib/runtime/src/task_spec/runtime_arg_ref.h b/lib/runtime/src/task_spec/runtime_arg_ref.h index 6b4345091a..033c2bcfbc 100644 --- a/lib/runtime/src/task_spec/runtime_arg_ref.h +++ b/lib/runtime/src/task_spec/runtime_arg_ref.h @@ -3,6 +3,7 @@ #include "arg_ref.h" #include "device_specific.h" +#include "runtime/config.h" namespace FlexFlow { @@ -15,6 +16,7 @@ using RuntimeArgRefSpec = ArgRefSpec; RuntimeArgRef profiling_settings(); RuntimeArgRef> ff_handle(); +RuntimeArgRef iteration_config(); } // namespace FlexFlow From 8db251252670a8107430318da166e2ba059879ca Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 17:50:28 -0700 Subject: [PATCH 19/24] fix asserts --- lib/runtime/src/ops/attention.cc | 15 --------------- lib/runtime/src/ops/batch_matmul.cc | 13 ++++--------- 2 files changed, 4 insertions(+), 24 deletions(-) diff --git a/lib/runtime/src/ops/attention.cc b/lib/runtime/src/ops/attention.cc index bca87bdb53..94e2b03731 100644 --- a/lib/runtime/src/ops/attention.cc +++ b/lib/runtime/src/ops/attention.cc @@ -121,18 +121,6 @@ static DeviceSpecific int num_samples = get_piece_shape(query_parallel_tensor_shape)[ff_dim_t(2)]; int num_heads = get_piece_shape(weight_parallel_tensor_shape)[ff_dim_t(1)]; - assert(qoSeqLength == query.shape[legion_dim_t(1)]); - assert(qSize == query.shape[legion_dim_t(0)]); - assert(num_samples == key.shape[legion_dim_t(2)]); - assert(kvSeqLength == key.shape[legion_dim_t(1)]); - assert(kSize == key.shape[legion_dim_t(0)]); - assert(num_samples == value.shape[legion_dim_t(2)]); - assert(kvSeqLength == value.shape[legion_dim_t(1)]); - assert(vSize == value.shape[legion_dim_t(0)]); - assert(num_samples == output.shape[legion_dim_t(2)]); - assert(qoSeqLength == output.shape[legion_dim_t(1)]); - assert(oProjSize == output.shape[legion_dim_t(0)]); - DeviceSpecific per_device_state = acc.create_device_specific( init_kernel(handle, @@ -149,9 +137,6 @@ static DeviceSpecific qoSeqLength, kvSeqLength, attrs.add_bias_kv)); - - assert(weight.shape.get_volume() * sizeof(float) == - acc.unwrap(per_device_state)->weightSize); return per_device_state; } diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index a8c2ec7bd7..00699652e7 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -110,8 +110,8 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { int k = a_input.shape[legion_dim_t(0)]; assert(k == b_input.shape[legion_dim_t(1)]); - assert(a_input.shape.size() == b_input.shape.size()); - assert(a_input.shape.size() == output.shape.size()); + assert(a_input.shape.get_volume() == b_input.shape.get_volume()); + assert(a_input.shape.get_volume() == output.shape.get_volume()); int batch = 1; for (int i = 2; i < a_input.shape.get_dim(); @@ -171,8 +171,8 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(n == output.shape[legion_dim_t(1)]); int k = a_input.shape[legion_dim_t(0)]; assert(k == b_input.shape[legion_dim_t(1)]); - assert(a_input.shape.size() == b_input.shape.size()); - assert(a_input.shape.size() == output.shape.size()); + assert(a_input.shape.get_volume() == b_input.shape.get_volume()); + assert(a_input.shape.get_volume() == output.shape.get_volume()); int batch = 1; for (int i = 2; i < a_input.shape.dims.num_dims(); i++) { @@ -182,11 +182,6 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { batch *= dim_size; } - // TODO: add support for meta->a_seq_length_dim >= 0 - // or meta->b_seq_length_dim >= 0 - assert((meta->a_seq_length_dim >= a_len) || (iter_config.seq_length == 0)); - assert((meta->b_seq_length_dim >= b_len) || (iter_config.seq_length == 0)); - return profile(backward_kernel, profiling, "[BatchMatmul] backward_time = %.2lfms\n", From e8a6c30a439fcb1168986939de909d6c717c66f5 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 17:57:44 -0700 Subject: [PATCH 20/24] remove asserts --- lib/runtime/src/ops/batch_norm.cc | 6 ------ 1 file changed, 6 deletions(-) diff --git a/lib/runtime/src/ops/batch_norm.cc b/lib/runtime/src/ops/batch_norm.cc index ffd52c96fb..6ebf359051 100644 --- a/lib/runtime/src/ops/batch_norm.cc +++ b/lib/runtime/src/ops/batch_norm.cc @@ -130,9 +130,6 @@ static DeviceSpecific } static optional forward_task_impl(TaskArgumentAccessor const &acc) { - assert(regions.size() == 4); - assert(task->regions.size() == 4); - auto per_device_state = acc.get_argument(PER_DEVICE_STATE); ProfilingSettings profiling = acc.get_argument(PROFILING); @@ -161,9 +158,6 @@ static void forward_task(Task const *task, } static optional backward_task_impl(TaskArgumentAccessor const &acc) { - assert(regions.size() == 7); - assert(task->regions.size() == 7); - auto per_device_state = acc.get_argument(PER_DEVICE_STATE); ProfilingSettings profiling = acc.get_argument(PROFILING); From 09acbe545ad4d6d1cab1e14e55bf1d7e9a638954 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 17:59:10 -0700 Subject: [PATCH 21/24] format --- lib/runtime/src/ops/batch_matmul.cc | 5 ++--- lib/runtime/src/task_spec/op_arg_ref.h | 5 +---- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 00699652e7..45c5e11b9c 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -80,7 +80,7 @@ static DeviceSpecific DeviceSpecific per_device_state = acc.create_device_specific( init_kernel(handle, allocator, a_seq_length_dim, b_seq_length_dim)); - + return per_device_state; } @@ -174,8 +174,7 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(a_input.shape.get_volume() == b_input.shape.get_volume()); assert(a_input.shape.get_volume() == output.shape.get_volume()); int batch = 1; - for (int i = 2; i < a_input.shape.dims.num_dims(); - i++) { + for (int i = 2; i < a_input.shape.dims.num_dims(); i++) { int dim_size = a_input.shape[legion_dim_t(i)]; assert(dim_size == b_input.shape[legion_dim_t(i)]); assert(dim_size == output.shape[legion_dim_t(i)]); diff --git a/lib/runtime/src/task_spec/op_arg_ref.h b/lib/runtime/src/task_spec/op_arg_ref.h index 20ce6892a2..3e931d79a4 100644 --- a/lib/runtime/src/task_spec/op_arg_ref.h +++ b/lib/runtime/src/task_spec/op_arg_ref.h @@ -7,10 +7,7 @@ namespace FlexFlow { -enum class OpArgRefType { - PER_DEVICE_OP_STATE, - PARALLEL_TENSOR_SHAPE - }; +enum class OpArgRefType { PER_DEVICE_OP_STATE, PARALLEL_TENSOR_SHAPE }; template using OpArgRef = ArgRef; From 67d2d1ed9620748ee2a415f7ce64929e613b3b47 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Thu, 31 Aug 2023 12:39:08 -0700 Subject: [PATCH 22/24] concat --- lib/kernels/include/kernels/concat_kernels.h | 15 +- lib/kernels/src/hip/concat_kernels.cpp | 4 - lib/op-attrs/include/op-attrs/ops/concat.h | 4 +- lib/runtime/src/ops/concat.cc | 610 ++++-------------- lib/runtime/src/ops/concat.h | 335 +++++++++- lib/runtime/src/serialization.h | 9 +- .../src/task_spec/op_task_invocation.h | 17 +- lib/runtime/src/task_spec/op_tensor_spec.h | 21 + .../src/task_spec/variadic_tensor_ref.h | 26 + 9 files changed, 540 insertions(+), 501 deletions(-) create mode 100644 lib/runtime/src/task_spec/op_tensor_spec.h create mode 100644 lib/runtime/src/task_spec/variadic_tensor_ref.h diff --git a/lib/kernels/include/kernels/concat_kernels.h b/lib/kernels/include/kernels/concat_kernels.h index 741bbbe9f0..1bbefdf843 100644 --- a/lib/kernels/include/kernels/concat_kernels.h +++ b/lib/kernels/include/kernels/concat_kernels.h @@ -6,28 +6,27 @@ namespace FlexFlow { -class ConcatPerDeviceState : public PerDeviceOpState { -public: - ConcatPerDeviceState(FFHandler handle) : PerDeviceOpState(handle){}; - int legion_axis; - char op_name[MAX_OPNAME]; +struct ConcatPerDeviceState { + req legion_axis; }; +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(ConcatPerDeviceState, legion_axis); + namespace Kernels { namespace Concat { -void init_meta(ConcatPerDeviceState *meta, int legion_axis); +ConcatPerDeviceState init_kernel(ff_dim_t legion_axis); void forward_kernel(ffStream_t stream, ConcatPerDeviceState const *m, GenericTensorAccessorW const &output, - GenericTensorAccessorR const *inputs, + std::vector const &inputs, int num_inputs); void backward_kernel(ffStream_t stream, ConcatPerDeviceState const *m, GenericTensorAccessorR const &output_grad, - GenericTensorAccessorW const *input_grads, + std::vector const &input_grads, int num_inputs); } // namespace Concat diff --git a/lib/kernels/src/hip/concat_kernels.cpp b/lib/kernels/src/hip/concat_kernels.cpp index e818f8b568..f943bc9156 100644 --- a/lib/kernels/src/hip/concat_kernels.cpp +++ b/lib/kernels/src/hip/concat_kernels.cpp @@ -26,10 +26,6 @@ using Legion::Rect; namespace Kernels { namespace Concat { -void init_meta(ConcatPerDeviceState *m, int legion_axis) { - m->legion_axis = legion_axis; -} - template void calc_blk_size(coord_t &num_blocks, coord_t &blk_size, diff --git a/lib/op-attrs/include/op-attrs/ops/concat.h b/lib/op-attrs/include/op-attrs/ops/concat.h index b9bd14a231..cbc864be44 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat.h +++ b/lib/op-attrs/include/op-attrs/ops/concat.h @@ -9,9 +9,9 @@ namespace FlexFlow { struct ConcatAttrs { - ff_dim_t axis; + req axis; }; -FF_VISITABLE_STRUCT(ConcatAttrs, axis); +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(ConcatAttrs, axis); CHECK_VALID_OP_ATTR(ConcatAttrs); } // namespace FlexFlow diff --git a/lib/runtime/src/ops/concat.cc b/lib/runtime/src/ops/concat.cc index f17a33b956..20c54c993e 100644 --- a/lib/runtime/src/ops/concat.cc +++ b/lib/runtime/src/ops/concat.cc @@ -16,537 +16,197 @@ #include "concat.h" #include "kernels/concat_kernels.h" #include "legion/legion_utilities.h" -#include "task_spec.h" +#include "task_spec/variadic_tensor_ref.h" #include "utils/hash-utils.h" +#include "op-attrs/get_output_shapes.h" namespace FlexFlow { -enum Slots { - INPUTS, - OUTPUT, - INPUT_GRADS, - OUTPUT_GRAD, - ATTRS, - PROFILING -} +using namespace FlexFlow::Kernels::Concat; -// declare Legion names -using Legion::ArgumentMap; using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; using Legion::Runtime; using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; -using PCG::Node; -using namespace FlexFlow::Kernels::Concat; +enum Slots { + INPUTS, + OUTPUT, + ATTRS, + PROFILING, + HANDLE, + PER_DEVICE_STATE, + NUM_INPUTS +}; -bool operator==(ConcatParams const &lhs, ConcatParams const &rhs) { - return lhs.axis == rhs.axis; -} +OpTaskInvocation init(ConcatAttrs const &attrs) { + OpTaskBinding binding; -ConcatParams Concat::get_params() const { - ConcatParams params; - params.axis = legion_axis; - return params; -} + binding.bind_arg(PROFILING, profiling_settings()); + binding.bind_arg(ATTRS, attrs); -Tensor - FFModel::concat(int n, Tensor const *tensors, int axis, char const *name) { - Layer *concat = new Layer(this, - OP_CONCAT, - DT_FLOAT, - name, - n /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - tensors); - int numdim = tensors[0]->num_dims; - // Making sure axis is between [0, numdim) - axis = (axis % numdim + numdim) % numdim; - int dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdim; i++) { - dims[i] = tensors[0]->dims[i]; - } - for (int i = 1; i < n; i++) { - assert(tensors[i]->data_type == tensors[0]->data_type); - assert(tensors[i]->num_dims == tensors[0]->num_dims); - for (int j = 0; j < numdim; j++) { - if (j != numdim - axis - 1) { - assert(tensors[i]->dims[j] == tensors[0]->dims[j]); - } else { - dims[j] += tensors[i]->dims[j]; - } - } - } - concat->outputs[0] = create_tensor_legion_ordering( - numdim, dims, tensors[0]->data_type, concat, 0, true /*create_grad*/); - concat->add_int_property("legion_axis", numdim - axis - 1); - layers.push_back(concat); - return concat->outputs[0]; + return {CONCAT_INIT_TASK_ID, binding}; } -Op *Concat::create_operator_from_layer( - FFModel &model, - Layer const *layer, - std::vector const &inputs) { - long long value; - layer->get_int_property("legion_axis", value); - int legion_axis = value; - return new Concat( - model, inputs.size(), inputs.data(), legion_axis, layer->name); -} +OpTaskInvocation forward(ConcatAttrs const &attrs) { + OpTaskBinding binding; + binding.bind_arg(PER_DEVICE_STATE, per_device_op_state()); + binding.bind(INPUTS, get_input_tensors()); + binding.bind(OUTPUT, output_tensor(0)); + binding.bind(NUM_INPUTS, get_number_inputs()); + binding.bind_arg(PROFILING, profiling_settings()); -Concat::Concat(FFModel &model, - int _n, - ParallelTensor const *_tensors, - int _legion_axis, - char const *name) - : Op(model, - OP_CONCAT, - DT_FLOAT, - name, - _n /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - _tensors), - legion_axis(_legion_axis) { - int num_dim = inputs[0]->num_dims; - ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < num_dim; i++) { - dims[i] = inputs[0]->dims[i]; - } - for (int i = 1; i < numInputs; i++) { - assert(inputs[i]->data_type == inputs[0]->data_type); - assert(inputs[i]->num_dims == inputs[0]->num_dims); - for (int j = 0; j < num_dim; j++) { - if (j != legion_axis) { - assert(inputs[i]->dims[j] == inputs[0]->dims[j]); - } else { - // Assert that the concat dim cannot be parallelized - assert(inputs[i]->dims[j].parallel_idx == -1); - assert(inputs[i]->dims[j].degree == 1); - dims[j].size += inputs[i]->dims[j].size; - } - } - } - numOutputs = 1; - outputs[0] = model.create_parallel_tensor_legion_ordering( - num_dim, dims, inputs[0]->data_type, this); + return {CONCAT_FWD_TASK_ID, binding}; } -Concat::Concat(FFModel &model, - ConcatParams const ¶ms, - std::vector const &inputs, - char const *name) - : Concat(model, inputs.size(), inputs.data(), params.axis, name) {} +OpTaskInvocation backward(ConcatAttrs const &attrs) { + OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); -static OpTaskSignature get_init_task_signature() { - OpTaskSignature init(OpTaskType::INIT); + return {CONCAT_BWD_TASK_ID, b}; +} - init.add_arg_slot(ATTRS); - init.add_arg_slot(PROFILING); +static DeviceSpecific + init_task_impl(TaskArgumentAccessor const &acc) { + auto const &attrs = acc.get_argument(ATTRS); + PerDeviceFFHandle handle = acc.get_argument(HANDLE); - init.add_input_slot(INPUTS, SlotType::VARIADIC); - init.add_output_slot(OUTPUT); + DeviceSpecific per_device_state = + acc.create_device_specific(init_kernel(attrs.axis)); + return per_device_state; +} - return init; +static DeviceSpecific + init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + return init_task_impl(acc); } -static OpTaskSignature get_fwd_task_signature() { - OpTaskSignature fwd(OpTaskType::FWD); +static optional forward_task_impl(TaskArgumentAccessor const &acc) { + auto per_device_state = acc.get_argument(PER_DEVICE_STATE); + int number_of_inputs = acc.get_argument(NUM_INPUTS); + ProfilingSettings profiling = acc.get_argument(PROFILING); - fwd.add_arg_slot(ATTRS); + auto output = acc.get_tensor(OUTPUT); + auto inputs = acc.get_variadic_tensor(INPUTS); - fwd.add_input_slot(INPUTS, SlotType::VARIADIC); - fwd.add_output_slot(OUTPUT); + return profile(forward_kernel, + profiling, + "[Concat] forward_time = %.2lfms\n", + &per_device_state, + output, + inputs, + number_of_inputs); +} - return init; +static void forward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + forward_task_impl(acc); } -static OpTaskSignature get_bwd_task_signature() { - OpTaskSignature bwd(OpTaskType::BWD); +static optional backward_task_impl(TaskArgumentAccessor const &acc) { + auto per_device_state = acc.get_argument(PER_DEVICE_STATE); + int number_of_inputs = acc.get_argument(NUM_INPUTS); + ProfilingSettings profiling = acc.get_argument(PROFILING); - bwd.add_arg_slot(ATTRS); + auto input_grads = acc.get_variadic_tensor_grad(INPUTS); + auto output_grad = acc.get_tensor_grad(OUTPUT); - bwd.add_input_grad_slot(INPUT_GRADS, SlotType::VARIADIC); - bwd.add_output_grad_slot(OUTPUT_GRAD); + assert(number_of_inputs <= MAX_NUM_INPUTS); - return bwd; + return profile(backward_kernel, + profiling, + "[Concat] backward_time = %.2lfms\n", + &per_device_state, + output_grad, + input_grads, + number_of_inputs); } -OpTaskBinding Concat::get_init_task_binding() const { - OpTaskBinding binding; - - binding.bind_arg(PROFILING, this->profiling); - binding.bind_arg(ATTRS, this->attrs); - - return binding; +static void backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + backward_task_impl(acc); } -OpTaskBinding Concat::get_fwd_task_binding() const { - OpTaskBinding binding; +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + ConcatAttrs const &attrs, + InputVariadicParallelTensorDesc const &inputs_shape, + ProfilingSettings const &settings, + MachineView const &mv) { + int numInputs = (inputs_shape.shapes).size(); + assert(numInputs <= MAX_NUM_INPUTS); - binding.bind_arg(ATTRS, this->attrs); + auto env = sim.new_environment(); - for (int i = 0; i < this->attrs.n; i++) { - binding.bind(INPUTS, input_tensor(i)); - } + ParallelTensorShape output_shape = get_output_shape(attrs, inputs_shape.shapes); - binding.bind(OUTPUT, output_tensor(0)); + SimTaskBinding init_binding; + init_binding.bind_arg(PROFILING, settings); + init_binding.bind_arg(ATTRS, attrs); - return binding; -} + auto init_accessor = + env.get_init_accessor(CONCAT_INIT_TASK_ID, init_binding); + DeviceSpecific per_device_state = + init_task_impl(init_accessor); -OpTaskBinding Concat::get_bwd_task_binding() const { - OpTaskBinding binding; + SimTaskBinding fwd_binding; + fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); + fwd_binding.bind(INPUTS, inputs_shape); + fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind_arg(NUM_INPUTS, numInputs); + fwd_binding.bind_arg(PROFILING, settings); - binding.bind_arg(ATTRS, this->attrs); + SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); - for (int i = 0; i < this->attrs.n; i++) { - binding.bind(INPUT_GRADS, input_tensor(i).grad()); - } + auto fwd_accessor = env.get_fwd_accessor(CONCAT_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = env.get_bwd_accessor(CONCAT_BWD_TASK_ID, bwd_binding); - binding.bind(OUTPUT_GRAD, output_tensor(0).grad()); + float forward_time = forward_task_impl(fwd_accessor).value(); + float backward_time = backward_task_impl(bwd_accessor).value(); - return binding; -} + float sync_time = default_estimate_sync_time(env); + return make_metrics(forward_time, backward_time, sync_time, env); -void Concat::init(FFModel const &ff) { - this->execute_task(ff, CONCAT_INIT_TASK_ID, get_init_task_signature()); - // assert(check_output_input_weight_same_parallel_is()); - // parallel_is = outputs[0]->parallel_is; - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_init(ff, argmap); - // IndexLauncher launcher(CONCAT_INIT_TASK_ID, - // parallel_is, - // TaskArgument(this, sizeof(Concat)), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // for (int i = 0; i < numInputs; i++) { - // launcher.add_region_requirement(RegionRequirement(inputs[i]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[i]->region)); - // launcher.add_field(i + 1, FID_DATA); - // } - // for (int i = 0; i < numInputs; i++) { - // launcher.add_region_requirement(RegionRequirement(inputs[i]->part_grad, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // inputs[i]->region_grad)); - // launcher.add_field(i + numInputs + 1, FID_DATA); - // } - // FutureMap fm = runtime->execute_index_space(ctx, launcher); - // fm.wait_all_results(); - // set_opmeta_from_futuremap(ff, fm); } -PerDeviceOpState *Concat::init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - auto const &attrs = acc.get_argument(ATTRS); - bool profiling = acc.get_argument(PROFILING); - - FFHandler handler = *((FFHandler const *)task->local_args); - ConcatPerDeviceState *m = new ConcatPerDeviceState(handler); - // Note that our internal axis index ordering is opposite to other frameworks - init_meta(m, attrs.legion_axis); - m->profiling = profiling; - std::strcpy(m->op_name, attrs.name); - return m; -} +template <> +void register_task() { + OpTaskSignature init(OpTaskType::INIT); -void Concat::forward(FFModel const &ff) { - this->execute_task(ff, CONCAT_FWD_TASK_ID, get_fwd_task_signature()); - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_forward(ff, argmap); - // IndexLauncher launcher(CONCAT_FWD_TASK_ID, - // parallel_is, - // TaskArgument(this, sizeof(Concat)), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // for (int i = 0; i < numInputs; i++) { - // launcher.add_region_requirement(RegionRequirement(inputs[i]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[i]->region)); - // launcher.add_field(i + 1, FID_DATA); - // } - // runtime->execute_index_space(ctx, launcher); + init.add_arg_slot(ATTRS); + init.add_arg_slot(PROFILING); + + register_task(CONCAT_INIT_TASK_ID, "Concat Init", init, init_task); } -/* - regions[0](O): output - regions[1..numInputs](I): inputs -*/ -void Concat::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - // Concat const *cc = (Concat *)task->args; - ConcatPerDeviceState const *m = *((ConcatPerDeviceState **)task->local_args); - // Note that our internal axis index ordering is opposite to other frameworks - assert(regions.size() == attrs.n + 1); - assert(task->regions.size() == attrs.n + 1); - // Domain out_domain = runtime->get_index_space_domain( - // ctx, task->regions[0].region.get_index_space()); - // GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( - // DT_FLOAT, regions[0], task->regions[0], FID_DATA, ctx, runtime); - // assert(out_domain.get_dim() == cc->outputs[0].num_dims); - // Domain in_domain[MAX_NUM_INPUTS]; - // for (int i = 0; i < cc->numInputs; i++) - // in_domain[i] = runtime->get_index_space_domain( - // ctx, task->regions[i + 1].region.get_index_space()); - // float *output = helperGetTensorPointerWO( - // regions[0], task->regions[0], FID_DATA, ctx, runtime); - - auto output = acc.get_tensor(OUTPUT); - auto inputs = acc.get_variadic_tensor(INPUTS); - - // GenericTensorAccessorR inputs[MAX_NUM_INPUTS]; - // for (int i = 0; i < attrs.n; i++) { - // // inputs[i] = helperGetTensorPointerRO( - // // regions[i + 1], task->regions[i + 1], FID_DATA, ctx, runtime); - // inputs[i] = helperGetGenericTensorAccessorRO( - // DT_FLOAT, regions[i + 1], task->regions[i + 1], FID_DATA, ctx, - // runtime); - // } - profile(forward_kernel, - m->profiling, - "[Concat] forward_time = %.2lfms\n", - m, - output, - inputs, - attrs.n) -} +template <> +void register_task() { + OpTaskSignature fwd(OpTaskType::FWD); -void Concat::backward(FFModel const &ff) { - this->execute_task(ff, CONCAT_BWD_TASK_ID, get_bwd_task_signature()); - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_backward(ff, argmap); - // IndexLauncher launcher(CONCAT_BWD_TASK_ID, - // parallel_is, - // TaskArgument(this, sizeof(Concat)), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // outputs[0]->region_grad)); - // launcher.add_field(0, FID_DATA); - // for (int i = 0; i < numInputs; i++) { - // launcher.add_region_requirement(RegionRequirement(inputs[i]->part_grad, - // 0 /*projection id*/, - // READ_WRITE, - // EXCLUSIVE, - // inputs[i]->region_grad)); - // // LogicalRegion lr = inputs[i]->region_grad; - // // printf("concat[%d]: region(%d,%d,%d)\n", i+1, - // // lr.get_index_space().get_id(), lr.get_field_space().get_id(), - // // lr.get_tree_id()); - // launcher.add_field(i + 1, FID_DATA); - // } - // runtime->execute_index_space(ctx, launcher); -} + fwd.add_arg_slot(NUM_INPUTS); + fwd.add_arg_slot(PROFILING); + fwd.add_input_slot(INPUTS, SlotType::VARIADIC); + fwd.add_output_slot(OUTPUT); + fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); -/* - regions[0](I): output_grad - regions[1..numInputs](I/O): input_grad -*/ -void Concat::backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - // Concat const *cc = (Concat *)task->args; - ConcatPerDeviceState const *m = *((ConcatPerDeviceState **)task->local_args); - // Note that our internal axis index ordering is opposite to other frameworks - assert(regions.size() == attrs.n + 1); - assert(task->regions.size() == attrs.n + 1); - assert(attrs.n <= MAX_NUM_INPUTS); - // Domain out_grad_domain = runtime->get_index_space_domain( - // ctx, task->regions[0].region.get_index_space()); - // assert(out_grad_domain.get_dim() == cc->outputs[0].num_dims); - // Domain in_grad_domains[MAX_NUM_INPUTS]; - // for (int i = 0; i < cc->numInputs; i++) - // in_grad_domains[i] = runtime->get_index_space_domain( - // ctx, task->regions[i + 1].region.get_index_space()); - // float const *output_grad = helperGetTensorPointerRO( - // regions[0], task->regions[0], FID_DATA, ctx, runtime); - - auto input_grads = acc.get_variadic_tensor(INPUT_GRADS); - auto output_grad = acc.get_tensor(OUTPUT_GRAD); - - // GenericTensorAccessorR output_grad = helperGetGenericTensorAccessorRO( - // DT_FLOAT, regions[0], task->regions[0], FID_DATA, ctx, runtime); - // GenericTensorAccessorW input_grads[MAX_NUM_INPUTS]; - // for (int i = 0; i < attrs.n; i++) { - // // input_grads[i] = helperGetTensorPointerRW( - // // regions[i + 1], task->regions[i + 1], FID_DATA, ctx, runtime); - // input_grads[i] = helperGetGenericTensorAccessorRW( - // DT_FLOAT, regions[i + 1], task->regions[i + 1], FID_DATA, ctx, - // runtime); - // } - - profile(backward_kernel, - m->profiling, - "[Concat] backward_time = %.2lfms\n", - m, - output_grad, - input_grads, - attrs.n) + register_task(CONCAT_FWD_TASK_ID, "Concat Fwd", fwd, forward_task); } -bool Concat::get_int_parameter(PMParameter para, int *value) const { - switch (para) { - case PM_AXIS: - *value = legion_axis; - return true; - default: - return Op::get_int_parameter(para, value); - } -} +template <> +void register_task() { + OpTaskSignature bwd = + infer_bwd_signature(get_op_signature(CONCAT_FWD_TASK_ID)); -bool Concat::measure_operator_cost(Simulator *sim, - MachineView const &mv, - CostMetrics &cost_metrics) const { - assert(numInputs <= MAX_NUM_INPUTS); - ParallelTensorBase sub_inputs[MAX_NUM_INPUTS], sub_output; - if (!outputs[0]->get_sub_tensor(mv, sub_output)) { - return false; - } - for (int i = 0; i < numInputs; i++) { - if (!inputs[i]->get_sub_tensor(mv, sub_inputs[i])) { - return false; - } - } - - ConcatPerDeviceState *m = sim->concat_meta; - init_meta(m, this->legion_axis); - - sim->free_all(); - float *input_ptrs[MAX_NUM_INPUTS]; - float *input_grad_ptrs[MAX_NUM_INPUTS]; - bool out_of_memory = false; - for (int i = 0; i < numInputs; i++) { - input_ptrs[i] = - (float *)sim->allocate(sub_inputs[i].get_volume(), DT_FLOAT); - out_of_memory = out_of_memory || (input_ptrs[i] == NULL); - } - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - Domain out_domain = sub_output.get_domain(); - float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - GenericTensorAccessorW output_acc(DT_FLOAT, out_domain, output_ptr); - cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - out_of_memory = out_of_memory || (output_ptr == NULL); - if (out_of_memory) { - cost_metrics.forward_time = Simulator::MAXIMUM_TASK_RUN_TIME; - cost_metrics.backward_time = Simulator::MAXIMUM_TASK_RUN_TIME; - return true; - } - - Domain in_domains[MAX_NUM_INPUTS]; - GenericTensorAccessorR input_acc[MAX_NUM_INPUTS]; - for (int i = 0; i < numInputs; i++) { - in_domains[i] = sub_inputs[i].get_domain(); - input_acc[i] = - GenericTensorAccessorR(DT_FLOAT, in_domains[i], input_ptrs[i]); - } - - assert(m->profiling == false); - - std::function forward, backward; - forward = [&](ffStream_t stream) { - forward_kernel(stream, m, output_acc, input_acc, numInputs); - }; - if (sim->computationMode == COMP_MODE_TRAINING) { - GenericTensorAccessorW input_grad_accs[MAX_NUM_INPUTS]; - for (int i = 0; i < numInputs; i++) { - input_grad_ptrs[i] = - (float *)sim->allocate(sub_inputs[i].get_volume(), DT_FLOAT); - out_of_memory = out_of_memory || (input_grad_ptrs[i] == NULL); - input_grad_accs[i] = - GenericTensorAccessorW(DT_FLOAT, in_domains[i], input_grad_ptrs[i]); - } - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - float *output_grad_ptr = - (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - GenericTensorAccessorR output_grad_acc( - DT_FLOAT, out_domain, output_grad_ptr); - cost_metrics.outputs_memory += - cost_metrics.total_mem_diff_from(sim->offset); - - out_of_memory = out_of_memory || (output_grad_ptr == NULL); - if (out_of_memory) { - cost_metrics.forward_time = Simulator::MAXIMUM_TASK_RUN_TIME; - cost_metrics.backward_time = Simulator::MAXIMUM_TASK_RUN_TIME; - return true; - } - backward = [&](ffStream_t stream) { - backward_kernel(stream, m, output_grad_acc, input_grad_accs, numInputs); - }; - } - - inner_measure_operator_cost(sim, forward, backward, cost_metrics); - - if (sim->computationMode == COMP_MODE_TRAINING) { - printf( - "[Measure Concat] name(%s) forward_time(%.4lf) backward_time(%.4lf)\n", - name, - cost_metrics.forward_time, - cost_metrics.backward_time); - } else { - printf("[Measure Concat] name(%s) forward_time(%.4lf)\n", - name, - cost_metrics.forward_time); - } - - return true; + register_task(CONCAT_BWD_TASK_ID, "BatchMatmul Bwd", bwd, backward_task); } + }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/concat.h b/lib/runtime/src/ops/concat.h index 0e0c0c2523..5f792566b4 100644 --- a/lib/runtime/src/ops/concat.h +++ b/lib/runtime/src/ops/concat.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_CONCAT_H #include "op-attrs/ops/concat.h" -#include "op_task_invocation.h" +#include "task_spec/op_task_invocation.h" #include "sim_environment.h" namespace FlexFlow { @@ -76,3 +76,336 @@ CostMetrics } // namespace FlexFlow #endif + + +// bool operator==(ConcatParams const &lhs, ConcatParams const &rhs) { +// return lhs.axis == rhs.axis; +// } + +// ConcatParams Concat::get_params() const { +// ConcatParams params; +// params.axis = legion_axis; +// return params; +// } + +// Tensor +// FFModel::concat(int n, Tensor const *tensors, int axis, char const *name) { +// Layer *concat = new Layer(this, +// OP_CONCAT, +// DT_FLOAT, +// name, +// n /*inputs*/, +// 0 /*weights*/, +// 1 /*outputs*/, +// tensors); +// int numdim = tensors[0]->num_dims; +// // Making sure axis is between [0, numdim) +// axis = (axis % numdim + numdim) % numdim; +// int dims[MAX_TENSOR_DIM]; +// for (int i = 0; i < numdim; i++) { +// dims[i] = tensors[0]->dims[i]; +// } +// for (int i = 1; i < n; i++) { +// assert(tensors[i]->data_type == tensors[0]->data_type); +// assert(tensors[i]->num_dims == tensors[0]->num_dims); +// for (int j = 0; j < numdim; j++) { +// if (j != numdim - axis - 1) { +// assert(tensors[i]->dims[j] == tensors[0]->dims[j]); +// } else { +// dims[j] += tensors[i]->dims[j]; +// } +// } +// } +// concat->outputs[0] = create_tensor_legion_ordering( +// numdim, dims, tensors[0]->data_type, concat, 0, true /*create_grad*/); +// concat->add_int_property("legion_axis", numdim - axis - 1); +// layers.push_back(concat); +// return concat->outputs[0]; +// } + +// Op *Concat::create_operator_from_layer( +// FFModel &model, +// Layer const *layer, +// std::vector const &inputs) { +// long long value; +// layer->get_int_property("legion_axis", value); +// int legion_axis = value; +// return new Concat( +// model, inputs.size(), inputs.data(), legion_axis, layer->name); +// } + +// Concat::Concat(FFModel &model, +// int _n, +// ParallelTensor const *_tensors, +// int _legion_axis, +// char const *name) +// : Op(model, +// OP_CONCAT, +// DT_FLOAT, +// name, +// _n /*inputs*/, +// 0 /*weights*/, +// 1 /*outputs*/, +// _tensors), +// legion_axis(_legion_axis) { +// int num_dim = inputs[0]->num_dims; +// ParallelDim dims[MAX_TENSOR_DIM]; +// for (int i = 0; i < num_dim; i++) { +// dims[i] = inputs[0]->dims[i]; +// } +// for (int i = 1; i < numInputs; i++) { +// assert(inputs[i]->data_type == inputs[0]->data_type); +// assert(inputs[i]->num_dims == inputs[0]->num_dims); +// for (int j = 0; j < num_dim; j++) { +// if (j != legion_axis) { +// assert(inputs[i]->dims[j] == inputs[0]->dims[j]); +// } else { +// // Assert that the concat dim cannot be parallelized +// assert(inputs[i]->dims[j].parallel_idx == -1); +// assert(inputs[i]->dims[j].degree == 1); +// dims[j].size += inputs[i]->dims[j].size; +// } +// } +// } +// numOutputs = 1; +// outputs[0] = model.create_parallel_tensor_legion_ordering( +// num_dim, dims, inputs[0]->data_type, this); +// } + +// Concat::Concat(FFModel &model, +// ConcatParams const ¶ms, +// std::vector const &inputs, +// char const *name) +// : Concat(model, inputs.size(), inputs.data(), params.axis, name) {} + + +// void Concat::init(FFModel const &ff) { +// this->execute_task(ff, CONCAT_INIT_TASK_ID, get_init_task_signature()); +// // assert(check_output_input_weight_same_parallel_is()); +// // parallel_is = outputs[0]->parallel_is; +// // ArgumentMap argmap; +// // Context ctx = ff.config.lg_ctx; +// // Runtime *runtime = ff.config.lg_hlr; +// // set_argumentmap_for_init(ff, argmap); +// // IndexLauncher launcher(CONCAT_INIT_TASK_ID, +// // parallel_is, +// // TaskArgument(this, sizeof(Concat)), +// // argmap, +// // Predicate::TRUE_PRED, +// // false /*must*/, +// // 0 /*mapper_id*/, +// // outputs[0]->machine_view.hash()); +// // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// // 0 /*projection id*/, +// // WRITE_ONLY, +// // EXCLUSIVE, +// // outputs[0]->region)); +// // launcher.add_field(0, FID_DATA); +// // for (int i = 0; i < numInputs; i++) { +// // launcher.add_region_requirement(RegionRequirement(inputs[i]->part, +// // 0 /*projection id*/, +// // READ_ONLY, +// // EXCLUSIVE, +// // inputs[i]->region)); +// // launcher.add_field(i + 1, FID_DATA); +// // } +// // for (int i = 0; i < numInputs; i++) { +// // launcher.add_region_requirement(RegionRequirement(inputs[i]->part_grad, +// // 0 /*projection id*/, +// // WRITE_ONLY, +// // EXCLUSIVE, +// // inputs[i]->region_grad)); +// // launcher.add_field(i + numInputs + 1, FID_DATA); +// // } +// // FutureMap fm = runtime->execute_index_space(ctx, launcher); +// // fm.wait_all_results(); +// // set_opmeta_from_futuremap(ff, fm); +// } + + +// void Concat::forward(FFModel const &ff) { +// this->execute_task(ff, CONCAT_FWD_TASK_ID, get_fwd_task_signature()); +// // ArgumentMap argmap; +// // Context ctx = ff.config.lg_ctx; +// // Runtime *runtime = ff.config.lg_hlr; +// // set_argumentmap_for_forward(ff, argmap); +// // IndexLauncher launcher(CONCAT_FWD_TASK_ID, +// // parallel_is, +// // TaskArgument(this, sizeof(Concat)), +// // argmap, +// // Predicate::TRUE_PRED, +// // false /*must*/, +// // 0 /*mapper_id*/, +// // outputs[0]->machine_view.hash()); +// // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// // 0 /*projection id*/, +// // WRITE_ONLY, +// // EXCLUSIVE, +// // outputs[0]->region)); +// // launcher.add_field(0, FID_DATA); +// // for (int i = 0; i < numInputs; i++) { +// // launcher.add_region_requirement(RegionRequirement(inputs[i]->part, +// // 0 /*projection id*/, +// // READ_ONLY, +// // EXCLUSIVE, +// // inputs[i]->region)); +// // launcher.add_field(i + 1, FID_DATA); +// // } +// // runtime->execute_index_space(ctx, launcher); +// } + +/* + regions[0](O): output + regions[1..numInputs](I): inputs +*/ + + +// void Concat::backward(FFModel const &ff) { +// this->execute_task(ff, CONCAT_BWD_TASK_ID, get_bwd_task_signature()); +// // ArgumentMap argmap; +// // Context ctx = ff.config.lg_ctx; +// // Runtime *runtime = ff.config.lg_hlr; +// // set_argumentmap_for_backward(ff, argmap); +// // IndexLauncher launcher(CONCAT_BWD_TASK_ID, +// // parallel_is, +// // TaskArgument(this, sizeof(Concat)), +// // argmap, +// // Predicate::TRUE_PRED, +// // false /*must*/, +// // 0 /*mapper_id*/, +// // outputs[0]->machine_view.hash()); +// // launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, +// // 0 /*projection id*/, +// // READ_ONLY, +// // EXCLUSIVE, +// // outputs[0]->region_grad)); +// // launcher.add_field(0, FID_DATA); +// // for (int i = 0; i < numInputs; i++) { +// // launcher.add_region_requirement(RegionRequirement(inputs[i]->part_grad, +// // 0 /*projection id*/, +// // READ_WRITE, +// // EXCLUSIVE, +// // inputs[i]->region_grad)); +// // // LogicalRegion lr = inputs[i]->region_grad; +// // // printf("concat[%d]: region(%d,%d,%d)\n", i+1, +// // // lr.get_index_space().get_id(), lr.get_field_space().get_id(), +// // // lr.get_tree_id()); +// // launcher.add_field(i + 1, FID_DATA); +// // } +// // runtime->execute_index_space(ctx, launcher); +// } + +/* + regions[0](I): output_grad + regions[1..numInputs](I/O): input_grad +*/ + +// bool Concat::get_int_parameter(PMParameter para, int *value) const { +// switch (para) { +// case PM_AXIS: +// *value = legion_axis; +// return true; +// default: +// return Op::get_int_parameter(para, value); +// } +// } + + +// bool Concat::measure_operator_cost(Simulator *sim, +// MachineView const &mv, +// CostMetrics &cost_metrics) const { +// assert(numInputs <= MAX_NUM_INPUTS); +// ParallelTensorBase sub_inputs[MAX_NUM_INPUTS], sub_output; +// if (!outputs[0]->get_sub_tensor(mv, sub_output)) { +// return false; +// } +// for (int i = 0; i < numInputs; i++) { +// if (!inputs[i]->get_sub_tensor(mv, sub_inputs[i])) { +// return false; +// } +// } + +// ConcatPerDeviceState *m = sim->concat_meta; +// init_meta(m, this->legion_axis); + +// sim->free_all(); +// float *input_ptrs[MAX_NUM_INPUTS]; +// float *input_grad_ptrs[MAX_NUM_INPUTS]; +// bool out_of_memory = false; +// for (int i = 0; i < numInputs; i++) { +// input_ptrs[i] = +// (float *)sim->allocate(sub_inputs[i].get_volume(), DT_FLOAT); +// out_of_memory = out_of_memory || (input_ptrs[i] == NULL); +// } +// cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); + +// Domain out_domain = sub_output.get_domain(); +// float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); +// GenericTensorAccessorW output_acc(DT_FLOAT, out_domain, output_ptr); +// cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); + +// out_of_memory = out_of_memory || (output_ptr == NULL); +// if (out_of_memory) { +// cost_metrics.forward_time = Simulator::MAXIMUM_TASK_RUN_TIME; +// cost_metrics.backward_time = Simulator::MAXIMUM_TASK_RUN_TIME; +// return true; +// } + +// Domain in_domains[MAX_NUM_INPUTS]; +// GenericTensorAccessorR input_acc[MAX_NUM_INPUTS]; +// for (int i = 0; i < numInputs; i++) { +// in_domains[i] = sub_inputs[i].get_domain(); +// input_acc[i] = +// GenericTensorAccessorR(DT_FLOAT, in_domains[i], input_ptrs[i]); +// } + +// assert(m->profiling == false); + +// std::function forward, backward; +// forward = [&](ffStream_t stream) { +// forward_kernel(stream, m, output_acc, input_acc, numInputs); +// }; +// if (sim->computationMode == COMP_MODE_TRAINING) { +// GenericTensorAccessorW input_grad_accs[MAX_NUM_INPUTS]; +// for (int i = 0; i < numInputs; i++) { +// input_grad_ptrs[i] = +// (float *)sim->allocate(sub_inputs[i].get_volume(), DT_FLOAT); +// out_of_memory = out_of_memory || (input_grad_ptrs[i] == NULL); +// input_grad_accs[i] = +// GenericTensorAccessorW(DT_FLOAT, in_domains[i], input_grad_ptrs[i]); +// } +// cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); +// float *output_grad_ptr = +// (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); +// GenericTensorAccessorR output_grad_acc( +// DT_FLOAT, out_domain, output_grad_ptr); +// cost_metrics.outputs_memory += +// cost_metrics.total_mem_diff_from(sim->offset); + +// out_of_memory = out_of_memory || (output_grad_ptr == NULL); +// if (out_of_memory) { +// cost_metrics.forward_time = Simulator::MAXIMUM_TASK_RUN_TIME; +// cost_metrics.backward_time = Simulator::MAXIMUM_TASK_RUN_TIME; +// return true; +// } +// backward = [&](ffStream_t stream) { +// backward_kernel(stream, m, output_grad_acc, input_grad_accs, numInputs); +// }; +// } + +// inner_measure_operator_cost(sim, forward, backward, cost_metrics); + +// if (sim->computationMode == COMP_MODE_TRAINING) { +// printf( +// "[Measure Concat] name(%s) forward_time(%.4lf) backward_time(%.4lf)\n", +// name, +// cost_metrics.forward_time, +// cost_metrics.backward_time); +// } else { +// printf("[Measure Concat] name(%s) forward_time(%.4lf)\n", +// name, +// cost_metrics.forward_time); +// } + +// return true; +// } \ No newline at end of file diff --git a/lib/runtime/src/serialization.h b/lib/runtime/src/serialization.h index adf838201a..6839bfdde5 100644 --- a/lib/runtime/src/serialization.h +++ b/lib/runtime/src/serialization.h @@ -9,7 +9,8 @@ #include "utils/optional.h" #include "utils/variant.h" #include "utils/visitable.h" -#include +#include "utils/type_traits.h" +#include "utils/required.h" namespace FlexFlow { @@ -77,6 +78,12 @@ struct is_trivially_serializable< typename std::enable_if::value>::type> : std::true_type {}; +template +struct is_trivially_serializable>> : is_trivially_serializable> {}; + +template +struct is_trivially_serializable> : is_trivially_serializable {}; + template <> struct is_trivially_serializable : std::true_type {}; template <> diff --git a/lib/runtime/src/task_spec/op_task_invocation.h b/lib/runtime/src/task_spec/op_task_invocation.h index 07f5bf12ae..9203f2f4f1 100644 --- a/lib/runtime/src/task_spec/op_task_invocation.h +++ b/lib/runtime/src/task_spec/op_task_invocation.h @@ -5,6 +5,8 @@ #include "index_task_invocation.h" #include "legion.h" #include "op_arg_ref.h" +#include "op_tensor_spec.h" +#include "variadic_tensor_ref.h" #include "op_task_signature.h" #include "runtime/config.h" #include "runtime/profiling.h" @@ -22,16 +24,6 @@ namespace FlexFlow { enum class IsTrainable { YES, NO }; -struct OpTensorSpec { - TensorRole role; - req idx; -}; -FF_VISITABLE_STRUCT(OpTensorSpec, role, idx); - -OpTensorSpec input_tensor(int); -OpTensorSpec output_tensor(int); -OpTensorSpec weight_tensor(int); - using OpArgSpec = variant + void bind(slot_id name, VariadicTensorRef const &t) { + NOT_IMPLEMENTED(); + } + template void bind_device_specific_arg(slot_id name, T const &t) { NOT_IMPLEMENTED(); diff --git a/lib/runtime/src/task_spec/op_tensor_spec.h b/lib/runtime/src/task_spec/op_tensor_spec.h new file mode 100644 index 0000000000..84261141d3 --- /dev/null +++ b/lib/runtime/src/task_spec/op_tensor_spec.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_OP_TENSOR_SPEC_REF_H +#define _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_OP_TENSOR_SPEC_REF_H + +#include "op_task_signature.h" + +namespace FlexFlow { + +struct OpTensorSpec { + TensorRole role; + req idx; +}; +FF_VISITABLE_STRUCT(OpTensorSpec, role, idx); + +OpTensorSpec input_tensor(int); +OpTensorSpec output_tensor(int); +OpTensorSpec weight_tensor(int); + + +} // namespace FlexFlow + +#endif diff --git a/lib/runtime/src/task_spec/variadic_tensor_ref.h b/lib/runtime/src/task_spec/variadic_tensor_ref.h new file mode 100644 index 0000000000..1b9bc33b4f --- /dev/null +++ b/lib/runtime/src/task_spec/variadic_tensor_ref.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_VARIADIC_TENSOR_ARG_REF_H +#define _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_VARIADIC_TENSOR_ARG_REF_H + +#include "op_tensor_spec.h" +#include "arg_ref.h" + +namespace FlexFlow { + +enum class VariadicTensorRefType {INPUT_TENSORS, + NUM_INPUTS}; + +template +using VariadicTensorRef = ArgRef; + +VariadicTensorRef get_input_tensors() { + return {VariadicTensorRefType::INPUT_TENSORS}; +} + +VariadicTensorRef get_number_inputs() { + return {VariadicTensorRefType::NUM_INPUTS}; +} + + +} // namespace FlexFlow + +#endif From bf750675ebabc1abbf3bad7620455f80a74d4b1a Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Fri, 1 Sep 2023 11:16:07 -0700 Subject: [PATCH 23/24] format --- lib/kernels/include/kernels/combine_kernels.h | 3 +-- lib/runtime/src/ops/combine.cc | 15 +++++---------- lib/runtime/src/ops/combine.h | 18 ++++++++---------- 3 files changed, 14 insertions(+), 22 deletions(-) diff --git a/lib/kernels/include/kernels/combine_kernels.h b/lib/kernels/include/kernels/combine_kernels.h index 44ab67d9a7..24b9adb803 100644 --- a/lib/kernels/include/kernels/combine_kernels.h +++ b/lib/kernels/include/kernels/combine_kernels.h @@ -10,8 +10,7 @@ struct CombinePerDeviceState { req data_type; }; -FF_VISITABLE_STRUCT_NO_EQ(CombinePerDeviceState, - data_type); +FF_VISITABLE_STRUCT_NO_EQ(CombinePerDeviceState, data_type); namespace Kernels { namespace Combine { diff --git a/lib/runtime/src/ops/combine.cc b/lib/runtime/src/ops/combine.cc index c2d8d9a017..a22b317d10 100644 --- a/lib/runtime/src/ops/combine.cc +++ b/lib/runtime/src/ops/combine.cc @@ -26,12 +26,7 @@ using Legion::Task; using namespace FlexFlow::Kernels::Combine; -enum Slots { - INPUT, - OUTPUT, - PROFILING, - PER_DEVICE_STATE -}; +enum Slots { INPUT, OUTPUT, PROFILING, PER_DEVICE_STATE }; OpTaskInvocation init(CombineAttrs const &attrs) { OpTaskBinding binding; @@ -44,7 +39,8 @@ OpTaskInvocation init(CombineAttrs const &attrs) { OpTaskInvocation forward(CombineAttrs const &attrs) { OpTaskBinding binding; - binding.bind_arg(PER_DEVICE_STATE, per_device_op_state()); + binding.bind_arg(PER_DEVICE_STATE, + per_device_op_state()); binding.bind_arg(PROFILING, profiling_settings()); binding.bind(INPUT, input_tensor(0)); @@ -61,7 +57,7 @@ OpTaskInvocation backward(CombineAttrs const &attrs) { static DeviceSpecific init_task_impl(TaskArgumentAccessor const &acc) { - + auto input = acc.get_tensor(INPUT); DeviceSpecific per_device_state = @@ -171,5 +167,4 @@ void register_task() { }; // namespace FlexFlow -namespace std { -}; // namespace std +namespace std {}; // namespace std diff --git a/lib/runtime/src/ops/combine.h b/lib/runtime/src/ops/combine.h index 4b96ce6495..512c3be363 100644 --- a/lib/runtime/src/ops/combine.h +++ b/lib/runtime/src/ops/combine.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_COMBINE_H #include "op-attrs/ops/combine.h" -#include "task_spec/op_task_invocation.h" #include "sim_environment.h" +#include "task_spec/op_task_invocation.h" namespace FlexFlow { @@ -112,7 +112,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, // output_grad_ptr, input_grad_ptr, output_grad_domain.get_volume()); // } - // void Combine::backward_task(Task const *task, // std::vector const ®ions, // Context ctx, @@ -171,7 +170,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, // return rf; // } - // void Combine::init(FFModel const &ff) { // parallel_is = outputs[0]->parallel_is; // ArgumentMap argmap; @@ -188,7 +186,8 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, // 0 /*mapper_id*/, // outputs[0]->machine_view.hash()); // launcher.add_region_requirement(RegionRequirement( -// input_lp, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, inputs[0]->region)); +// input_lp, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, +// inputs[0]->region)); // launcher.add_field(0, FID_DATA); // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, // 0 /*projection id*/, @@ -232,7 +231,8 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, // 0 /*mapper_id*/, // outputs[0]->machine_view.hash()); // launcher.add_region_requirement(RegionRequirement( -// input_lp, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, inputs[0]->region)); +// input_lp, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, +// inputs[0]->region)); // launcher.add_field(0, FID_DATA); // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, // 0 /*projection id*/, @@ -274,8 +274,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, // runtime->execute_index_space(ctx, launcher); // } - - // CombineParams Combine::get_params() const { // CombineParams params; // params.combine_legion_dim = this->combine_dim; @@ -338,9 +336,9 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, // template // void Combine::forward_task_with_type(Task const *task, -// std::vector const ®ions, -// Context ctx, -// Runtime *runtime) { +// std::vector const +// ®ions, Context ctx, Runtime *runtime) +// { // Domain input_domain = runtime->get_index_space_domain( // ctx, task->regions[0].region.get_index_space()); // Domain output_domain = runtime->get_index_space_domain( From d1f4fb91729c7bf9cfe612c46cf33707c0ae656a Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Fri, 1 Sep 2023 11:17:05 -0700 Subject: [PATCH 24/24] format --- lib/kernels/include/kernels/concat_kernels.h | 11 ++-- lib/runtime/src/ops/concat.cc | 58 ++++++++++--------- lib/runtime/src/ops/concat.h | 47 ++++++++------- lib/runtime/src/serialization.h | 7 ++- .../src/task_spec/op_task_invocation.h | 4 +- lib/runtime/src/task_spec/op_tensor_spec.h | 1 - .../src/task_spec/variadic_tensor_ref.h | 6 +- 7 files changed, 67 insertions(+), 67 deletions(-) diff --git a/lib/kernels/include/kernels/concat_kernels.h b/lib/kernels/include/kernels/concat_kernels.h index 1bbefdf843..165f63f332 100644 --- a/lib/kernels/include/kernels/concat_kernels.h +++ b/lib/kernels/include/kernels/concat_kernels.h @@ -23,11 +23,12 @@ void forward_kernel(ffStream_t stream, std::vector const &inputs, int num_inputs); -void backward_kernel(ffStream_t stream, - ConcatPerDeviceState const *m, - GenericTensorAccessorR const &output_grad, - std::vector const &input_grads, - int num_inputs); +void backward_kernel( + ffStream_t stream, + ConcatPerDeviceState const *m, + GenericTensorAccessorR const &output_grad, + std::vector const &input_grads, + int num_inputs); } // namespace Concat } // namespace Kernels diff --git a/lib/runtime/src/ops/concat.cc b/lib/runtime/src/ops/concat.cc index 20c54c993e..d8f610c3ea 100644 --- a/lib/runtime/src/ops/concat.cc +++ b/lib/runtime/src/ops/concat.cc @@ -16,9 +16,9 @@ #include "concat.h" #include "kernels/concat_kernels.h" #include "legion/legion_utilities.h" +#include "op-attrs/get_output_shapes.h" #include "task_spec/variadic_tensor_ref.h" #include "utils/hash-utils.h" -#include "op-attrs/get_output_shapes.h" namespace FlexFlow { @@ -50,7 +50,8 @@ OpTaskInvocation init(ConcatAttrs const &attrs) { OpTaskInvocation forward(ConcatAttrs const &attrs) { OpTaskBinding binding; - binding.bind_arg(PER_DEVICE_STATE, per_device_op_state()); + binding.bind_arg(PER_DEVICE_STATE, + per_device_op_state()); binding.bind(INPUTS, get_input_tensors()); binding.bind(OUTPUT, output_tensor(0)); binding.bind(NUM_INPUTS, get_number_inputs()); @@ -70,7 +71,7 @@ static DeviceSpecific auto const &attrs = acc.get_argument(ATTRS); PerDeviceFFHandle handle = acc.get_argument(HANDLE); - DeviceSpecific per_device_state = + DeviceSpecific per_device_state = acc.create_device_specific(init_kernel(attrs.axis)); return per_device_state; } @@ -85,7 +86,8 @@ static DeviceSpecific } static optional forward_task_impl(TaskArgumentAccessor const &acc) { - auto per_device_state = acc.get_argument(PER_DEVICE_STATE); + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); int number_of_inputs = acc.get_argument(NUM_INPUTS); ProfilingSettings profiling = acc.get_argument(PROFILING); @@ -93,12 +95,12 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { auto inputs = acc.get_variadic_tensor(INPUTS); return profile(forward_kernel, - profiling, - "[Concat] forward_time = %.2lfms\n", - &per_device_state, - output, - inputs, - number_of_inputs); + profiling, + "[Concat] forward_time = %.2lfms\n", + &per_device_state, + output, + inputs, + number_of_inputs); } static void forward_task(Task const *task, @@ -110,7 +112,8 @@ static void forward_task(Task const *task, } static optional backward_task_impl(TaskArgumentAccessor const &acc) { - auto per_device_state = acc.get_argument(PER_DEVICE_STATE); + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); int number_of_inputs = acc.get_argument(NUM_INPUTS); ProfilingSettings profiling = acc.get_argument(PROFILING); @@ -120,12 +123,12 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(number_of_inputs <= MAX_NUM_INPUTS); return profile(backward_kernel, - profiling, - "[Concat] backward_time = %.2lfms\n", - &per_device_state, - output_grad, - input_grads, - number_of_inputs); + profiling, + "[Concat] backward_time = %.2lfms\n", + &per_device_state, + output_grad, + input_grads, + number_of_inputs); } static void backward_task(Task const *task, @@ -136,24 +139,25 @@ static void backward_task(Task const *task, backward_task_impl(acc); } -CostMetrics measure_operator_cost(SimEnvFactory const &sim, - ConcatAttrs const &attrs, - InputVariadicParallelTensorDesc const &inputs_shape, - ProfilingSettings const &settings, - MachineView const &mv) { +CostMetrics + measure_operator_cost(SimEnvFactory const &sim, + ConcatAttrs const &attrs, + InputVariadicParallelTensorDesc const &inputs_shape, + ProfilingSettings const &settings, + MachineView const &mv) { int numInputs = (inputs_shape.shapes).size(); assert(numInputs <= MAX_NUM_INPUTS); auto env = sim.new_environment(); - ParallelTensorShape output_shape = get_output_shape(attrs, inputs_shape.shapes); + ParallelTensorShape output_shape = + get_output_shape(attrs, inputs_shape.shapes); SimTaskBinding init_binding; init_binding.bind_arg(PROFILING, settings); init_binding.bind_arg(ATTRS, attrs); - auto init_accessor = - env.get_init_accessor(CONCAT_INIT_TASK_ID, init_binding); + auto init_accessor = env.get_init_accessor(CONCAT_INIT_TASK_ID, init_binding); DeviceSpecific per_device_state = init_task_impl(init_accessor); @@ -174,7 +178,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, float sync_time = default_estimate_sync_time(env); return make_metrics(forward_time, backward_time, sync_time, env); - } template <> @@ -183,7 +186,7 @@ void register_task() { init.add_arg_slot(ATTRS); init.add_arg_slot(PROFILING); - + register_task(CONCAT_INIT_TASK_ID, "Concat Init", init, init_task); } @@ -208,5 +211,4 @@ void register_task() { register_task(CONCAT_BWD_TASK_ID, "BatchMatmul Bwd", bwd, backward_task); } - }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/concat.h b/lib/runtime/src/ops/concat.h index 5f792566b4..0493006345 100644 --- a/lib/runtime/src/ops/concat.h +++ b/lib/runtime/src/ops/concat.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_CONCAT_H #include "op-attrs/ops/concat.h" -#include "task_spec/op_task_invocation.h" #include "sim_environment.h" +#include "task_spec/op_task_invocation.h" namespace FlexFlow { @@ -77,7 +77,6 @@ CostMetrics #endif - // bool operator==(ConcatParams const &lhs, ConcatParams const &rhs) { // return lhs.axis == rhs.axis; // } @@ -89,7 +88,8 @@ CostMetrics // } // Tensor -// FFModel::concat(int n, Tensor const *tensors, int axis, char const *name) { +// FFModel::concat(int n, Tensor const *tensors, int axis, char const *name) +// { // Layer *concat = new Layer(this, // OP_CONCAT, // DT_FLOAT, @@ -178,7 +178,6 @@ CostMetrics // char const *name) // : Concat(model, inputs.size(), inputs.data(), params.axis, name) {} - // void Concat::init(FFModel const &ff) { // this->execute_task(ff, CONCAT_INIT_TASK_ID, get_init_task_signature()); // // assert(check_output_input_weight_same_parallel_is()); @@ -210,11 +209,11 @@ CostMetrics // // launcher.add_field(i + 1, FID_DATA); // // } // // for (int i = 0; i < numInputs; i++) { -// // launcher.add_region_requirement(RegionRequirement(inputs[i]->part_grad, +// // launcher.add_region_requirement(RegionRequirement(inputs[i]->part_grad, // // 0 /*projection id*/, // // WRITE_ONLY, // // EXCLUSIVE, -// // inputs[i]->region_grad)); +// // inputs[i]->region_grad)); // // launcher.add_field(i + numInputs + 1, FID_DATA); // // } // // FutureMap fm = runtime->execute_index_space(ctx, launcher); @@ -222,7 +221,6 @@ CostMetrics // // set_opmeta_from_futuremap(ff, fm); // } - // void Concat::forward(FFModel const &ff) { // this->execute_task(ff, CONCAT_FWD_TASK_ID, get_fwd_task_signature()); // // ArgumentMap argmap; @@ -259,7 +257,6 @@ CostMetrics regions[1..numInputs](I): inputs */ - // void Concat::backward(FFModel const &ff) { // this->execute_task(ff, CONCAT_BWD_TASK_ID, get_bwd_task_signature()); // // ArgumentMap argmap; @@ -278,14 +275,14 @@ CostMetrics // // 0 /*projection id*/, // // READ_ONLY, // // EXCLUSIVE, -// // outputs[0]->region_grad)); +// // outputs[0]->region_grad)); // // launcher.add_field(0, FID_DATA); // // for (int i = 0; i < numInputs; i++) { -// // launcher.add_region_requirement(RegionRequirement(inputs[i]->part_grad, +// // launcher.add_region_requirement(RegionRequirement(inputs[i]->part_grad, // // 0 /*projection id*/, // // READ_WRITE, // // EXCLUSIVE, -// // inputs[i]->region_grad)); +// // inputs[i]->region_grad)); // // // LogicalRegion lr = inputs[i]->region_grad; // // // printf("concat[%d]: region(%d,%d,%d)\n", i+1, // // // lr.get_index_space().get_id(), lr.get_field_space().get_id(), @@ -310,7 +307,6 @@ CostMetrics // } // } - // bool Concat::measure_operator_cost(Simulator *sim, // MachineView const &mv, // CostMetrics &cost_metrics) const { @@ -337,12 +333,14 @@ CostMetrics // (float *)sim->allocate(sub_inputs[i].get_volume(), DT_FLOAT); // out_of_memory = out_of_memory || (input_ptrs[i] == NULL); // } -// cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); +// cost_metrics.inputs_memory += +// cost_metrics.total_mem_diff_from(sim->offset); // Domain out_domain = sub_output.get_domain(); -// float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); -// GenericTensorAccessorW output_acc(DT_FLOAT, out_domain, output_ptr); -// cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); +// float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), +// DT_FLOAT); GenericTensorAccessorW output_acc(DT_FLOAT, out_domain, +// output_ptr); cost_metrics.outputs_memory += +// cost_metrics.total_mem_diff_from(sim->offset); // out_of_memory = out_of_memory || (output_ptr == NULL); // if (out_of_memory) { @@ -372,10 +370,11 @@ CostMetrics // (float *)sim->allocate(sub_inputs[i].get_volume(), DT_FLOAT); // out_of_memory = out_of_memory || (input_grad_ptrs[i] == NULL); // input_grad_accs[i] = -// GenericTensorAccessorW(DT_FLOAT, in_domains[i], input_grad_ptrs[i]); +// GenericTensorAccessorW(DT_FLOAT, in_domains[i], +// input_grad_ptrs[i]); // } -// cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); -// float *output_grad_ptr = +// cost_metrics.inputs_memory += +// cost_metrics.total_mem_diff_from(sim->offset); float *output_grad_ptr = // (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); // GenericTensorAccessorR output_grad_acc( // DT_FLOAT, out_domain, output_grad_ptr); @@ -389,7 +388,8 @@ CostMetrics // return true; // } // backward = [&](ffStream_t stream) { -// backward_kernel(stream, m, output_grad_acc, input_grad_accs, numInputs); +// backward_kernel(stream, m, output_grad_acc, input_grad_accs, +// numInputs); // }; // } @@ -397,9 +397,8 @@ CostMetrics // if (sim->computationMode == COMP_MODE_TRAINING) { // printf( -// "[Measure Concat] name(%s) forward_time(%.4lf) backward_time(%.4lf)\n", -// name, -// cost_metrics.forward_time, +// "[Measure Concat] name(%s) forward_time(%.4lf) +// backward_time(%.4lf)\n", name, cost_metrics.forward_time, // cost_metrics.backward_time); // } else { // printf("[Measure Concat] name(%s) forward_time(%.4lf)\n", @@ -408,4 +407,4 @@ CostMetrics // } // return true; -// } \ No newline at end of file +// } diff --git a/lib/runtime/src/serialization.h b/lib/runtime/src/serialization.h index 6839bfdde5..5c1194c7d6 100644 --- a/lib/runtime/src/serialization.h +++ b/lib/runtime/src/serialization.h @@ -7,10 +7,10 @@ #include "legion/legion_utilities.h" #include "op-attrs/dim_ordered.h" #include "utils/optional.h" +#include "utils/required.h" +#include "utils/type_traits.h" #include "utils/variant.h" #include "utils/visitable.h" -#include "utils/type_traits.h" -#include "utils/required.h" namespace FlexFlow { @@ -79,7 +79,8 @@ struct is_trivially_serializable< : std::true_type {}; template -struct is_trivially_serializable>> : is_trivially_serializable> {}; +struct is_trivially_serializable>> + : is_trivially_serializable> {}; template struct is_trivially_serializable> : is_trivially_serializable {}; diff --git a/lib/runtime/src/task_spec/op_task_invocation.h b/lib/runtime/src/task_spec/op_task_invocation.h index 9203f2f4f1..56e709734e 100644 --- a/lib/runtime/src/task_spec/op_task_invocation.h +++ b/lib/runtime/src/task_spec/op_task_invocation.h @@ -5,9 +5,8 @@ #include "index_task_invocation.h" #include "legion.h" #include "op_arg_ref.h" -#include "op_tensor_spec.h" -#include "variadic_tensor_ref.h" #include "op_task_signature.h" +#include "op_tensor_spec.h" #include "runtime/config.h" #include "runtime/profiling.h" #include "serialization.h" @@ -16,6 +15,7 @@ #include "utils/bidict.h" #include "utils/optional.h" #include "utils/stack_map.h" +#include "variadic_tensor_ref.h" #include #include #include diff --git a/lib/runtime/src/task_spec/op_tensor_spec.h b/lib/runtime/src/task_spec/op_tensor_spec.h index 84261141d3..d859bb3072 100644 --- a/lib/runtime/src/task_spec/op_tensor_spec.h +++ b/lib/runtime/src/task_spec/op_tensor_spec.h @@ -15,7 +15,6 @@ OpTensorSpec input_tensor(int); OpTensorSpec output_tensor(int); OpTensorSpec weight_tensor(int); - } // namespace FlexFlow #endif diff --git a/lib/runtime/src/task_spec/variadic_tensor_ref.h b/lib/runtime/src/task_spec/variadic_tensor_ref.h index 1b9bc33b4f..ddd9bd5069 100644 --- a/lib/runtime/src/task_spec/variadic_tensor_ref.h +++ b/lib/runtime/src/task_spec/variadic_tensor_ref.h @@ -1,13 +1,12 @@ #ifndef _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_VARIADIC_TENSOR_ARG_REF_H #define _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_VARIADIC_TENSOR_ARG_REF_H -#include "op_tensor_spec.h" #include "arg_ref.h" +#include "op_tensor_spec.h" namespace FlexFlow { -enum class VariadicTensorRefType {INPUT_TENSORS, - NUM_INPUTS}; +enum class VariadicTensorRefType { INPUT_TENSORS, NUM_INPUTS }; template using VariadicTensorRef = ArgRef; @@ -20,7 +19,6 @@ VariadicTensorRef get_number_inputs() { return {VariadicTensorRefType::NUM_INPUTS}; } - } // namespace FlexFlow #endif