From 6b248cf20e560dbc2ff57d0db2bf463775177620 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 22 Aug 2023 14:11:42 -0700 Subject: [PATCH 01/26] batch matmul initial commit --- .../include/kernels/batch_matmul_kernels.h | 23 +- lib/kernels/src/hip/batch_matmul_kernels.cpp | 7 +- .../include/op-attrs/ops/batch_matmul.h | 4 +- lib/op-attrs/src/batch_matmul.cc | 8 + lib/runtime/src/ops/batch_matmul.cc | 855 ++++-------------- lib/runtime/src/ops/batch_matmul.h | 441 ++++++++- 6 files changed, 655 insertions(+), 683 deletions(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index 0e4437bdb8..cdffdf1907 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -6,17 +6,26 @@ namespace FlexFlow { -class BatchMatmulPerDeviceState : public PerDeviceOpState { -public: - BatchMatmulPerDeviceState(FFHandler handler); - int a_seq_length_dim, b_seq_length_dim; +struct BMMPerDeviceState { + PerDeviceFFHandle handle; + Allocator allocator; + int a_seq_length_dim; + int b_seq_length_dim; }; +FF_VISITABLE_STRUCT_NO_EQ(BMMPerDeviceState, + handle,); + namespace Kernels { namespace BatchMatmul { +BMMPerDeviceState init_kernel(PerDeviceFFHandle handle, + Allocator allocator, + int a_seq_length_dim, + int b_seq_length_dim); + void forward_kernel(ffStream_t stream, - BatchMatmulPerDeviceState const *, + BMMPerDeviceState const *meta, float *o_ptr, float const *a_ptr, float const *b_ptr, @@ -25,12 +34,10 @@ void forward_kernel(ffStream_t stream, int n, int k, int batch, - int a_seq_length_dim = -1, - int b_seq_length_dim = -1, int seq_length = -1); void backward_kernel(ffStream_t stream, - BatchMatmulPerDeviceState const *, + BMMPerDeviceState const *meta, float const *o_ptr, float const *o_grad_ptr, float const *a_ptr, diff --git a/lib/kernels/src/hip/batch_matmul_kernels.cpp b/lib/kernels/src/hip/batch_matmul_kernels.cpp index d8b6500326..e5334b1841 100644 --- a/lib/kernels/src/hip/batch_matmul_kernels.cpp +++ b/lib/kernels/src/hip/batch_matmul_kernels.cpp @@ -19,9 +19,6 @@ namespace FlexFlow { -BatchMatmulPerDeviceState::BatchMatmulPerDeviceState(FFHandler handler) - : PerDeviceOpState(handler) {} - namespace Kernels { namespace BatchMatmul { @@ -42,9 +39,9 @@ void forward_kernel(hipStream_t stream, int k, int batch, hipStream_t stream, - int a_seq_length_dim, - int b_seq_length_dim, int seq_length) { + int a_seq_length_dim = meta->a_seq_length_dim; + int b_seq_length_dim = meta->b_seq_length_dim; checkCUDA(hipblasSetStream(meta->handle.blas, stream)); checkCUDNN(miopenSetStream(meta->handle.dnn, stream)); diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index c74824570c..00c700ba20 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -12,8 +12,10 @@ struct BatchMatmulAttrs { }; FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim); -CHECK_VALID_OP_ATTR(BatchMatmulAttrs); +int get_aSeqLengthDim(BatchMatmulAttrs const &attrs); +int get_bSeqLengthDim(BatchMatmulAttrs const &attrs); +CHECK_VALID_OP_ATTR(BatchMatmulAttrs); } // namespace FlexFlow #endif diff --git a/lib/op-attrs/src/batch_matmul.cc b/lib/op-attrs/src/batch_matmul.cc index 1cc8c5cfda..bd61c24737 100644 --- a/lib/op-attrs/src/batch_matmul.cc +++ b/lib/op-attrs/src/batch_matmul.cc @@ -2,6 +2,14 @@ namespace FlexFlow { +int get_aSeqLengthDim(BatchMatmulAttrs const &attrs) { + return attrs.a_seq_length_dim; +} + +int get_bSeqLengthDim(BatchMatmulAttrs const &attrs) { + return attrs.b_seq_length_dim; +} + /* bool BatchMatmulAttrs::is_valid( */ /* ParallelTensorShape const &lhs, ParallelTensorShape const &rhs) const { */ diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 3e860bd413..9bbc050f8b 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -13,754 +13,287 @@ * limitations under the License. */ +// #include "batch_matmul.h" +// #include "kernels/batch_matmul_kernels.h" +// #include "kernels/profiling.h" +// #include "legion/legion_utilities.h" +// #include "tasks.h" + #include "batch_matmul.h" #include "kernels/batch_matmul_kernels.h" -#include "kernels/profiling.h" -#include "legion/legion_utilities.h" -#include "tasks.h" +#include "legion.h" +#include "op-attrs/ops/batch_matmul.h" namespace FlexFlow { using namespace FlexFlow::Kernels::BatchMatmul; +using Legion::Context; +using Legion::PhysicalRegion; +using Legion::Runtime; +using Legion::Task; + enum Slots { - A_INPUT, - B_INPUT, - OUTPUT, - A_INPUT_GRAD, - B_INPUT_GRAD, - OUTPUT_GRAD, - ATTRS, - PROFILING -}; + A_INPUT, //tensor + B_INPUT, //tensor + OUTPUT, //tensor + PROFILING, + HANDLE, + A_SEQ_LENGTH_DIM, + B_SEQ_LENGTH_DIM, + PER_DEVICE_STATE + }; OpTaskInvocation init(BatchMatmulAttrs const &attrs) { - OpTaskBinding b; + OpTaskBinding init; - b.bind_arg(ATTRS, attrs); - b.bind_arg(PROFILING, enable_profiling()); + init.bind_arg(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); + init.bind_arg(B_SEQ_LENGTH_DIM, get_bSeqLengthDim(attrs)); + init.bind_arg(HANDLE, ff_handle()); - return {BATCHMATMUL_INIT_TASK_ID, b}; + return {BATCHMATMUL_INIT_TASK_ID, init}; } OpTaskInvocation forward(BatchMatmulAttrs const &attrs) { - OpTaskBinding b; - - b.bind(A_INPUT, input_tensor(0)); - b.bind(B_INPUT, input_tensor(1)); - b.bind(OUTPUT, output_tensor(0)); - - return {BATCHMATMUL_FWD_TASK_ID, b}; -} - -OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { - OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - - return {BATCHMATMUL_BWD_TASK_ID, b}; -} - -BatchMatmulParams BatchMatmul::get_params() const { - BatchMatmulParams params; - params.a_seq_length_dim = inputs[0]->num_dims - 1 - this->a_seq_length_dim; - params.b_seq_length_dim = inputs[1]->num_dims - 1 - this->b_seq_length_dim; - return params; -} - -Tensor FFModel::batch_matmul(const Tensor A, - const Tensor B, - int a_seq_length_dim, - int b_seq_length_dim, - char const *name) { - Layer *bmm = new Layer(this, - OP_BATCHMATMUL, - DT_FLOAT, - name, - 2 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - A, - B); - assert((a_seq_length_dim <= 1) && - "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " - "Fortran ordering)."); - assert((b_seq_length_dim <= 1) && - "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " - "Fortran ordering)."); - assert(A->num_dims == B->num_dims); - for (int i = A->num_dims - 1; i >= 2; i--) { - assert(A->dims[i] == B->dims[i]); - } - assert(A->dims[0] == B->dims[1]); - int dims[MAX_TENSOR_DIM]; - int numdim = A->num_dims; - for (int i = 0; i < numdim; i++) { - dims[i] = A->dims[i]; - } - dims[0] = B->dims[0]; - bmm->outputs[0] = create_tensor_legion_ordering( - numdim, dims, A->data_type, bmm, 0, true /*create_grad*/); - bmm->add_int_property("a_seq_length_dim", a_seq_length_dim); - bmm->add_int_property("b_seq_length_dim", b_seq_length_dim); - layers.push_back(bmm); - return bmm->outputs[0]; -} - -Op *BatchMatmul::create_operator_from_layer( - FFModel &model, - Layer const *layer, - std::vector const &inputs) { - long long value; - layer->get_int_property("a_seq_length_dim", value); - int a_seq_length_dim = value; - layer->get_int_property("b_seq_length_dim", value); - int b_seq_length_dim = value; - return new BatchMatmul(model, - inputs[0], - inputs[1], - a_seq_length_dim, - b_seq_length_dim, - layer->name); -} - -BatchMatmul::BatchMatmul( - FFModel &model, - BatchMatmulParams const ¶ms, - std::pair const &inputs, - char const *name) - : BatchMatmul(model, - inputs.first, - inputs.second, - params.a_seq_length_dim, - params.b_seq_length_dim, - name) {} - -// return A*B -BatchMatmul::BatchMatmul(FFModel &model, - const ParallelTensor A, - const ParallelTensor B, - int _a_seq_length_dim, - int _b_seq_length_dim, - char const *name) - : Op(model, - OP_BATCHMATMUL, - DT_FLOAT, - name, - 2 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - A, - B), - a_seq_length_dim(A->num_dims - 1 - _a_seq_length_dim), - b_seq_length_dim(B->num_dims - 1 - _b_seq_length_dim) { - assert((_a_seq_length_dim <= 1) && - "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " - "Fortran ordering)."); - assert((_b_seq_length_dim <= 1) && - "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " - "Fortran ordering)."); - assert(A->num_dims == B->num_dims); - for (int i = A->num_dims - 1; i >= 2; i--) { - assert(A->dims[i] == B->dims[i]); - } - assert(A->dims[0] == B->dims[1]); - ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < A->num_dims; i++) { - dims[i] = A->dims[i]; - } - dims[0] = B->dims[0]; - numOutputs = 1; - outputs[0] = model.create_parallel_tensor_legion_ordering( - A->num_dims, dims, DT_FLOAT, this); - // C is not none - // if (C != Tensor::NO_TENSOR) { - // numInputs = 3; - // assert(C.num_dims == outputs[0].num_dims); - // for (int i = 0; i < C.num_dims; i++) - // assert(C.adim[i] == outputs[0].adim[i]); - //} -} - -void BatchMatmul::serialize(Legion::Serializer &sez) const { - BatchMatmulParams params = get_params(); - sez.serialize(params.a_seq_length_dim); - sez.serialize(params.b_seq_length_dim); -} - -using PCG::Node; -/*static*/ -Node BatchMatmul::deserialize(FFModel &ff, - Legion::Deserializer &dez, - ParallelTensor inputs[], - int num_inputs) { - assert(num_inputs == 2); - int a_seq_length_dim, b_seq_length_dim; - dez.deserialize(a_seq_length_dim); - dez.deserialize(b_seq_length_dim); - - BatchMatmulParams params; - params.a_seq_length_dim = a_seq_length_dim; - params.b_seq_length_dim = b_seq_length_dim; - return ff.get_or_create_node({inputs[0], inputs[1]}, params); -} + OpTaskBinding fwd; -Op *BatchMatmul::materialize(FFModel &ff, - ParallelTensor inputs[], - int num_inputs) const { - BatchMatmulParams params = get_params(); - return new BatchMatmul(ff, params, {inputs[0], inputs[1]}, this->name); -} - -template <> -void register_task() { - OpTaskSignature sig(OpTaskType::INIT); + fwd.bind(A_INPUT, input_tensor(0)); + fwd.bind(B_INPUT, input_tensor(1)); + fwd.bind(OUTPUT, output_tensor(0)); - sig.add_arg_slot(ATTRS); - sig.add_arg_slot(PROFILING); + fwd.bind_arg(PROFILING, profiling_settings()); + fwd.bind_arg(PER_DEVICE_STATE, per_device_op_state()); - register_task(BATCHMATMUL_INIT_TASK_ID, "BatchMatmul Init", sig, init_task); + return {BATCHMATMUL_FWD_TASK_ID, fwd}; } -static OpTaskSignature get_fwd_task_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_input_slot(A_INPUT, READ_WRITE); - fwd.add_input_slot(B_INPUT, READ_WRITE); - fwd.add_output_slot(OUTPUT); - - return fwd; -} - -static OpTaskSignature get_bwd_task_signature() { - OpTaskSignature bwd(OpTaskType::BWD); - - bwd.add_input_slot(A_INPUT); - bwd.add_input_slot(B_INPUT); - bwd.add_input_grad_slot(A_INPUT_GRAD); - bwd.add_input_grad_slot(B_INPUT_GRAD); - bwd.add_output_slot(OUTPUT); - bwd.add_output_grad_slot(OUTPUT_GRAD); - - return bwd; -} - -OpTaskBinding BatchMatmul::get_init_task_binding() const { - OpTaskBinding binding; - - binding.bind_arg(ATTRS, this->attrs); - binding.bind_arg(PROFILING, this->profiling); +OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { + OpTaskBinding bwd = infer_bwd_binding(forward(attrs).binding); - return binding; + return {BATCHMATMUL_BWD_TASK_ID, bwd}; } -OpTaskBinding BatchMatmul::get_fwd_task_binding() const { - OpTaskBinding binding; - - binding.bind(A_INPUT, input_tensor(0)); - binding.bind(B_INPUT, input_tensor(1)); - binding.bind(OUTPUT, output_tensor(0)); - - binding.bind_arg(ATTRS, this->attrs); - return binding; +static DeviceSpecificArg init_task_impl(TaskArgumentAccessor const &acc) { + auto const a_seq_length_dim = acc.get_argument(A_SEQ_LENGTH_DIM); + auto const b_seq_length_dim = acc.get_argument(B_SEQ_LENGTH_DIM); + PerDeviceFFHandle handle = acc.get_argument(HANDLE); + Allocator allocator = acc.get_allocator(); + + DeviceSpecificArg per_device_state = + acc.create_device_specific( + init_kernel(handle, + allocator, + a_seq_length_dim, + b_seq_length_dim)); + + // assert(weight.shape.get_volume() * sizeof(float) == + // acc.unwrap(per_device_state)->weightSize); + return per_device_state; } -OpTaskBinding BatchMatmul::get_bwd_task_binding() const { - OpTaskBinding binding; - binding.bind(A_INPUT, input_tensor(0)); - binding.bind(B_INPUT, input_tensor(1)); - binding.bind_grad(A_INPUT_GRAD, input_tensor(0).grad()); - binding.bind_grad(B_INPUT_GRAD, input_tensor(1).grad()); - - binding.bind(OUTPUT, output_tensor(0)); - binding.bind_grad(OUTPUT_GRAD, output_tensor(0).grad()); - - binding.bind_arg(ATTRS, this->attrs); - return binding; -} - -void BatchMatmul::init(FFModel const &ff) { - int dim = outputs[0]->num_dims; - switch (dim) { -#define DIMFUNC(DIM) \ - case DIM: { \ - // init_with_dim(ff); - this->execute_task(ff, BATCHMATMUL_INIT_TASK_ID, get_init_task_signature()); - break; - } - LEGION_FOREACH_N(DIMFUNC) -#undef DIMFUNC - default: - assert(false); -} -} // namespace FlexFlow -// / -// template -// void BatchMatmul::init_with_dim(FFModel const &ff) { -// assert(check_output_input_weight_same_parallel_is()); -// parallel_is = outputs[0]->parallel_is; -// ArgumentMap argmap; -// Context ctx = ff.config.lg_ctx; -// Runtime *runtime = ff.config.lg_hlr; -// set_argumentmap_for_init(ff, argmap); -// IndexLauncher launcher(BATCHMATMUL_INIT_TASK_ID, -// parallel_is, -// TaskArgument(this, sizeof(BatchMatmul)), -// argmap, -// Predicate::TRUE_PRED, -// false /*must*/, -// 0 /*mapper_id*/, -// outputs[0]->machine_view.hash()); -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, -// 0 /*projection id*/, -// WRITE_ONLY, -// EXCLUSIVE, -// outputs[0]->region)); -// launcher.add_field(0, FID_DATA); -// for (int i = 0; i < numInputs; i++) { -// launcher.add_region_requirement(RegionRequirement(inputs[i]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[i]->region)); -// launcher.add_field(i + 1, FID_DATA); -// } -// FutureMap fm = runtime->execute_index_space(ctx, launcher); -// fm.wait_all_results(); -// set_opmeta_from_futuremap(ff, fm); -// } - -PerDeviceOpState * - BatchMatmul::init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { +static DeviceSpecificArg + init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { TaskArgumentAccessor acc(task, regions, ctx, runtime); - auto const &attrs = acc.get_argument(ATTRS); - bool profiling = acc.get_argument(PROFILING); - - FFHandler handle = *((FFHandler const *)task->local_args); - BatchMatmulPerDeviceState *m = new BatchMatmulPerDeviceState(handle); - m->profiling = profiling; - m->a_seq_length_dim = attrs.a_seq_length_dim; - m->b_seq_length_dim = attrs.b_seq_length_dim; - return m; -} - -void BatchMatmul::forward(FFModel const &ff) { - int dim = outputs[0]->num_dims; - switch (dim) { -#define DIMFUNC(DIM) \ - case DIM: { \ - // forward_with_dim(ff); - this->execute_task(ff, BATCHMATMUL_FWD_TASK_ID, get_fwd_task_signature()); - break; - } - LEGION_FOREACH_N(DIMFUNC) -#undef DIMFUNC - default: - assert(false); -} + return init_task_impl(acc); } -// template -// void BatchMatmul::forward_with_dim(FFModel const &ff) { -// ArgumentMap argmap; -// Context ctx = ff.config.lg_ctx; -// Runtime *runtime = ff.config.lg_hlr; -// set_argumentmap_for_forward(ff, argmap); -// IndexLauncher launcher( -// BATCHMATMUL_FWD_TASK_ID, -// parallel_is, -// TaskArgument(&ff.iter_config, sizeof(FFIterationConfig)), -// argmap, -// Predicate::TRUE_PRED, -// false /*must*/, -// 0 /*mapper_id*/, -// outputs[0]->machine_view.hash()); -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, -// 0 /*projection id*/, -// WRITE_ONLY, -// EXCLUSIVE, -// outputs[0]->region)); -// launcher.add_field(0, FID_DATA); -// for (int i = 0; i < numInputs; i++) { -// launcher.add_region_requirement(RegionRequirement(inputs[i]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[i]->region)); -// launcher.add_field(i + 1, FID_DATA); -// } -// runtime->execute_index_space(ctx, launcher); -// } - -/* - regions[0](O): output - regions[1](I): A - regions[2](I): B - ////////////////////(optional) regions[3](I): C -- TODO: is C deprecated? - output = A * B /////////+ C -*/ -void BatchMatmul::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { +static optional forward_task_impl(TaskArgumentAccessor const &acc) { assert(regions.size() == 3); assert(task->regions.size() == 3); - TaskArgumentAccessor acc(task, regions, ctx, runtime); + auto a_input = acc.get_tensor(A_INPUT); + auto b_input = acc.get_tensor(B_INPUT); + auto output = acc.get_tensor(OUTPUT); - // const BatchMatmul* bmm = (const BatchMatmul*) task->args; + ProfilingSettings profiling = acc.get_argument(PROFILING); + auto per_device_state = acc.get_argument(PER_DEVICE_STATE); FFIterationConfig const *iter_config = (FFIterationConfig const *)task->args; - // BatchMatmulMeta const *meta = *((BatchMatmulMeta **)task->local_args); - BatchMatmulPerDeviceState const *meta = - *((BatchMatmulPerDeviceState **)task->local_args); - auto a_input = acc.get_tensor(A_INPUT); - auto b_input = acc.get_tensor(B_INPUT); - auto output = acc.get_tensor(OUTPUT); - - int m = b_input.shape[0]; - assert(m == output.shape[0]); - int n = a_input.shape[1]; - assert(n == output.shape[1]); - int k = a_input.shape[0]; - assert(k == b_input.shape[1]); + int m = b_input.shape[legion_dim_t(0)]; + assert(m == output.shape[legion_dim_t(0)]); + int n = a_input.shape[legion_dim_t(1)]; + assert(n == output.shape[legion_dim_t(1)]); + int k = a_input.shape[legion_dim_t(0)]; + assert(k == b_input.shape[legion_dim_t(1)]); assert(a_input.shape.size() == b_input.shape.size()); assert(a_input.shape.size() == output.shape.size()); + int batch = 1; for (int i = 2; i < a_input.shape.size(); i++) { - int dim_size = a_input.shape[i]; - assert(dim_size == b_input.shape[i]); - assert(dim_size == output.shape[i]); + int dim_size = a_input.shape[legion_dim_t(i)]; + assert(dim_size == b_input.shape[legion_dim_t(i)]); + assert(dim_size == output.shape[legion_dim_t(i)]); batch *= dim_size; } - float *out_ptr = output.get_float_ptr(); - c float const *a_ptr = a_input.get_float_ptr(); - float const *b_ptr = b_input.get_float_ptr(); - float const *c_ptr = NULL; - // if (regions.size() == 4) { - // Domain c_domain = runtime->get_index_space_domain( - // ctx, task->regions[3].region.get_index_space()); - // assert(c_domain == a_domain); - // c_ptr = helperGetTensorPointerRO( - // regions[3], task->regions[3], FID_DATA, ctx, runtime); - // } - - profile(forward_kernel, - meta->profiling, + + return profile(forward_kernel, + profiling, "[BatchMatmul] forward_time = %.2lfms\n", - out_ptr, - a_ptr, - b_ptr, - c_ptr, + per_device_state, + output.get_float_ptr(), + a_input.get_float_ptr(), + b_input.get_float_ptr(), + NULL, //c_ptr m, n, k, batch, - meta->a_seq_length_dim, - meta->b_seq_length_dim, iter_config->seq_length); } -void BatchMatmul::backward(FFModel const &ff) { - int dim = outputs[0]->num_dims; - switch (dim) { -#define DIMFUNC(DIM) \ - case DIM: { \ - backward_with_dim(ff); \ - break; \ - } - LEGION_FOREACH_N(DIMFUNC) -#undef DIMFUNC - default: - assert(false); - } +static void forward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + forward_task_impl(acc); } -/* - regions[0](I): output - regions[1](I): output_grad - regions[2](I): A - regions[3](I/O): A_grad - regions[4](I): B - regions[5](I/O): B_grad - regions[6](I/O): C_grad -*/ -// template -// void BatchMatmul::backward_with_dim(FFModel const &ff) { -// ArgumentMap argmap; -// Context ctx = ff.config.lg_ctx; -// Runtime *runtime = ff.config.lg_hlr; -// set_argumentmap_for_backward(ff, argmap); -// IndexLauncher launcher( -// BATCHMATMUL_BWD_TASK_ID, -// parallel_is, -// TaskArgument(&ff.iter_config, sizeof(FFIterationConfig)), -// argmap, -// Predicate::TRUE_PRED, -// false /*must*/, -// 0 /*mapper_id*/, -// outputs[0]->machine_view.hash()); -// // regions[0](I): output -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// outputs[0]->region)); -// launcher.add_field(0, FID_DATA); -// // regions[1](I): output_grad -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// outputs[0]->region_grad)); -// launcher.add_field(1, FID_DATA); -// // regions[2](I): A -// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[0]->region)); -// launcher.add_field(2, FID_DATA); -// // regions[3](I/O): A_grad -// launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, -// 0 /*projection id*/, -// READ_WRITE, -// EXCLUSIVE, -// inputs[0]->region_grad)); -// launcher.add_field(3, FID_DATA); -// // regions[4](I): B -// launcher.add_region_requirement(RegionRequirement(inputs[1]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[1]->region)); -// launcher.add_field(4, FID_DATA); -// // regions[5](I/O): B_grad -// launcher.add_region_requirement(RegionRequirement(inputs[1]->part_grad, -// 0 /*projection id*/, -// READ_WRITE, -// EXCLUSIVE, -// inputs[1]->region_grad)); -// launcher.add_field(5, FID_DATA); -// runtime->execute_index_space(ctx, launcher); -// } - -/* - regions[0](I): output - regions[1](I): output_grad - regions[2](I): A - regions[3](I/O): A_grad - regions[4](I): B - regions[5](I/O): B_grad - regions[6](I/O): C_grad -*/ -__host__ void - BatchMatmul::backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { +static optional backward_task_impl(TaskArgumentAccessor const &acc) { // Currently assume C is NULL assert(regions.size() == 6); assert(task->regions.size() == 6); + // BatchMatmul* bmm = (BatchMatmul*) task->args; - TaskArgumentAccessor acc(task, regions, ctx, runtime); FFIterationConfig const *iter_config = (FFIterationConfig const *)task->args; - BatchMatmulPerDeviceState const *meta = - *((BatchMatmulPerDeviceState **)task->local_args); - // output domains - auto output = acc.get_tensor(OUTPUT); - auto output_grad = acc.get_tensor(OUTPUT_GRAD); - assert(output == - output_grad); // is this equivalent to checking `Domain` equality? - // A domains - auto a_input = acc.get_tensor(A_INPUT); - auto a_input_grad = acc.get_tensor(A_INPUT_GRAD); + ProfilingSettings profiling = acc.get_argument(PROFILING); + auto per_device_state = acc.get_argument(PER_DEVICE_STATE); + + auto output = acc.get_tensor(OUTPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); + // is this equivalent to checking `Domain` equality? + assert(output == output_grad); + + auto a_input = acc.get_tensor(A_INPUT); + auto a_input_grad = acc.get_tensor_grad(A_INPUT); assert(a_input == a_input_grad); - // B domains - auto b_input = acc.get_tensor(B_INPUT); - auto b_input_grad = acc.get_tensor(B_INPUT_GRAD); + + auto b_input = acc.get_tensor(B_INPUT); + auto b_input_grad = acc.get_tensor_grad(B_INPUT); assert(b_input == b_input_grad); // check dins - int m = b_input.shape[0]; - assert(m == output.shape[0]); - int n = a_input.shape[1]; - assert(n == output.shape[1]); - int k = a_input.shape[0]; - assert(k == b_input.shape[1]); + int m = b_input.shape[legion_dim_t(0)]; + assert(m == output.shape[legion_dim_t(0)]); + int n = a_input.shape[legion_dim_t(1)]; + assert(n == output.shape[legion_dim_t(1)]); + int k = a_input.shape[legion_dim_t(0)]; + assert(k == b_input.shape[legion_dim_t(1)]); assert(a_input.shape.size() == b_input.shape.size()); assert(a_input.shape.size() == output.shape.size()); int batch = 1; for (int i = 2; i < a_input.shape.size(); i++) { - int dim_size = a_input.shape[i]; - assert(dim_size == b_input.shape[i]); - assert(dim_size == output.shape[i]); + int dim_size = a_input.shape[legion_dim_t(i)]; + assert(dim_size == b_input.shape[legion_dim_t(i)]); + assert(dim_size == output.shape[legion_dim_t(i)]); batch *= dim_size; } - // get pointers - float const *out_ptr = output.get_float_ptr(); - float const *out_grad_ptr = output_grad.get_float_ptr(); - float const *a_ptr = a_input.get_float_ptr(); - float *a_grad_ptr = a_input_grad.get_float_ptr(); - float const *b_ptr = b_input.get_float_ptr(); - float *b_grad_ptr = b_input_grad.get_float_ptr(); - - float *c_grad_ptr = NULL; // TODO: add support for meta->a_seq_length_dim >= 0 // or meta->b_seq_length_dim >= 0 assert((meta->a_seq_length_dim >= a_len) || (iter_config->seq_length == 0)); assert((meta->b_seq_length_dim >= b_len) || (iter_config->seq_length == 0)); - profile(backward_kernel, - meta->profiling, + return profile(backward_kernel, + profiling, "[BatchMatmul] backward_time = %.2lfms\n", - meta, - out_ptr, - out_grad_ptr, - a_ptr, - a_grad_ptr, - b_ptr, - b_grad_ptr, - c_grad_ptr, + per_device_state, + output.get_float_ptr(), + output_grad.get_float_ptr(), + a_input.get_float_ptr(), + a_input_grad.get_float_ptr(), + b_input.get_float_ptr(), + b_input_grad.get_float_ptr(), + NULL, //c_grad_ptr m, n, k, batch); } -void BatchMatmul::print_layer(FFModel const &ff) { - return; +static void backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + backward_task_impl(acc); } -bool BatchMatmul::measure_operator_cost(Simulator *sim, - MachineView const &pc, - CostMetrics &cost_metrics) const { - ParallelTensorBase sub_output, sub_input0, sub_input1; - if (!outputs[0]->get_sub_tensor(pc, sub_output)) { - return false; - } - if (!inputs[0]->get_sub_tensor(pc, sub_input0)) { - return false; - } - if (!inputs[1]->get_sub_tensor(pc, sub_input1)) { - return false; - } - - int input0_c = sub_input0.dims[0].size; - int input0_r = sub_input0.dims[1].size; - int input1_c = sub_input1.dims[0].size; - int input1_r = sub_input1.dims[1].size; - int output_c = sub_output.dims[0].size; - int output_r = sub_output.dims[1].size; - - assert(input0_c == input1_r); - assert(input0_r == output_r); - assert(input1_c == output_c); - - assert(sub_input0.dims[2] == sub_input1.dims[2]); - assert(sub_input1.dims[2] == sub_output.dims[2]); - int batch = 1; - assert(sub_input0.num_dims == sub_input1.num_dims); - for (int i = 2; i < sub_input0.num_dims; i++) { - assert(sub_input0.dims[i] == sub_input1.dims[i]); - assert(sub_input0.dims[i] == sub_output.dims[i]); - batch *= sub_input0.dims[i].size; - } - - BatchMatmulPerDeviceState *meta = sim->batch_matmul_meta; +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + BatchMatmulAttrs const &attrs, + InputParallelTensorDesc const &a_input, + InputParallelTensorDesc const &b_input, + ProfilingSettings const &settings, + MachineView const &pc) const { + auto env = sim.new_environment(); + + //todo add get_output_shape and get_weights_shape to batch_matmul op-attrs + // ParallelTensorShape output_shape = get_output_shape(attrs, inputs); + // ParallelTensorShape weight_shape = get_weights_shape(attrs, inputs); + + SimTaskBinding init_binding; + init_binding.bind_arg(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); + init_binding.bind_arg(B_SEQ_LENGTH_DIM, get_bSeqLengthDim(attrs)); + init_binding.bind_arg(HANDLE, ff_handle()); + + auto init_accessor = + env.get_init_accessor(BATCHMATMUL_INIT_TASK_ID, init_binding); + DeviceSpecificArg per_device_state = + init_task_impl(init_accessor); + + SimTaskBinding fwd_binding; + fwd_binding.bind(A_INPUT, a_input); + fwd_binding.bind(B_INPUT, b_input); + // fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind_arg(PROFILING, settings); + fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); + + SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); + + auto fwd_accessor = env.get_fwd_accessor(BATCHMATMUL_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = env.get_bwd_accessor(BATCHMATMUL_BWD_TASK_ID, bwd_binding); + + float forward_time = forward_task_impl(fwd_accessor).value(); + float backward_time = backward_task_impl(bwd_accessor).value(); + + float sync_time = default_estimate_sync_time(env); + return make_metrics(forward_time, backward_time, sync_time, env); +} - // allocate tensors in simulator - sim->free_all(); - float *a_ptr = (float *)sim->allocate(sub_input0.get_volume(), DT_FLOAT); - assert(a_ptr != NULL); - float *b_ptr = (float *)sim->allocate(sub_input1.get_volume(), DT_FLOAT); - assert(b_ptr != NULL); - float *c_ptr = NULL; - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); +template <> +void register_task() { + OpTaskSignature init(OpTaskType::INIT); - float *out_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - assert(out_ptr != NULL); - cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); + init.add_arg_slot(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); + init.add_arg_slot(B_SEQ_LENGTH_DIM, get_bSeqLengthDim(attrs)); + init.add_unchecked_arg_slot(HANDLE, ff_handle()); - int m = input1_c; - int n = input0_r; - int k = input0_c; + register_task(BATCHMATMUL_INIT_TASK_ID, "BatchMatmul Init", init, init_task); +} - assert(meta->profiling == false); +template <> +void register_task() { + OpTaskSignature fwd(OpTaskType::FWD); - std::function forward, backward; - forward = [&](ffStream_t stream) { - forward_kernel(stream, meta, out_ptr, a_ptr, b_ptr, c_ptr, m, n, k, batch); - }; + fwd.add_input_slot(A_INPUT); + fwd.add_input_slot(B_INPUT); + fwd.add_output_slot(OUTPUT); + fwd.add_arg_slot(PROFILING); + fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); - if (sim->computationMode == COMP_MODE_TRAINING) { - float *a_grad_ptr = - (float *)sim->allocate(sub_input0.get_volume(), DT_FLOAT); - float *b_grad_ptr = - (float *)sim->allocate(sub_input1.get_volume(), DT_FLOAT); - float *c_grad_ptr = NULL; - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *out_grad_ptr = - (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - assert(out_grad_ptr != NULL); - cost_metrics.outputs_memory += - cost_metrics.total_mem_diff_from(sim->offset); - - backward = [&](ffStream_t stream) { - backward_kernel(stream, - meta, - out_ptr, - out_grad_ptr, - a_ptr, - a_grad_ptr, - b_ptr, - b_grad_ptr, - c_grad_ptr, - m, - n, - k, - batch); - }; - } + register_task(BATCHMATMUL_FWD_TASK_ID, "BatchMatmul Fwd", fwd, forward_task); +} - inner_measure_operator_cost(sim, forward, backward, cost_metrics); - - if (sim->computationMode == COMP_MODE_TRAINING) { - printf("[Measure BatchMatmul] name(%s) adim(%d %d %d) bdim(%d %d %d) " - "odim(%d %d %d) forward_time(%.4lf) backward_time(%.4lf)\n", - name, - batch, - input0_r, - input0_c, - batch, - input1_r, - input1_c, - batch, - output_r, - output_c, - cost_metrics.forward_time, - cost_metrics.backward_time); - } else { - printf("[Measure BatchMatmul] name(%s) adim(%d %d %d) bdim(%d %d %d) " - "odim(%d %d %d) forward_time(%.4lf)\n", - name, - batch, - input0_r, - input0_c, - batch, - input1_r, - input1_c, - batch, - output_r, - output_c, - cost_metrics.forward_time); - } +template <> +void register_task() { + OpTaskSignature bwd = + infer_bwd_signature(get_op_signature(ATTENTION_FWD_TASK_ID)); - return true; + register_task(BATCHMATMUL_BWD_TASK_ID, "BatchMatmul Bwd", bwd, backward_task); } -} -; // namespace FlexFlow + +}; // namespace FlexFlow diff --git a/lib/runtime/src/ops/batch_matmul.h b/lib/runtime/src/ops/batch_matmul.h index c133c2a875..ab9bb45e8d 100644 --- a/lib/runtime/src/ops/batch_matmul.h +++ b/lib/runtime/src/ops/batch_matmul.h @@ -1,10 +1,14 @@ #ifndef _FLEXFLOW_BATCH_MATMUL_H #define _FLEXFLOW_BATCH_MATMUL_H +// #include "op-attrs/ops/batch_matmul.h" +// #include "task_spec/op_task_invocation.h" +// #include "task_spec/op_task_signature.h" +// #include "sim_environment.h" + #include "op-attrs/ops/batch_matmul.h" -#include "op_task_invocation.h" -#include "op_task_signature.h" #include "sim_environment.h" +#include "task_spec/op_task_invocation.h" namespace FlexFlow { @@ -19,12 +23,12 @@ OpTaskInvocation init(BatchMatmulAttrs const &); OpTaskInvocation forward(BatchMatmulAttrs const &); OpTaskInvocation backward(BatchMatmulAttrs const &); -CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, - BatchMatmulAttrs const &attrs, - ParallelTensorShape const &lhs_input_shape, - ParallelTensorShape const &rhs_input_shape, - ProfilingSettings const &settings, - MachineView const &); +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + BatchMatmulAttrs const &attrs, + InputParallelTensorDesc const &a_input, + InputParallelTensorDesc const &b_input, + ProfilingSettings const &settings, + MachineView const &pc); /* class BatchMatmul : public Op { */ /* public: */ @@ -84,3 +88,424 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, } // namespace FlexFlow #endif + + + +// BatchMatmulParams BatchMatmul::get_params() const { +// BatchMatmulParams params; +// params.a_seq_length_dim = inputs[0]->num_dims - 1 - this->a_seq_length_dim; +// params.b_seq_length_dim = inputs[1]->num_dims - 1 - this->b_seq_length_dim; +// return params; +// } + +// Tensor FFModel::batch_matmul(const Tensor A, +// const Tensor B, +// int a_seq_length_dim, +// int b_seq_length_dim, +// char const *name) { +// Layer *bmm = new Layer(this, +// OP_BATCHMATMUL, +// DT_FLOAT, +// name, +// 2 /*inputs*/, +// 0 /*weights*/, +// 1 /*outputs*/, +// A, +// B); +// assert((a_seq_length_dim <= 1) && +// "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " +// "Fortran ordering)."); +// assert((b_seq_length_dim <= 1) && +// "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " +// "Fortran ordering)."); +// assert(A->num_dims == B->num_dims); +// for (int i = A->num_dims - 1; i >= 2; i--) { +// assert(A->dims[i] == B->dims[i]); +// } +// assert(A->dims[0] == B->dims[1]); +// int dims[MAX_TENSOR_DIM]; +// int numdim = A->num_dims; +// for (int i = 0; i < numdim; i++) { +// dims[i] = A->dims[i]; +// } +// dims[0] = B->dims[0]; +// bmm->outputs[0] = create_tensor_legion_ordering( +// numdim, dims, A->data_type, bmm, 0, true /*create_grad*/); +// bmm->add_int_property("a_seq_length_dim", a_seq_length_dim); +// bmm->add_int_property("b_seq_length_dim", b_seq_length_dim); +// layers.push_back(bmm); +// return bmm->outputs[0]; +// } + +// Op *BatchMatmul::create_operator_from_layer( +// FFModel &model, +// Layer const *layer, +// std::vector const &inputs) { +// long long value; +// layer->get_int_property("a_seq_length_dim", value); +// int a_seq_length_dim = value; +// layer->get_int_property("b_seq_length_dim", value); +// int b_seq_length_dim = value; +// return new BatchMatmul(model, +// inputs[0], +// inputs[1], +// a_seq_length_dim, +// b_seq_length_dim, +// layer->name); +// } + +// BatchMatmul::BatchMatmul( +// FFModel &model, +// BatchMatmulParams const ¶ms, +// std::pair const &inputs, +// char const *name) +// : BatchMatmul(model, +// inputs.first, +// inputs.second, +// params.a_seq_length_dim, +// params.b_seq_length_dim, +// name) {} + +// // return A*B +// BatchMatmul::BatchMatmul(FFModel &model, +// const ParallelTensor A, +// const ParallelTensor B, +// int _a_seq_length_dim, +// int _b_seq_length_dim, +// char const *name) +// : Op(model, +// OP_BATCHMATMUL, +// DT_FLOAT, +// name, +// 2 /*inputs*/, +// 0 /*weights*/, +// 1 /*outputs*/, +// A, +// B), +// a_seq_length_dim(A->num_dims - 1 - _a_seq_length_dim), +// b_seq_length_dim(B->num_dims - 1 - _b_seq_length_dim) { +// assert((_a_seq_length_dim <= 1) && +// "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " +// "Fortran ordering)."); +// assert((_b_seq_length_dim <= 1) && +// "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " +// "Fortran ordering)."); +// assert(A->num_dims == B->num_dims); +// for (int i = A->num_dims - 1; i >= 2; i--) { +// assert(A->dims[i] == B->dims[i]); +// } +// assert(A->dims[0] == B->dims[1]); +// ParallelDim dims[MAX_TENSOR_DIM]; +// for (int i = 0; i < A->num_dims; i++) { +// dims[i] = A->dims[i]; +// } +// dims[0] = B->dims[0]; +// numOutputs = 1; +// outputs[0] = model.create_parallel_tensor_legion_ordering( +// A->num_dims, dims, DT_FLOAT, this); +// // C is not none +// // if (C != Tensor::NO_TENSOR) { +// // numInputs = 3; +// // assert(C.num_dims == outputs[0].num_dims); +// // for (int i = 0; i < C.num_dims; i++) +// // assert(C.adim[i] == outputs[0].adim[i]); +// //} +// } + +// void BatchMatmul::serialize(Legion::Serializer &sez) const { +// BatchMatmulParams params = get_params(); +// sez.serialize(params.a_seq_length_dim); +// sez.serialize(params.b_seq_length_dim); +// } + +// using PCG::Node; +// /*static*/ +// Node BatchMatmul::deserialize(FFModel &ff, +// Legion::Deserializer &dez, +// ParallelTensor inputs[], +// int num_inputs) { +// assert(num_inputs == 2); +// int a_seq_length_dim, b_seq_length_dim; +// dez.deserialize(a_seq_length_dim); +// dez.deserialize(b_seq_length_dim); + +// BatchMatmulParams params; +// params.a_seq_length_dim = a_seq_length_dim; +// params.b_seq_length_dim = b_seq_length_dim; +// return ff.get_or_create_node({inputs[0], inputs[1]}, params); +// } + +// Op *BatchMatmul::materialize(FFModel &ff, +// ParallelTensor inputs[], +// int num_inputs) const { +// BatchMatmulParams params = get_params(); +// return new BatchMatmul(ff, params, {inputs[0], inputs[1]}, this->name); +// } + + +// void BatchMatmul::forward(FFModel const &ff) { +// int dim = outputs[0]->num_dims; +// switch (dim) { +// #define DIMFUNC(DIM) \ +// case DIM: { \ +// // forward_with_dim(ff); +// this->execute_task(ff, BATCHMATMUL_FWD_TASK_ID, get_fwd_task_signature()); +// break; +// } +// LEGION_FOREACH_N(DIMFUNC) +// #undef DIMFUNC +// default: +// assert(false); +// } +// } + +// template +// void BatchMatmul::forward_with_dim(FFModel const &ff) { +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_forward(ff, argmap); +// IndexLauncher launcher( +// BATCHMATMUL_FWD_TASK_ID, +// parallel_is, +// TaskArgument(&ff.iter_config, sizeof(FFIterationConfig)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// for (int i = 0; i < numInputs; i++) { +// launcher.add_region_requirement(RegionRequirement(inputs[i]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[i]->region)); +// launcher.add_field(i + 1, FID_DATA); +// } +// runtime->execute_index_space(ctx, launcher); +// } + +/* + regions[0](O): output + regions[1](I): A + regions[2](I): B + ////////////////////(optional) regions[3](I): C -- TODO: is C deprecated? + output = A * B /////////+ C +*/ + + +// void BatchMatmul::init(FFModel const &ff) { +// int dim = outputs[0]->num_dims; +// switch (dim) { +// #define DIMFUNC(DIM) \ +// case DIM: { \ +// // init_with_dim(ff); +// this->execute_task(ff, BATCHMATMUL_INIT_TASK_ID, get_init_task_signature()); +// break; +// } +// LEGION_FOREACH_N(DIMFUNC) +// #undef DIMFUNC +// default: +// assert(false); +// } +// } // namespace FlexFlow +// // / +// // template +// // void BatchMatmul::init_with_dim(FFModel const &ff) { +// // assert(check_output_input_weight_same_parallel_is()); +// // parallel_is = outputs[0]->parallel_is; +// // ArgumentMap argmap; +// // Context ctx = ff.config.lg_ctx; +// // Runtime *runtime = ff.config.lg_hlr; +// // set_argumentmap_for_init(ff, argmap); +// // IndexLauncher launcher(BATCHMATMUL_INIT_TASK_ID, +// // parallel_is, +// // TaskArgument(this, sizeof(BatchMatmul)), +// // argmap, +// // Predicate::TRUE_PRED, +// // false /*must*/, +// // 0 /*mapper_id*/, +// // outputs[0]->machine_view.hash()); +// // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// // 0 /*projection id*/, +// // WRITE_ONLY, +// // EXCLUSIVE, +// // outputs[0]->region)); +// // launcher.add_field(0, FID_DATA); +// // for (int i = 0; i < numInputs; i++) { +// // launcher.add_region_requirement(RegionRequirement(inputs[i]->part, +// // 0 /*projection id*/, +// // READ_ONLY, +// // EXCLUSIVE, +// // inputs[i]->region)); +// // launcher.add_field(i + 1, FID_DATA); +// // } +// // FutureMap fm = runtime->execute_index_space(ctx, launcher); +// // fm.wait_all_results(); +// // set_opmeta_from_futuremap(ff, fm); +// // } + +// OpTaskBinding BatchMatmul::get_bwd_task_binding() const { +// OpTaskBinding binding; +// binding.bind(A_INPUT, input_tensor(0)); +// binding.bind(B_INPUT, input_tensor(1)); +// binding.bind_grad(A_INPUT_GRAD, input_tensor(0).grad()); +// binding.bind_grad(B_INPUT_GRAD, input_tensor(1).grad()); + +// binding.bind(OUTPUT, output_tensor(0)); +// binding.bind_grad(OUTPUT_GRAD, output_tensor(0).grad()); + +// binding.bind_arg(ATTRS, this->attrs); +// return binding; +// } + + +// static OpTaskSignature get_fwd_task_signature() { +// OpTaskSignature fwd(OpTaskType::FWD); + +// fwd.add_input_slot(A_INPUT, READ_WRITE); +// fwd.add_input_slot(B_INPUT, READ_WRITE); +// fwd.add_output_slot(OUTPUT); + +// return fwd; +// } + +// static OpTaskSignature get_bwd_task_signature() { +// OpTaskSignature bwd(OpTaskType::BWD); + +// bwd.add_input_slot(A_INPUT); +// bwd.add_input_slot(B_INPUT); +// bwd.add_input_grad_slot(A_INPUT_GRAD); +// bwd.add_input_grad_slot(B_INPUT_GRAD); +// bwd.add_output_slot(OUTPUT); +// bwd.add_output_grad_slot(OUTPUT_GRAD); + +// return bwd; +// } + +// OpTaskBinding BatchMatmul::get_init_task_binding() const { +// OpTaskBinding binding; + +// binding.bind_arg(ATTRS, this->attrs); +// binding.bind_arg(PROFILING, this->profiling); + +// return binding; +// } + +// OpTaskBinding BatchMatmul::get_fwd_task_binding() const { +// OpTaskBinding binding; + +// binding.bind(A_INPUT, input_tensor(0)); +// binding.bind(B_INPUT, input_tensor(1)); +// binding.bind(OUTPUT, output_tensor(0)); + +// binding.bind_arg(ATTRS, this->attrs); +// return binding; +// } + +//void BatchMatmul::backward(FFModel const &ff) { +// int dim = outputs[0]->num_dims; +// switch (dim) { +// #define DIMFUNC(DIM) \ +// case DIM: { \ +// backward_with_dim(ff); \ +// break; \ +// } +// LEGION_FOREACH_N(DIMFUNC) +// #undef DIMFUNC +// default: +// assert(false); +// } +// } + + +// void BatchMatmul::print_layer(FFModel const &ff) { +// return; +// } + + + +/* + regions[0](I): output + regions[1](I): output_grad + regions[2](I): A + regions[3](I/O): A_grad + regions[4](I): B + regions[5](I/O): B_grad + regions[6](I/O): C_grad +*/ +// template +// void BatchMatmul::backward_with_dim(FFModel const &ff) { +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_backward(ff, argmap); +// IndexLauncher launcher( +// BATCHMATMUL_BWD_TASK_ID, +// parallel_is, +// TaskArgument(&ff.iter_config, sizeof(FFIterationConfig)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// // regions[0](I): output +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// // regions[1](I): output_grad +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// outputs[0]->region_grad)); +// launcher.add_field(1, FID_DATA); +// // regions[2](I): A +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(2, FID_DATA); +// // regions[3](I/O): A_grad +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// inputs[0]->region_grad)); +// launcher.add_field(3, FID_DATA); +// // regions[4](I): B +// launcher.add_region_requirement(RegionRequirement(inputs[1]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[1]->region)); +// launcher.add_field(4, FID_DATA); +// // regions[5](I/O): B_grad +// launcher.add_region_requirement(RegionRequirement(inputs[1]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// inputs[1]->region_grad)); +// launcher.add_field(5, FID_DATA); +// runtime->execute_index_space(ctx, launcher); +// } + +/* + regions[0](I): output + regions[1](I): output_grad + regions[2](I): A + regions[3](I/O): A_grad + regions[4](I): B + regions[5](I/O): B_grad + regions[6](I/O): C_grad +*/ \ No newline at end of file From d9aae744914fb6d710a5f174d4e88cc0cd7a95d5 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 22 Aug 2023 14:29:35 -0700 Subject: [PATCH 02/26] fix FF_VISITABLE_STRUCT_NO_EQ --- lib/kernels/include/kernels/batch_matmul_kernels.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index cdffdf1907..b8e6b9cd4b 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -14,7 +14,10 @@ struct BMMPerDeviceState { }; FF_VISITABLE_STRUCT_NO_EQ(BMMPerDeviceState, - handle,); + handle, + allocator, + a_seq_length_dim, + b_seq_length_dim); namespace Kernels { namespace BatchMatmul { From 2073822dbc697615407763f10a8383e864c6929e Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Wed, 23 Aug 2023 12:03:39 -0700 Subject: [PATCH 03/26] finish draft 1 batch_matmul --- .../include/kernels/batch_matmul_kernels.h | 7 ++-- lib/runtime/include/runtime/config.h | 8 +++-- lib/runtime/src/ops/batch_matmul.cc | 32 ++++++++++--------- lib/runtime/src/task_spec/op_arg_ref.h | 8 ++++- 4 files changed, 32 insertions(+), 23 deletions(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index b8e6b9cd4b..1506305e9e 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -10,7 +10,7 @@ struct BMMPerDeviceState { PerDeviceFFHandle handle; Allocator allocator; int a_seq_length_dim; - int b_seq_length_dim; + req b_seq_length_dim; }; FF_VISITABLE_STRUCT_NO_EQ(BMMPerDeviceState, @@ -28,7 +28,7 @@ BMMPerDeviceState init_kernel(PerDeviceFFHandle handle, int b_seq_length_dim); void forward_kernel(ffStream_t stream, - BMMPerDeviceState const *meta, + BMMPerDeviceState const &meta, float *o_ptr, float const *a_ptr, float const *b_ptr, @@ -40,14 +40,13 @@ void forward_kernel(ffStream_t stream, int seq_length = -1); void backward_kernel(ffStream_t stream, - BMMPerDeviceState const *meta, + BMMPerDeviceState const &meta, float const *o_ptr, float const *o_grad_ptr, float const *a_ptr, float *a_grad_ptr, float const *b_ptr, float *b_grad_ptr, - float *c_grad_ptr, int m, int n, int k, diff --git a/lib/runtime/include/runtime/config.h b/lib/runtime/include/runtime/config.h index a7b8d86171..fc45c2c76c 100644 --- a/lib/runtime/include/runtime/config.h +++ b/lib/runtime/include/runtime/config.h @@ -104,13 +104,15 @@ struct FFConfig : public use_visitable_cmp { int python_data_loader_type = 2; }; -class FFIterationConfig { -public: +struct FFIterationConfig { FFIterationConfig(); void reset(); - int seq_length; + req seq_length; }; +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(FFIterationConfig, + seq_length); + enum FieldIDs { FID_DATA, }; diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 9bbc050f8b..761c565657 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -41,7 +41,8 @@ enum Slots { HANDLE, A_SEQ_LENGTH_DIM, B_SEQ_LENGTH_DIM, - PER_DEVICE_STATE + PER_DEVICE_STATE, + ITERATION_CONFIG }; OpTaskInvocation init(BatchMatmulAttrs const &attrs) { @@ -63,6 +64,7 @@ OpTaskInvocation forward(BatchMatmulAttrs const &attrs) { fwd.bind_arg(PROFILING, profiling_settings()); fwd.bind_arg(PER_DEVICE_STATE, per_device_op_state()); + fwd.bind_arg(ITERATION_CONFIG, iteration_config()); return {BATCHMATMUL_FWD_TASK_ID, fwd}; } @@ -110,7 +112,7 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto per_device_state = acc.get_argument(PER_DEVICE_STATE); - FFIterationConfig const *iter_config = (FFIterationConfig const *)task->args; + FFIterationConfig iter_config = acc.get_argument(ITERATION_CONFIG); int m = b_input.shape[legion_dim_t(0)]; assert(m == output.shape[legion_dim_t(0)]); @@ -123,7 +125,7 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { assert(a_input.shape.size() == output.shape.size()); int batch = 1; - for (int i = 2; i < a_input.shape.size(); i++) { + for (int i = 2; i < a_input.shape.get_dim(); i++) { //get_dim() or get_volume()? int dim_size = a_input.shape[legion_dim_t(i)]; assert(dim_size == b_input.shape[legion_dim_t(i)]); assert(dim_size == output.shape[legion_dim_t(i)]); @@ -137,12 +139,12 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { output.get_float_ptr(), a_input.get_float_ptr(), b_input.get_float_ptr(), - NULL, //c_ptr + nullptr, //c_ptr m, n, k, batch, - iter_config->seq_length); + iter_config.seq_length); } static void forward_task(Task const *task, @@ -159,7 +161,7 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(task->regions.size() == 6); // BatchMatmul* bmm = (BatchMatmul*) task->args; - FFIterationConfig const *iter_config = (FFIterationConfig const *)task->args; + FFIterationConfig iter_config = acc.get_argument(ITERATION_CONFIG); ProfilingSettings profiling = acc.get_argument(PROFILING); auto per_device_state = acc.get_argument(PER_DEVICE_STATE); @@ -186,7 +188,7 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(a_input.shape.size() == b_input.shape.size()); assert(a_input.shape.size() == output.shape.size()); int batch = 1; - for (int i = 2; i < a_input.shape.size(); i++) { + for (int i = 2; i < a_input.shape.get_dim(); i++) { //@colin get_dim() or get_volume()? int dim_size = a_input.shape[legion_dim_t(i)]; assert(dim_size == b_input.shape[legion_dim_t(i)]); assert(dim_size == output.shape[legion_dim_t(i)]); @@ -195,8 +197,8 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { // TODO: add support for meta->a_seq_length_dim >= 0 // or meta->b_seq_length_dim >= 0 - assert((meta->a_seq_length_dim >= a_len) || (iter_config->seq_length == 0)); - assert((meta->b_seq_length_dim >= b_len) || (iter_config->seq_length == 0)); + assert((meta->a_seq_length_dim >= a_len) || (iter_config.seq_length == 0)); + assert((meta->b_seq_length_dim >= b_len) || (iter_config.seq_length == 0)); return profile(backward_kernel, profiling, @@ -208,7 +210,6 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { a_input_grad.get_float_ptr(), b_input.get_float_ptr(), b_input_grad.get_float_ptr(), - NULL, //c_grad_ptr m, n, k, @@ -228,10 +229,10 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, InputParallelTensorDesc const &a_input, InputParallelTensorDesc const &b_input, ProfilingSettings const &settings, - MachineView const &pc) const { + MachineView const &pc) { auto env = sim.new_environment(); - //todo add get_output_shape and get_weights_shape to batch_matmul op-attrs + // @colin todo add get_output_shape and get_weights_shape to batch_matmul op-attrs // ParallelTensorShape output_shape = get_output_shape(attrs, inputs); // ParallelTensorShape weight_shape = get_weights_shape(attrs, inputs); @@ -248,6 +249,7 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, SimTaskBinding fwd_binding; fwd_binding.bind(A_INPUT, a_input); fwd_binding.bind(B_INPUT, b_input); + //@colin will uncomment after get_output_shape is done // fwd_binding.bind(OUTPUT, output_shape); fwd_binding.bind_arg(PROFILING, settings); fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); @@ -268,9 +270,9 @@ template <> void register_task() { OpTaskSignature init(OpTaskType::INIT); - init.add_arg_slot(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); - init.add_arg_slot(B_SEQ_LENGTH_DIM, get_bSeqLengthDim(attrs)); - init.add_unchecked_arg_slot(HANDLE, ff_handle()); + init.add_arg_slot(A_SEQ_LENGTH_DIM); + init.add_arg_slot(B_SEQ_LENGTH_DIM); + init.add_unchecked_arg_slot(HANDLE); register_task(BATCHMATMUL_INIT_TASK_ID, "BatchMatmul Init", init, init_task); } diff --git a/lib/runtime/src/task_spec/op_arg_ref.h b/lib/runtime/src/task_spec/op_arg_ref.h index 6e921c05e8..bc4f38e5fb 100644 --- a/lib/runtime/src/task_spec/op_arg_ref.h +++ b/lib/runtime/src/task_spec/op_arg_ref.h @@ -4,10 +4,12 @@ #include "arg_ref.h" #include "device_specific_arg.h" #include "op-attrs/parallel_tensor_shape.h" +#include "runtime/config.h" + namespace FlexFlow { -enum class OpArgRefType { PER_DEVICE_OP_STATE, PARALLEL_TENSOR_SHAPE }; +enum class OpArgRefType { PER_DEVICE_OP_STATE, PARALLEL_TENSOR_SHAPE, ITERATION_CONFIG }; template using OpArgRef = ArgRef; @@ -23,6 +25,10 @@ OpArgRef input_parallel_tensor_shape(int idx) { return {OpArgRefType::PARALLEL_TENSOR_SHAPE}; } +OpArgRef iteration_config() { + return {OpArgRefType::ITERATION_CONFIG}; +} + } // namespace FlexFlow #endif From d5806a01163ab9fb13f0457c7cc4ba4e330b41d0 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Wed, 23 Aug 2023 13:35:55 -0700 Subject: [PATCH 04/26] add output and weights --- lib/op-attrs/include/op-attrs/ops/batch_matmul.h | 3 +++ lib/runtime/src/ops/batch_matmul.cc | 15 ++++----------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index 00c700ba20..fa80ed3da9 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -14,6 +14,9 @@ FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim); int get_aSeqLengthDim(BatchMatmulAttrs const &attrs); int get_bSeqLengthDim(BatchMatmulAttrs const &attrs); +ParallelTensorShape get_weights_shape(BatchMatmulAttrs const &, + ParallelTensorShape const &, + ParallelTensorShape const &); CHECK_VALID_OP_ATTR(BatchMatmulAttrs); } // namespace FlexFlow diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 761c565657..3fabc86d1a 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -13,16 +13,11 @@ * limitations under the License. */ -// #include "batch_matmul.h" -// #include "kernels/batch_matmul_kernels.h" -// #include "kernels/profiling.h" -// #include "legion/legion_utilities.h" -// #include "tasks.h" - #include "batch_matmul.h" #include "kernels/batch_matmul_kernels.h" #include "legion.h" #include "op-attrs/ops/batch_matmul.h" +#include "op-attrs/get_output_shapes.h" namespace FlexFlow { @@ -232,9 +227,8 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, MachineView const &pc) { auto env = sim.new_environment(); - // @colin todo add get_output_shape and get_weights_shape to batch_matmul op-attrs - // ParallelTensorShape output_shape = get_output_shape(attrs, inputs); - // ParallelTensorShape weight_shape = get_weights_shape(attrs, inputs); + ParallelTensorShape output_shape = get_output_shape(attrs, a_input.shape, b_input.shape); + ParallelTensorShape weight_shape = get_weights_shape(attrs, a_input.shape, b_input.shape); SimTaskBinding init_binding; init_binding.bind_arg(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); @@ -249,8 +243,7 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, SimTaskBinding fwd_binding; fwd_binding.bind(A_INPUT, a_input); fwd_binding.bind(B_INPUT, b_input); - //@colin will uncomment after get_output_shape is done - // fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind(OUTPUT, output_shape); fwd_binding.bind_arg(PROFILING, settings); fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); From f620070ea8bc8e7373915a5063e1a5977ed7f433 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Wed, 23 Aug 2023 13:45:00 -0700 Subject: [PATCH 05/26] format --- .../include/kernels/batch_matmul_kernels.h | 15 +-- .../include/op-attrs/ops/batch_matmul.h | 4 +- lib/runtime/include/runtime/config.h | 3 +- lib/runtime/src/ops/batch_matmul.cc | 106 +++++++++--------- lib/runtime/src/ops/batch_matmul.h | 62 +++++----- lib/runtime/src/task_spec/op_arg_ref.h | 7 +- 6 files changed, 97 insertions(+), 100 deletions(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index 1506305e9e..23522a89f7 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -13,19 +13,16 @@ struct BMMPerDeviceState { req b_seq_length_dim; }; -FF_VISITABLE_STRUCT_NO_EQ(BMMPerDeviceState, - handle, - allocator, - a_seq_length_dim, - b_seq_length_dim); +FF_VISITABLE_STRUCT_NO_EQ( + BMMPerDeviceState, handle, allocator, a_seq_length_dim, b_seq_length_dim); namespace Kernels { namespace BatchMatmul { BMMPerDeviceState init_kernel(PerDeviceFFHandle handle, - Allocator allocator, - int a_seq_length_dim, - int b_seq_length_dim); + Allocator allocator, + int a_seq_length_dim, + int b_seq_length_dim); void forward_kernel(ffStream_t stream, BMMPerDeviceState const &meta, @@ -40,7 +37,7 @@ void forward_kernel(ffStream_t stream, int seq_length = -1); void backward_kernel(ffStream_t stream, - BMMPerDeviceState const &meta, + BMMPerDeviceState const &meta, float const *o_ptr, float const *o_grad_ptr, float const *a_ptr, diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index fa80ed3da9..6ac2a2cf69 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -15,8 +15,8 @@ FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim); int get_aSeqLengthDim(BatchMatmulAttrs const &attrs); int get_bSeqLengthDim(BatchMatmulAttrs const &attrs); ParallelTensorShape get_weights_shape(BatchMatmulAttrs const &, - ParallelTensorShape const &, - ParallelTensorShape const &); + ParallelTensorShape const &, + ParallelTensorShape const &); CHECK_VALID_OP_ATTR(BatchMatmulAttrs); } // namespace FlexFlow diff --git a/lib/runtime/include/runtime/config.h b/lib/runtime/include/runtime/config.h index fc45c2c76c..21cfca20cf 100644 --- a/lib/runtime/include/runtime/config.h +++ b/lib/runtime/include/runtime/config.h @@ -110,8 +110,7 @@ struct FFIterationConfig { req seq_length; }; -FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(FFIterationConfig, - seq_length); +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(FFIterationConfig, seq_length); enum FieldIDs { FID_DATA, diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 3fabc86d1a..0f4ad9e865 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -16,8 +16,8 @@ #include "batch_matmul.h" #include "kernels/batch_matmul_kernels.h" #include "legion.h" -#include "op-attrs/ops/batch_matmul.h" #include "op-attrs/get_output_shapes.h" +#include "op-attrs/ops/batch_matmul.h" namespace FlexFlow { @@ -29,16 +29,16 @@ using Legion::Runtime; using Legion::Task; enum Slots { - A_INPUT, //tensor - B_INPUT, //tensor - OUTPUT, //tensor + A_INPUT, // tensor + B_INPUT, // tensor + OUTPUT, // tensor PROFILING, HANDLE, A_SEQ_LENGTH_DIM, B_SEQ_LENGTH_DIM, PER_DEVICE_STATE, ITERATION_CONFIG - }; +}; OpTaskInvocation init(BatchMatmulAttrs const &attrs) { OpTaskBinding init; @@ -70,18 +70,16 @@ OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { return {BATCHMATMUL_BWD_TASK_ID, bwd}; } -static DeviceSpecificArg init_task_impl(TaskArgumentAccessor const &acc) { +static DeviceSpecificArg + init_task_impl(TaskArgumentAccessor const &acc) { auto const a_seq_length_dim = acc.get_argument(A_SEQ_LENGTH_DIM); auto const b_seq_length_dim = acc.get_argument(B_SEQ_LENGTH_DIM); PerDeviceFFHandle handle = acc.get_argument(HANDLE); Allocator allocator = acc.get_allocator(); - DeviceSpecificArg per_device_state = + DeviceSpecificArg per_device_state = acc.create_device_specific( - init_kernel(handle, - allocator, - a_seq_length_dim, - b_seq_length_dim)); + init_kernel(handle, allocator, a_seq_length_dim, b_seq_length_dim)); // assert(weight.shape.get_volume() * sizeof(float) == // acc.unwrap(per_device_state)->weightSize); @@ -107,7 +105,8 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto per_device_state = acc.get_argument(PER_DEVICE_STATE); - FFIterationConfig iter_config = acc.get_argument(ITERATION_CONFIG); + FFIterationConfig iter_config = + acc.get_argument(ITERATION_CONFIG); int m = b_input.shape[legion_dim_t(0)]; assert(m == output.shape[legion_dim_t(0)]); @@ -120,7 +119,8 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { assert(a_input.shape.size() == output.shape.size()); int batch = 1; - for (int i = 2; i < a_input.shape.get_dim(); i++) { //get_dim() or get_volume()? + for (int i = 2; i < a_input.shape.get_dim(); + i++) { // get_dim() or get_volume()? int dim_size = a_input.shape[legion_dim_t(i)]; assert(dim_size == b_input.shape[legion_dim_t(i)]); assert(dim_size == output.shape[legion_dim_t(i)]); @@ -128,18 +128,18 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { } return profile(forward_kernel, - profiling, - "[BatchMatmul] forward_time = %.2lfms\n", - per_device_state, - output.get_float_ptr(), - a_input.get_float_ptr(), - b_input.get_float_ptr(), - nullptr, //c_ptr - m, - n, - k, - batch, - iter_config.seq_length); + profiling, + "[BatchMatmul] forward_time = %.2lfms\n", + per_device_state, + output.get_float_ptr(), + a_input.get_float_ptr(), + b_input.get_float_ptr(), + nullptr, // c_ptr + m, + n, + k, + batch, + iter_config.seq_length); } static void forward_task(Task const *task, @@ -156,14 +156,15 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(task->regions.size() == 6); // BatchMatmul* bmm = (BatchMatmul*) task->args; - FFIterationConfig iter_config = acc.get_argument(ITERATION_CONFIG); + FFIterationConfig iter_config = + acc.get_argument(ITERATION_CONFIG); ProfilingSettings profiling = acc.get_argument(PROFILING); auto per_device_state = acc.get_argument(PER_DEVICE_STATE); auto output = acc.get_tensor(OUTPUT); auto output_grad = acc.get_tensor_grad(OUTPUT); // is this equivalent to checking `Domain` equality? - assert(output == output_grad); + assert(output == output_grad); auto a_input = acc.get_tensor(A_INPUT); auto a_input_grad = acc.get_tensor_grad(A_INPUT); @@ -183,7 +184,8 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(a_input.shape.size() == b_input.shape.size()); assert(a_input.shape.size() == output.shape.size()); int batch = 1; - for (int i = 2; i < a_input.shape.get_dim(); i++) { //@colin get_dim() or get_volume()? + for (int i = 2; i < a_input.shape.get_dim(); + i++) { //@colin get_dim() or get_volume()? int dim_size = a_input.shape[legion_dim_t(i)]; assert(dim_size == b_input.shape[legion_dim_t(i)]); assert(dim_size == output.shape[legion_dim_t(i)]); @@ -196,19 +198,19 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert((meta->b_seq_length_dim >= b_len) || (iter_config.seq_length == 0)); return profile(backward_kernel, - profiling, - "[BatchMatmul] backward_time = %.2lfms\n", - per_device_state, - output.get_float_ptr(), - output_grad.get_float_ptr(), - a_input.get_float_ptr(), - a_input_grad.get_float_ptr(), - b_input.get_float_ptr(), - b_input_grad.get_float_ptr(), - m, - n, - k, - batch); + profiling, + "[BatchMatmul] backward_time = %.2lfms\n", + per_device_state, + output.get_float_ptr(), + output_grad.get_float_ptr(), + a_input.get_float_ptr(), + a_input_grad.get_float_ptr(), + b_input.get_float_ptr(), + b_input_grad.get_float_ptr(), + m, + n, + k, + batch); } static void backward_task(Task const *task, @@ -220,15 +222,17 @@ static void backward_task(Task const *task, } CostMetrics measure_operator_cost(SimEnvFactory const &sim, - BatchMatmulAttrs const &attrs, - InputParallelTensorDesc const &a_input, - InputParallelTensorDesc const &b_input, - ProfilingSettings const &settings, - MachineView const &pc) { + BatchMatmulAttrs const &attrs, + InputParallelTensorDesc const &a_input, + InputParallelTensorDesc const &b_input, + ProfilingSettings const &settings, + MachineView const &pc) { auto env = sim.new_environment(); - ParallelTensorShape output_shape = get_output_shape(attrs, a_input.shape, b_input.shape); - ParallelTensorShape weight_shape = get_weights_shape(attrs, a_input.shape, b_input.shape); + ParallelTensorShape output_shape = + get_output_shape(attrs, a_input.shape, b_input.shape); + ParallelTensorShape weight_shape = + get_weights_shape(attrs, a_input.shape, b_input.shape); SimTaskBinding init_binding; init_binding.bind_arg(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); @@ -249,8 +253,10 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); - auto fwd_accessor = env.get_fwd_accessor(BATCHMATMUL_FWD_TASK_ID, fwd_binding); - auto bwd_accessor = env.get_bwd_accessor(BATCHMATMUL_BWD_TASK_ID, bwd_binding); + auto fwd_accessor = + env.get_fwd_accessor(BATCHMATMUL_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = + env.get_bwd_accessor(BATCHMATMUL_BWD_TASK_ID, bwd_binding); float forward_time = forward_task_impl(fwd_accessor).value(); float backward_time = backward_task_impl(bwd_accessor).value(); diff --git a/lib/runtime/src/ops/batch_matmul.h b/lib/runtime/src/ops/batch_matmul.h index ab9bb45e8d..42c9bc23de 100644 --- a/lib/runtime/src/ops/batch_matmul.h +++ b/lib/runtime/src/ops/batch_matmul.h @@ -24,11 +24,11 @@ OpTaskInvocation forward(BatchMatmulAttrs const &); OpTaskInvocation backward(BatchMatmulAttrs const &); CostMetrics measure_operator_cost(SimEnvFactory const &sim, - BatchMatmulAttrs const &attrs, - InputParallelTensorDesc const &a_input, - InputParallelTensorDesc const &b_input, - ProfilingSettings const &settings, - MachineView const &pc); + BatchMatmulAttrs const &attrs, + InputParallelTensorDesc const &a_input, + InputParallelTensorDesc const &b_input, + ProfilingSettings const &settings, + MachineView const &pc); /* class BatchMatmul : public Op { */ /* public: */ @@ -89,8 +89,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, #endif - - // BatchMatmulParams BatchMatmul::get_params() const { // BatchMatmulParams params; // params.a_seq_length_dim = inputs[0]->num_dims - 1 - this->a_seq_length_dim; @@ -242,15 +240,14 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, // return new BatchMatmul(ff, params, {inputs[0], inputs[1]}, this->name); // } - // void BatchMatmul::forward(FFModel const &ff) { // int dim = outputs[0]->num_dims; // switch (dim) { -// #define DIMFUNC(DIM) \ -// case DIM: { \ +// #define DIMFUNC(DIM) \ +// case DIM: { \ // // forward_with_dim(ff); -// this->execute_task(ff, BATCHMATMUL_FWD_TASK_ID, get_fwd_task_signature()); -// break; +// this->execute_task(ff, BATCHMATMUL_FWD_TASK_ID, +// get_fwd_task_signature()); break; // } // LEGION_FOREACH_N(DIMFUNC) // #undef DIMFUNC @@ -299,15 +296,14 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, output = A * B /////////+ C */ - // void BatchMatmul::init(FFModel const &ff) { // int dim = outputs[0]->num_dims; // switch (dim) { -// #define DIMFUNC(DIM) \ -// case DIM: { \ +// #define DIMFUNC(DIM) \ +// case DIM: { \ // // init_with_dim(ff); -// this->execute_task(ff, BATCHMATMUL_INIT_TASK_ID, get_init_task_signature()); -// break; +// this->execute_task(ff, BATCHMATMUL_INIT_TASK_ID, +// get_init_task_signature()); break; // } // LEGION_FOREACH_N(DIMFUNC) // #undef DIMFUNC @@ -365,7 +361,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, // return binding; // } - // static OpTaskSignature get_fwd_task_signature() { // OpTaskSignature fwd(OpTaskType::FWD); @@ -409,28 +404,25 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, // return binding; // } -//void BatchMatmul::backward(FFModel const &ff) { -// int dim = outputs[0]->num_dims; -// switch (dim) { -// #define DIMFUNC(DIM) \ -// case DIM: { \ -// backward_with_dim(ff); \ -// break; \ +// void BatchMatmul::backward(FFModel const &ff) { +// int dim = outputs[0]->num_dims; +// switch (dim) { +// #d ef ine DIMFUNC(DIM) \ +// case DIM: { \ +// backward_with_dim(ff); \ +// break; \ // } -// LEGION_FOREACH_N(DIMFUNC) -// #undef DIMFUNC -// default: -// assert(false); -// } -// } - +// LEGION_FOREACH_N(DIMFUNC) +// #undef DIMFUNC +// default: +// assert(false); +// } +// } // void BatchMatmul::print_layer(FFModel const &ff) { // return; // } - - /* regions[0](I): output regions[1](I): output_grad @@ -508,4 +500,4 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, regions[4](I): B regions[5](I/O): B_grad regions[6](I/O): C_grad -*/ \ No newline at end of file +*/ diff --git a/lib/runtime/src/task_spec/op_arg_ref.h b/lib/runtime/src/task_spec/op_arg_ref.h index bc4f38e5fb..07316c877b 100644 --- a/lib/runtime/src/task_spec/op_arg_ref.h +++ b/lib/runtime/src/task_spec/op_arg_ref.h @@ -6,10 +6,13 @@ #include "op-attrs/parallel_tensor_shape.h" #include "runtime/config.h" - namespace FlexFlow { -enum class OpArgRefType { PER_DEVICE_OP_STATE, PARALLEL_TENSOR_SHAPE, ITERATION_CONFIG }; +enum class OpArgRefType { + PER_DEVICE_OP_STATE, + PARALLEL_TENSOR_SHAPE, + ITERATION_CONFIG +}; template using OpArgRef = ArgRef; From 5458923bc344d2cefbd6bc18b7c3719e73291ba8 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Wed, 23 Aug 2023 15:23:36 -0700 Subject: [PATCH 06/26] fix DeviceSpecific --- lib/runtime/src/ops/batch_matmul.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 0f4ad9e865..59570f4417 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -70,14 +70,14 @@ OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { return {BATCHMATMUL_BWD_TASK_ID, bwd}; } -static DeviceSpecificArg +static DeviceSpecific init_task_impl(TaskArgumentAccessor const &acc) { auto const a_seq_length_dim = acc.get_argument(A_SEQ_LENGTH_DIM); auto const b_seq_length_dim = acc.get_argument(B_SEQ_LENGTH_DIM); PerDeviceFFHandle handle = acc.get_argument(HANDLE); Allocator allocator = acc.get_allocator(); - DeviceSpecificArg per_device_state = + DeviceSpecific per_device_state = acc.create_device_specific( init_kernel(handle, allocator, a_seq_length_dim, b_seq_length_dim)); @@ -86,7 +86,7 @@ static DeviceSpecificArg return per_device_state; } -static DeviceSpecificArg +static DeviceSpecific init_task(Task const *task, std::vector const ®ions, Context ctx, @@ -241,7 +241,7 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, auto init_accessor = env.get_init_accessor(BATCHMATMUL_INIT_TASK_ID, init_binding); - DeviceSpecificArg per_device_state = + DeviceSpecific per_device_state = init_task_impl(init_accessor); SimTaskBinding fwd_binding; From 2f604dca90edf8ecaef4dade5bbb6d0fc4ed612f Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Fri, 25 Aug 2023 18:57:40 -0700 Subject: [PATCH 07/26] batch_norm --- .../include/kernels/batch_norm_kernels.h | 70 +- lib/runtime/src/ops/batch_norm.cc | 700 ++++++------------ lib/runtime/src/ops/batch_norm.h | 242 +++++- 3 files changed, 537 insertions(+), 475 deletions(-) diff --git a/lib/kernels/include/kernels/batch_norm_kernels.h b/lib/kernels/include/kernels/batch_norm_kernels.h index 6ff90299db..74dfc96068 100644 --- a/lib/kernels/include/kernels/batch_norm_kernels.h +++ b/lib/kernels/include/kernels/batch_norm_kernels.h @@ -8,30 +8,66 @@ namespace FlexFlow { -class BatchNormPerDeviceState : public PerDeviceOpState { -public: - BatchNormPerDeviceState(FFHandler handle, - std::unique_ptr allocator, - int output_n, - int output_c, - int output_h, - int output_w, - bool relu, - bool profiling); - ~BatchNormPerDeviceState(void); - - ffTensorDescriptor_t inputTensor, outputTensor, biasTensor; +struct BatchNormPerDeviceState { + PerDeviceFFHandle handle; + Allocator allocator; + ffTensorDescriptor_t inputTensor; + ffTensorDescriptor_t outputTensor; + ffTensorDescriptor_t biasTensor; ffActivationDescriptor_t actiDesc; ffBatchNormMode_t mode; - float *runningMean, *runningVar, *saveMean, *saveVar; - bool relu; - bool profiling; - std::unique_ptr allocator; + float *runningMean; + float *runningVar; + float *saveMean; + float *saveVar; + int output_n; + int output_c; + int output_h; + int output_w; + ProfilingSettings profiling; + req relu; }; +FF_VISITABLE_STRUCT_NO_EQ(BatchNormPerDeviceState, + handle, + allocator, + inputTensor, + outputTensor, + biasTensor, + actiDesc, + mode, + runningMean, + runningVar, + saveMean, + saveVar, + output_n, + output_c, + output_h, + output_w, + profiling, + relu); + namespace Kernels { namespace BatchNorm { +BatchNormPerDeviceState init_kernel(PerDeviceFFHandle handle, + Allocator allocator, + ffTensorDescriptor_t inputTensor, + ffTensorDescriptor_t outputTensor, + ffTensorDescriptor_t biasTensor, + ffActivationDescriptor_t actiDesc, + ffBatchNormMode_t mode, + float *runningMean, + float *runningVar, + float *saveMean, + float *saveVar, + int output_n, + int output_c, + int output_h, + int output_w, + ProfilingSettings profiling, + bool relu); + void forward_kernel(ffStream_t stream, BatchNormPerDeviceState *m, float const *input_ptr, diff --git a/lib/runtime/src/ops/batch_norm.cc b/lib/runtime/src/ops/batch_norm.cc index 98cc4576a1..ffd52c96fb 100644 --- a/lib/runtime/src/ops/batch_norm.cc +++ b/lib/runtime/src/ops/batch_norm.cc @@ -16,505 +16,291 @@ #include "batch_norm.h" #include "kernels/batch_norm_kernels.h" #include "legion/legion_utilities.h" -#include "task_spec.h" + +namespace FlexFlow { using namespace FlexFlow::Kernels::BatchNorm; -namespace FlexFlow { +using Legion::Context; +using Legion::PhysicalRegion; +using Legion::Runtime; +using Legion::Task; enum Slots { - INPUT, - SCALE, - BIAS, - OUTPUT, - INPUT_GRAD, - SCALE_GRAD, - BIAS_GRAD, - OUTPUT_GRAD, + INPUT, // tensor + SCALE, // tensor + BIAS, // tensor + OUTPUT, // tensor ATTRS, - PROFILING -} - -Tensor - FFModel::batch_norm(const Tensor input, bool relu, char const *name) { - assert(input->num_dims == 4); /*NCHW*/ - Layer *bm = new Layer(this, - OP_BATCHNORM, - DT_FLOAT, - name, - 1 /*inputs*/, - 2 /*weights*/, - 1 /*outputs*/, - input); - int numdims = 4; - bm->outputs[0] = create_tensor_legion_ordering( - numdims, input->dims, DT_FLOAT, bm, 0, true /*create_grad*/); - bm->add_int_property("relu", relu); - layers.push_back(bm); - return bm->outputs[0]; -} - -/* - locals[0] = scale - locals[1] = bias -*/ -BatchNorm::BatchNorm(FFModel &model, - const ParallelTensor _input, - const ParallelTensor _scale, - const ParallelTensor _bias, - bool _relu, - char const *name) - : Op(model, - OP_BATCHNORM, - DT_FLOAT, - name, - 1 /*inputs*/, - 2 /*weights*/, - 1 /*outputs*/, - _input, - _scale, - _bias), - relu(_relu) { - assert(_input->num_dims == 4); - numOutputs = 1; - ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < _input->num_dims; i++) { - dims[i] = _input->dims[_input->num_dims - 1 - i]; - } - outputs[0] = - model.create_parallel_tensor(_input->num_dims, dims, DT_FLOAT, this); - return; -} - -static OpTaskSignature get_init_task_signature() { - OpTaskSignature init(OpTaskType::INIT); - - init.add_arg_slot(ATTRS); - init.add_arg_slot(PROFILING); - - // init.add_input_slot(INPUT); - // init.add_param_slot(SCALE); - // init.add_param_slot(BIAS); - init.add_output_slot(OUTPUT); -} + PROFILING, + PER_DEVICE_STATE, + RELU, + HANDLE +}; -static OpTaskSignature get_fwd_task_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_arg_slot(ATTRS); - - fwd.add_input_slot(INPUT); - fwd.add_param_slot(SCALE); - fwd.add_param_slot(BIAS); - fwd.add_output_slot(OUTPUT, WRITE_DISCARD); - - return fwd; -} - -static OpTaskSignature get_bwd_task_signature() { - OpTaskSignature bwd(OpTaskType::BWD); - - bwd.add_arg_slot(ATTRS); - - bwd.add_input_slot(INPUT); - bwd.add_input_grad_slot(INPUT_GRAD, READ_WRITE); - bwd.add_param_slot(SCALE); - bwd.add_param_grad_slot(SCALE_GRAD, READ_WRITE); - bwd.add_param_grad_slot(BIAS_GRAD, READ_WRITE); - bwd.add_output_grad_slot(OUTPUT_GRAD); - - return bwd; -} - -OpTaskBinding BatchNorm::get_init_task_binding() const { +OpTaskInvocation init(BatchNormAttrs const &attrs) { OpTaskBinding binding; - binding.bind_arg(ATTRS, this->attrs); - binding.bind_arg(PROFILING, this->profiling); - - // binding.bind(INPUT, input_tensor(0)); - // binding.bind(SCALE, param_tensor(0)); - // binding.bind(BIAS, param_tensor(1)); + binding.bind(INPUT, input_tensor(0)); + binding.bind(BIAS, input_tensor(2)); binding.bind(OUTPUT, output_tensor(0)); - return binding; + binding.bind_arg(ATTRS, attrs); + binding.bind_arg(PROFILING, profiling_settings()); + binding.bind_arg(HANDLE, ff_handle()); + + return {BATCHNORM_INIT_TASK_ID, binding}; } -OpTaskBinding BatchNorm::get_fwd_task_binding() const { +OpTaskInvocation forward(BatchNormAttrs const &attrs) { OpTaskBinding binding; - - binding.bind_arg(ATTRS, this->attrs); + binding.bind_arg(PROFILING, profiling_settings()); + binding.bind_arg(PER_DEVICE_STATE, + per_device_op_state()); binding.bind(INPUT, input_tensor(0)); - binding.bind(SCALE, param_tensor(0)); - binding.bind(BIAS, param_tensor(1)); + binding.bind(SCALE, input_tensor(1)); + binding.bind(BIAS, input_tensor(2)); binding.bind(OUTPUT, output_tensor(0)); - return binding; + return {BATCHNORM_FWD_TASK_ID, binding}; } -OpTaskBinding BatchNorm::get_bwd_task_binding() const { - OpTaskBinding binding; +OpTaskInvocation backward(BatchNormAttrs const &attrs) { + OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - binding.bind_arg(ATTRS, this->attrs); + return {BATCHNORM_BWD_TASK_ID, binding}; +} - binding.bind(INPUT, input_tensor(0)); - binding.bind(INPUT_GRAD, input_tensor(0).grad()); - binding.bind(SCALE, param_tensor(0)); - binding.bind(SCALE_GRAD, param_tensor(0).grad()); - binding.bind(BIAS_GRAD, param_tensor(1).grad()); - binding.bind(OUTPUT_GRAD, output_tensor(0).grad()); +static DeviceSpecific + init_task_impl(TaskArgumentAccessor const &acc) { + Allocator allocator = acc.get_allocator(); + PerDeviceFFHandle handle = acc.get_argument(HANDLE); + ProfilingSettings profiling = acc.get_argument(PROFILING); + auto output = acc.get_tensor(OUTPUT); + auto const &attrs = acc.get_argument(ATTRS); + + int output_w = output.shape[legion_dim_t(0)]; + int output_h = output.shape[legion_dim_t(1)]; + int output_c = output.shape[legion_dim_t(2)]; + int output_n = output.shape[legion_dim_t(3)]; + + ffTensorDescriptor_t inputTensor; + ffTensorDescriptor_t outputTensor; + ffTensorDescriptor_t biasTensor; + ffActivationDescriptor_t actiDesc; + ffBatchNormMode_t mode; + + size_t totalSize = sizeof(float) * output_c * 4; + float *runningMean = (float *)allocator.allocate(totalSize); + float *runningVar = (float *)runningMean + output_c; + float *saveMean = (float *)runningVar + output_c; + float *saveVar = (float *)saveMean + output_c; + + DeviceSpecific per_device_state = + acc.create_device_specific( + init_kernel(handle, + allocator, + inputTensor, + outputTensor, + biasTensor, + actiDesc, + mode, + runningMean, + runningVar, + saveMean, + saveVar, + output_n, + output_c, + output_h, + output_w, + profiling, + attrs.relu)); + + return per_device_state; +} - return binding; +static DeviceSpecific + init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + return init_task_impl(acc); } -void BatchNorm::init(FFModel const &ff) { - this->execute_task(ff, BATCHNORM_INIT_TASK_ID, get_init_task_signature()); - // assert(check_output_input_weight_same_parallel_is()); - // parallel_is = outputs[0]->parallel_is; - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_init(ff, argmap); - // IndexLauncher launcher(BATCHNORM_INIT_TASK_ID, - // parallel_is, - // TaskArgument(this, sizeof(BatchNorm)), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(1, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(weights[0]->region, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // weights[0]->region)); - // launcher.add_field(2, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(weights[1]->region, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // weights[1]->region)); - // launcher.add_field(3, FID_DATA); - // FutureMap fm = runtime->execute_index_space(ctx, launcher); - // fm.wait_all_results(); - // set_opmeta_from_futuremap(ff, fm); +static optional forward_task_impl(TaskArgumentAccessor const &acc) { + assert(regions.size() == 4); + assert(task->regions.size() == 4); + + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + ProfilingSettings profiling = acc.get_argument(PROFILING); + + auto input = acc.get_tensor(INPUT); + auto output = acc.get_tensor(OUTPUT); + auto scale = acc.get_tensor(SCALE); + auto bias = acc.get_tensor(SCALE); + + return profile(forward_kernel, + profiling, + "[BatchNorm] forward_time = %.2lfms\n", + &per_device_state, + input.get_float_ptr(), + output.get_float_ptr(), + scale.get_float_ptr(), + bias.get_float_ptr()); } -/* - regions[0]: input - regions[1]: output - regions[2](I): scale - regions[3](I): bias -*/ -PerDeviceOpState * - BatchNorm::init_task(Task const *task, +static void forward_task(Task const *task, std::vector const ®ions, Context ctx, Runtime *runtime) { - assert(regions.size() == 4); - assert(task->regions.size() == 4); TaskArgumentAccessor acc(task, regions, ctx, runtime); - FFHandler handle = *((FFHandler const *)task->local_args); - - auto output = acc.get_tensor(OUTPUT); - - int output_w = output.shape[0]; - int output_h = output.shape[1]; - int output_c = output.shape[2]; - int output_n = output.shape[3]; - - Memory gpu_mem = Machine::MemoryQuery(Machine::get_machine()) - .only_kind(Memory::GPU_FB_MEM) - .best_affinity_to(task->target_proc) - .first(); - BatchNormPerDeviceState *m = new BatchNormPerDeviceState( - handle, bm, gpu_mem, output_n, output_c, output_h, output_w); - return m; + forward_task_impl(acc); } -void BatchNorm::forward(FFModel const &ff) { - this->execute_task(ff, BATCHNORM_FWD_TASK_ID, get_fwd_task_signature()); - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_forward(ff, argmap); - // IndexLauncher launcher(BATCHNORM_FWD_TASK_ID, - // parallel_is, - // TaskArgument(NULL, 0), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // WRITE_DISCARD, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(1, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(weights[0]->region, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // weights[0]->region)); - // launcher.add_field(2, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(weights[1]->region, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // weights[1]->region)); - // launcher.add_field(3, FID_DATA); - - // runtime->execute_index_space(ctx, launcher); +static optional backward_task_impl(TaskArgumentAccessor const &acc) { + assert(regions.size() == 7); + assert(task->regions.size() == 7); + + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + ProfilingSettings profiling = acc.get_argument(PROFILING); + + auto input = acc.get_tensor(INPUT); + auto input_grad = acc.get_tensor_grad(INPUT); + auto output = acc.get_tensor(OUTPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); + auto scale = acc.get_tensor(SCALE); + auto scale_grad = acc.get_tensor_grad(SCALE); + auto bias_grad = acc.get_tensor_grad(BIAS); + + return profile(backward_kernel, + profiling, + "[BatchNorm] backward_time = %.2lfms\n", + &per_device_state, + input.get_float_ptr(), + output_grad.get_float_ptr(), + output.get_float_ptr(), + input_grad.get_float_ptr(), + scale.get_float_ptr(), + scale_grad.get_float_ptr(), + bias_grad.get_float_ptr(), + output.shape.get_volume()); } -/* - regions[0](I): input - regions[1](O): ouptut - regions[2](I): scale - regions[3](I): bias -*/ -void BatchNorm::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 4); - assert(task->regions.size() == 4); - // const BatchNorm* bm = (BatchNorm*) task->args; +static void backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { TaskArgumentAccessor acc(task, regions, ctx, runtime); - BatchNormPerDeviceState *m = *((BatchNormPerDeviceState **)task->local_args); - - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - auto scale = acc.get_tensor(SCALE); - auto bias = acc.get_tensor(SCALE); - - profile(forward_kernel, - m->profiling, - "[BatchNorm] forward_time = %.2lfms\n", - m, - input.get_float_ptr(), - output.get_float_ptr(), - scale.get_float_ptr(), - bias.get_float_ptr()); + backward_task_impl(acc); } -void BatchNorm::backward(FFModel const &ff) { - this->execute_task(ff, BATCHNORM_BWD_TASK_ID, get_bwd_task_signature()); - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_backward(ff, argmap); - // IndexLauncher launcher(BATCHNORM_BWD_TASK_ID, - // parallel_is, - // TaskArgument(NULL, 0), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // // regions[0](I): input - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // // regions[1](I/O): input_grad (we only need grad tensors) - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, - // 0 /*projection id*/, - // READ_WRITE, - // EXCLUSIVE, - // inputs[0]->region_grad)); - // launcher.add_field(1, FID_DATA); - // // regions[2](I): output - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(2, FID_DATA); - // // regions[3](I/O): output_grad - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, - // 0 /*projection id*/, - // READ_WRITE, - // EXCLUSIVE, - // outputs[0]->region_grad)); - // launcher.add_field(3, FID_DATA); - // // regions[4](I): filter - // launcher.add_region_requirement(RegionRequirement(weights[0]->region, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // weights[0]->region)); - // launcher.add_field(4, FID_DATA); - // // regions[5](I/O): filter_grad - // launcher.add_region_requirement(RegionRequirement(weights[0]->part_grad, - // 0 /*projection id*/, - // READ_WRITE, - // EXCLUSIVE, - // weights[0]->region_grad)); - // launcher.add_field(5, FID_DATA); - // // regions[6](I/O): bias_grad - // launcher.add_region_requirement(RegionRequirement(weights[1]->part_grad, - // 0 /*projection id*/, - // READ_WRITE, - // EXCLUSIVE, - // weights[1]->region_grad)); - // launcher.add_field(6, FID_DATA); - // FutureMap fm = runtime->execute_index_space(ctx, launcher); +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + BatchNormAttrs const &attrs, + InputParallelTensorDesc const &input_shape, + InputParallelTensorDesc const &scale_shape, + InputParallelTensorDesc const &bias_shape, + ProfilingSettings const &settings, + MachineView const &mv) { + + // int output_w = sub_output.dims[0].size; + // int output_h = sub_output.dims[1].size; + // int output_c = sub_output.dims[2].size; + // int output_n = sub_output.dims[3].size; + // BatchNormPerDeviceState *m = new BatchNormPerDeviceState( + // sim->handler, this, sim->memory, output_n, output_c, output_h, + // output_w); + + // sim->free_all(); + // float *input_ptr = (float *)sim->allocate(sub_input.get_volume(), + // DT_FLOAT); assert(input_ptr != NULL); cost_metrics.inputs_memory += + // cost_metrics.total_mem_diff_from(sim->offset); + + // float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), + // DT_FLOAT); assert(output_ptr != NULL); cost_metrics.outputs_memory += + // cost_metrics.total_mem_diff_from(sim->offset); + + // float *bias_ptr = (float *)sim->allocate(output_c, DT_FLOAT); + // assert(bias_ptr != NULL); + // float *scale_ptr = (float *)sim->allocate(output_c, DT_FLOAT); + // assert(scale_ptr != NULL); + // cost_metrics.weights_memory += + // cost_metrics.total_mem_diff_from(sim->offset); + + auto env = sim.new_environment(); + + ParallelTensorShape output_shape = get_output_shape(attrs); + + SimTaskBinding init_binding; + init_binding.bind(INPUT, input_shape); + init_binding.bind(BIAS, bias_shape); + init_binding.bind(OUTPUT, output_shape); + + init_binding.bind_arg(ATTRS, attrs); + init_binding.bind_arg(PROFILING, settings); + init_binding.bind_arg(HANDLE, ff_handle()); + + auto init_accessor = + env.get_init_accessor(ATTENTION_INIT_TASK_ID, init_binding); + DeviceSpecific per_device_state = + init_task_impl(init_accessor); + + SimTaskBinding fwd_binding; + fwd_binding.bind(INPUT, input_shape); + fwd_binding.bind(SCALE, scale_shape); + fwd_binding.bind(BIAS, bias_shape); + fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind_arg(PROFILING, settings); + fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); + + SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); + + auto fwd_accessor = env.get_fwd_accessor(ATTENTION_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = env.get_bwd_accessor(ATTENTION_BWD_TASK_ID, bwd_binding); + + float forward_time = forward_task_impl(fwd_accessor).value(); + float backward_time = backward_task_impl(bwd_accessor).value(); + + float sync_time = default_estimate_sync_time(env); + return make_metrics(forward_time, backward_time, sync_time, env); } -/* - regions[0](I): input - regions[1](I/O): input_grad - regions[2](I): output - regions[3](I/O): output_grad - regions[4](I): scale - regions[5](I/O): scale_grad - regions[6](I/O): bias_grad -*/ -__host__ void - BatchNorm::backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 7); - assert(task->regions.size() == 7); - // float beta = 0.0f; - // const BatchNorm* bm = (BatchNorm*) task->args; - TaskArgumentAccessor acc(task, regions, ctx, runtime); - BatchNormPerDeviceState *m = *((BatchNormPerDeviceState **)task->local_args); - - auto input = acc.get_tensor(INPUT); - auto input_grad = acc.get_tensor_grad(INPUT_GRAD); - auto output = acc.get_tensor(OUTPUT); - auto output_grad = acc.get_tensor_grad(OUTPUT_GRAD); - auto scale = acc.get_tensor(SCALE); - auto scale_grad = acc.get_tensor_grad(SCALE_GRAD); - auto bias_grad = acc.get_tensor_grad(BIAS_GRAD); - - profile(backward_kernel, - m->profiling, - "[BatchNorm] backward_time = %.2lfms\n", - m, - input.get_float_ptr(), - output_grad.get_float_ptr(), - output.get_float_ptr(), - input_grad.get_float_ptr(), - scale.get_float_ptr(), - scale_grad.get_float_ptr(), - bias_grad.get_float_ptr(), - output.get_volume()); +template <> +void register_task() { + OpTaskSignature init(OpTaskType::INIT); + init.add_input_slot(INPUT); + init.add_input_slot(BIAS); + init.add_output_slot(OUTPUT); + init.add_arg_slot(ATTRS); + init.add_arg_slot(PROFILING); + init.add_unchecked_arg_slot(HANDLE); + + register_task(BATCHNORM_INIT_TASK_ID, "BatchNorm Init", init, init_task); } -bool BatchNorm::measure_operator_cost(Simulator *sim, - MachineView const &mv, - CostMetrics &cost_metrics) const { - ParallelTensorBase sub_input, sub_output; - if (!outputs[0]->get_sub_tensor(mv, sub_output)) { - return false; - } - if (!inputs[0]->get_sub_tensor(mv, sub_input)) { - return false; - } - - int output_w = sub_output.dims[0].size; - int output_h = sub_output.dims[1].size; - int output_c = sub_output.dims[2].size; - int output_n = sub_output.dims[3].size; - BatchNormPerDeviceState *m = new BatchNormPerDeviceState( - sim->handler, this, sim->memory, output_n, output_c, output_h, output_w); - - sim->free_all(); - float *input_ptr = (float *)sim->allocate(sub_input.get_volume(), DT_FLOAT); - assert(input_ptr != NULL); - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - assert(output_ptr != NULL); - cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *bias_ptr = (float *)sim->allocate(output_c, DT_FLOAT); - assert(bias_ptr != NULL); - float *scale_ptr = (float *)sim->allocate(output_c, DT_FLOAT); - assert(scale_ptr != NULL); - cost_metrics.weights_memory += cost_metrics.total_mem_diff_from(sim->offset); - - std::function forward, backward; - forward = [&](ffStream_t stream) { - forward_kernel(stream, m, input_ptr, output_ptr, scale_ptr, bias_ptr); - }; - if (sim->computationMode == COMP_MODE_TRAINING) { - float *input_grad_ptr = - (float *)sim->allocate(sub_input.get_volume(), DT_FLOAT); - assert(input_grad_ptr != NULL); - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *output_grad_ptr = - (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - assert(output_grad_ptr != NULL); - cost_metrics.outputs_memory += - cost_metrics.total_mem_diff_from(sim->offset); - - float *scale_grad_ptr = (float *)sim->allocate(output_c, DT_FLOAT); - assert(scale_grad_ptr != NULL); - float *bias_grad_ptr = (float *)sim->allocate(output_c, DT_FLOAT); - assert(bias_grad_ptr != NULL); - cost_metrics.weights_memory += - cost_metrics.total_mem_diff_from(sim->offset); - - backward = [&](ffStream_t stream) { - backward_kernel(stream, - m, - input_ptr, - output_grad_ptr, - output_ptr, - input_grad_ptr, - scale_ptr, - scale_grad_ptr, - bias_grad_ptr, - sub_output.get_volume()); - }; - } - - inner_measure_operator_cost(sim, forward, backward, cost_metrics); - - if (sim->computationMode == COMP_MODE_TRAINING) { - printf("[Measure BatchNorm] name(%s) size(%zu) forward_time(%.4lf) " - "backward_time(%.4lf)\n", - name, - sub_input.get_volume(), - cost_metrics.forward_time, - cost_metrics.backward_time); - } else { - printf("[Measure BatchNorm] name(%s) size(%zu) forward_time(%.4lf)\n", - name, - sub_input.get_volume(), - cost_metrics.forward_time); - } - // Free batchnormmeta - delete m; - return true; +template <> +void register_task() { + OpTaskSignature fwd(OpTaskType::FWD); + + fwd.add_input_slot(INPUT); + fwd.add_input_slot(SCALE); + fwd.add_input_slot(BIAS); + fwd.add_output_slot(OUTPUT); + fwd.add_arg_slot(PROFILING); + fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); + + register_task(BATCHNORM_FWD_TASK_ID, "BatchNorm Fwd", fwd, forward_task); +} + +template <> +void register_task() { + OpTaskSignature bwd = + infer_bwd_signature(get_op_signature(BATCHNORM_FWD_TASK_ID)); + + register_task(BATCHNORM_BWD_TASK_ID, "BatchNorm Bwd", bwd, backward_task); } }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/batch_norm.h b/lib/runtime/src/ops/batch_norm.h index e54331665e..94bda5122b 100644 --- a/lib/runtime/src/ops/batch_norm.h +++ b/lib/runtime/src/ops/batch_norm.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_BATCH_NORM_H #include "op-attrs/ops/batch_norm.h" -#include "op_task_invocation.h" #include "sim_environment.h" +#include "task_spec/op_task_invocation.h" namespace FlexFlow { @@ -66,3 +66,243 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, } // namespace FlexFlow #endif + +// void BatchNorm::init(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_INIT_TASK_ID, get_init_task_signature()); +// } + +// void BatchNorm::forward(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_FWD_TASK_ID, get_fwd_task_signature()); +// } + +// void BatchNorm::backward(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_BWD_TASK_ID, get_bwd_task_signature()); +// } + +// Tensor batch_norm(const Tensor input, bool relu, char const *name) { +// assert(input->num_dims == 4); /*NCHW*/ +// Layer *bm = new Layer(this, +// OP_BATCHNORM, +// DT_FLOAT, +// name, +// 1 /*inputs*/, +// 2 /*weights*/, +// 1 /*outputs*/, +// input); +// int numdims = 4; +// bm->outputs[0] = create_tensor_legion_ordering( +// numdims, input->dims, DT_FLOAT, bm, 0, true /*create_grad*/); +// bm->add_int_property("relu", relu); +// layers.push_back(bm); +// return bm->outputs[0]; +// } + +// BatchNorm::BatchNorm(FFModel &model, +// const ParallelTensor _input, +// const ParallelTensor _scale, +// const ParallelTensor _bias, +// bool _relu, +// char const *name) +// : Op(model, +// OP_BATCHNORM, +// DT_FLOAT, +// name, +// 1 /*inputs*/, +// 2 /*weights*/, +// 1 /*outputs*/, +// _input, +// _scale, +// _bias), +// relu(_relu) { +// assert(_input->num_dims == 4); +// numOutputs = 1; +// ParallelDim dims[MAX_TENSOR_DIM]; +// for (int i = 0; i < _input->num_dims; i++) { +// dims[i] = _input->dims[_input->num_dims - 1 - i]; +// } +// outputs[0] = +// model.create_parallel_tensor(_input->num_dims, dims, DT_FLOAT, this); +// return; +// } + +/* + locals[0] = scale + locals[1] = bias +*/ + +// void BatchNorm::init(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_INIT_TASK_ID, get_init_task_signature()); +// assert(check_output_input_weight_same_parallel_is()); +// parallel_is = outputs[0]->parallel_is; +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_init(ff, argmap); +// IndexLauncher launcher(BATCHNORM_INIT_TASK_ID, +// parallel_is, +// TaskArgument(this, sizeof(BatchNorm)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(1, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(weights[0]->region, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// weights[0]->region)); +// launcher.add_field(2, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(weights[1]->region, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// weights[1]->region)); +// launcher.add_field(3, FID_DATA); +// FutureMap fm = runtime->execute_index_space(ctx, launcher); +// fm.wait_all_results(); +// set_opmeta_from_futuremap(ff, fm); +// } + +/* + regions[0]: input + regions[1]: output + regions[2](I): scale + regions[3](I): bias +*/ + +// void BatchNorm::forward(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_FWD_TASK_ID, get_fwd_task_signature()); +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_forward(ff, argmap); +// IndexLauncher launcher(BATCHNORM_FWD_TASK_ID, +// parallel_is, +// TaskArgument(NULL, 0), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_DISCARD, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(1, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(weights[0]->region, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// weights[0]->region)); +// launcher.add_field(2, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(weights[1]->region, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// weights[1]->region)); +// launcher.add_field(3, FID_DATA); + +// runtime->execute_index_space(ctx, launcher); +// } + +/* + regions[0](I): input + regions[1](O): ouptut + regions[2](I): scale + regions[3](I): bias +*/ + +// void BatchNorm::backward(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_BWD_TASK_ID, get_bwd_task_signature()); +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_backward(ff, argmap); +// IndexLauncher launcher(BATCHNORM_BWD_TASK_ID, +// parallel_is, +// TaskArgument(NULL, 0), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// // regions[0](I): input +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// // regions[1](I/O): input_grad (we only need grad tensors) +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// inputs[0]->region_grad)); +// launcher.add_field(1, FID_DATA); +// // regions[2](I): output +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(2, FID_DATA); +// // regions[3](I/O): output_grad +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// outputs[0]->region_grad)); +// launcher.add_field(3, FID_DATA); +// // regions[4](I): filter +// launcher.add_region_requirement(RegionRequirement(weights[0]->region, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// weights[0]->region)); +// launcher.add_field(4, FID_DATA); +// // regions[5](I/O): filter_grad +// launcher.add_region_requirement(RegionRequirement(weights[0]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// weights[0]->region_grad)); +// launcher.add_field(5, FID_DATA); +// // regions[6](I/O): bias_grad +// launcher.add_region_requirement(RegionRequirement(weights[1]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// weights[1]->region_grad)); +// launcher.add_field(6, FID_DATA); +// FutureMap fm = runtime->execute_index_space(ctx, launcher); +// } + +/* + regions[0](I): input + regions[1](I/O): input_grad + regions[2](I): output + regions[3](I/O): output_grad + regions[4](I): scale + regions[5](I/O): scale_grad + regions[6](I/O): bias_grad +*/ From a32f4e55b34d400b4de8a7daed52f2390d116cc5 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Mon, 28 Aug 2023 10:53:23 -0700 Subject: [PATCH 08/26] cast op --- lib/kernels/include/kernels/cast_kernels.h | 17 +- lib/runtime/src/ops/cast.cc | 501 +++++---------------- lib/runtime/src/ops/cast.h | 267 ++++++++++- 3 files changed, 394 insertions(+), 391 deletions(-) diff --git a/lib/kernels/include/kernels/cast_kernels.h b/lib/kernels/include/kernels/cast_kernels.h index d43446883c..28985f5501 100644 --- a/lib/kernels/include/kernels/cast_kernels.h +++ b/lib/kernels/include/kernels/cast_kernels.h @@ -3,19 +3,26 @@ #include "kernels/accessor.h" #include "kernels/device.h" -#include "op-attrs/ffconst.h" namespace FlexFlow { -class CastPerDeviceState : public PerDeviceOpState { -public: - CastPerDeviceState(FFHandler handle); - DataType input_data_type, output_data_type; +struct CastPerDeviceState { + PerDeviceFFHandle handle; + DataType input_data_type; + req output_data_type; }; +FF_VISITABLE_STRUCT_NO_EQ(CastPerDeviceState, + handle, + input_data_type, + output_data_type); + namespace Kernels { namespace Cast { +CastPerDeviceState + init_kernel(PerDeviceFFHandle const &, DataType input, DataType output); + void forward_kernel(ffStream_t stream, CastPerDeviceState const *, GenericTensorAccessorR const &input, diff --git a/lib/runtime/src/ops/cast.cc b/lib/runtime/src/ops/cast.cc index 23c1bc9940..36afdefcef 100644 --- a/lib/runtime/src/ops/cast.cc +++ b/lib/runtime/src/ops/cast.cc @@ -16,441 +16,178 @@ #include "cast.h" #include "kernels/cast_kernels.h" #include "legion/legion_utilities.h" -#include "task_spec.h" #include "utils/hash-utils.h" using namespace FlexFlow::Kernels::Cast; -namespace FlexFlow { - -enum Slots { - INPUT, - OUTPUT, - INPUT_GRAD, - OUTPUT_GRAD, - ATTRS, - PROFILING -} - -// declare Legion names -using Legion::ArgumentMap; using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; using Legion::Runtime; using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; - -Tensor FFModel::cast(const Tensor input, DataType dtype, char const *name) { - Layer *cast = new Layer(this, - OP_CAST, - dtype, - name, - 1 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - input); - int numdims = input->num_dims; - int dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdims; i++) { - dims[i] = input->dims[i]; - } - cast->outputs[0] = create_tensor_legion_ordering( - numdims, dims, dtype, cast, 0, true /*create_grad*/); - cast->add_int_property("dtype", dtype); - layers.push_back(cast); - return cast->outputs[0]; -} - -Op *Cast::create_operator_from_layer( - FFModel &model, - Layer const *layer, - std::vector const &inputs) { - long long value; - layer->get_int_property("dtype", value); - DataType dtype = (DataType)value; - return new Cast(model, inputs[0], dtype, layer->name); -} - -CastParams Cast::get_params() const { - CastParams params; - params.dtype = this->outputs[0]->data_type; - return params; -} - -Cast::Cast(FFModel &model, - ParallelTensor const &input, - DataType _dtype, - char const *name) - : Op(model, - OP_CAST, - _dtype, - name, - 1 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - input) { - numOutputs = 1; - numWeights = 0; - int numdim = input->num_dims; - ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdim; i++) { - dims[i] = input->dims[i]; - } - outputs[0] = - model.create_parallel_tensor_legion_ordering(numdim, dims, _dtype, this); -} - -Cast::Cast(FFModel &model, - CastParams const ¶ms, - ParallelTensor const &input, - char const *name) - : Cast(model, input, params.dtype, name) {} - -static OpTaskSignature get_init_task_signature() { - OpTaskSignature init(OpTaskType::INIT); - - init.add_arg_slot(ATTRS); - init.add_arg_slot(PROFILING); - - init.add_input_slot(INPUT); - init.add_output_slot(OUTPUT); - - return init; -} - -static OpTaskSignature get_fwd_task_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_arg_slot(ATTRS); - - fwd.add_input_slot(INPUT); - fwd.add_output_slot(OUTPUT); - - return init; -} - -static OpTaskSignature get_bwd_task_signature() { - OpTaskSignature bwd(OpTaskType::BWD); - - bwd.add_arg_slot(ATTRS); - bwd.add_input_grad_slot(INPUT_GRAD); - bwd.add_output_grad_slot(OUTPUT_GRAD); +namespace FlexFlow { - return bwd; -} +enum Slots { INPUT, OUTPUT, ATTRS, PROFILING, PER_DEVICE_STATE, HANDLE }; -OpTaskBinding Cast::get_init_task_binding() const { +// declare Legion names +// using Legion::ArgumentMap; +// using Legion::Context; +// using Legion::coord_t; +// using Legion::Domain; +// using Legion::FutureMap; +// using Legion::IndexLauncher; +// using Legion::PhysicalRegion; +// using Legion::Predicate; +// using Legion::Rect; +// using Legion::RegionRequirement; +// using Legion::Runtime; +// using Legion::Task; +// using Legion::TaskArgument; +// using Legion::TaskLauncher; + +OpTaskInvocation init(CastAttrs const &attrs) { OpTaskBinding binding; - binding.bind_arg(PROFILING, this->profiling); - binding.bind_arg(ATTRS, this->attrs); + binding.bind_arg(HANDLE, ff_handle()); binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return binding; + return {CAST_INIT_TASK_ID, binding}; } -OpTaskBinding Cast::get_fwd_task_binding() const { +OpTaskInvocation forward(CastAttrs const &attrs) { OpTaskBinding binding; - binding.bind_arg(ATTRS, this->attrs); + binding.bind_arg(PER_DEVICE_STATE, per_device_op_state()); + binding.bind_arg(PROFILING, profiling_settings()); binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return binding; + return {CAST_FWD_TASK_ID, binding}; } -OpTaskBinding Cast::get_bwd_task_binding() const { - OpTaskBinding binding; +OpTaskInvocation backward(CastAttrs const &attrs) { + OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - binding.bind_arg(ATTRS, this->attrs); + return {CAST_BWD_TASK_ID, binding}; +} - binding.bind(INPUT_GRAD, input_tensor(0).grad()); - binding.bind(OUTPUT_GRAD, output_tensor(0).grad()); +static DeviceSpecific + init_task_impl(TaskArgumentAccessor const &acc) { - return binding; -} + PerDeviceFFHandle handle = acc.get_argument(HANDLE); + auto input = acc.get_tensor(INPUT); + auto output = acc.get_tensor(OUTPUT); -void Cast::init(FFModel const &ff) { - this->execute_task(ff, CAST_INIT_TASK_ID, get_init_task_signature()); - // assert(check_output_input_weight_same_parallel_is()); - // parallel_is = outputs[0]->parallel_is; - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_init(ff, argmap); - // IndexLauncher launcher(CAST_INIT_TASK_ID, - // parallel_is, - // TaskArgument(this, sizeof(Cast)), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[0]->region)); - // launcher.add_field(1, FID_DATA); - // FutureMap fm = runtime->execute_index_space(ctx, launcher); - // fm.wait_all_results(); - // set_opmeta_from_futuremap(ff, fm); + DeviceSpecific per_device_state = + acc.create_device_specific( + init_kernel(handle, input.data_type, output.data_type)); + return per_device_state; } -PerDeviceOpState *Cast::init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { +static DeviceSpecific + init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { TaskArgumentAccessor acc(task, regions, ctx, runtime); - - FFHandler handler = *((FFHandler const *)task->local_args); - CastPerDeviceState *m = new CastPerDeviceState(handler); - bool profiling = acc.get_argument(PROFILING); - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - m->input_data_type = input->data_type; - m->output_data_type = output->data_type; - m->profiling = profiling; - return m; + return init_task_impl(acc); } -void Cast::forward(FFModel const &ff) { - this->execute_task(ff, CAST_FWD_TASK_ID, get_fwd_task_signature()); - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_forward(ff, argmap); - // IndexLauncher launcher(CAST_FWD_TASK_ID, - // parallel_is, - // TaskArgument(NULL, false), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(1, FID_DATA); - // runtime->execute_index_space(ctx, launcher); -} +static optional forward_task_impl(TaskArgumentAccessor const &acc) { + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + ProfilingSettings profiling = acc.get_argument(PROFILING); -// template -// void Cast::forward_task_with_1_type(Task const *task, -// std::vector const -// ®ions, Context ctx, Runtime *runtime) -// { -// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); -// if (m->output_data_type == DT_FLOAT) { -// Cast::forward_task_with_2_type(task, regions, ctx, runtime); -// } else if (m->output_data_type == DT_DOUBLE) { -// Cast::forward_task_with_2_type(task, regions, ctx, runtime); -// } else if (m->output_data_type == DT_INT32) { -// Cast::forward_task_with_2_type(task, regions, ctx, -// runtime); -// } else if (m->output_data_type == DT_INT64) { -// Cast::forward_task_with_2_type(task, regions, ctx, -// runtime); -// } -// } - -// template -// void Cast::forward_task_with_2_type(Task const *task, -// std::vector const -// ®ions, Context ctx, Runtime *runtime) -// { -// assert(regions.size() == 2); -// assert(task->regions.size() == regions.size()); -// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); -// // Domain input_domain = runtime->get_index_space_domain( -// // ctx, task->regions[0].region.get_index_space()); -// Domain output_domain = runtime->get_index_space_domain( -// ctx, task->regions[1].region.get_index_space()); -// const IDT *input_ptr = helperGetTensorPointerRO( -// regions[0], task->regions[0], FID_DATA, ctx, runtime); -// ODT *output_ptr = helperGetTensorPointerWO( -// regions[1], task->regions[1], FID_DATA, ctx, runtime); -// forward_kernel_wrapper( -// m, input_ptr, output_ptr, output_domain.get_volume()); -// } - -void Cast::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); - // if (m->input_data_type == DT_FLOAT) { - // Cast::forward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->input_data_type == DT_DOUBLE) { - // Cast::forward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->input_data_type == DT_INT32) { - // Cast::forward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->input_data_type == DT_INT64) { - // Cast::forward_task_with_1_type(task, regions, ctx, runtime); - // } - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - - profile(forward_kernel, - m->profiling, - "[Cast] forward_time = %.2lfms\n", - m, - input, - output) -} + auto input = acc.get_tensor(INPUT); + auto output = acc.get_tensor(OUTPUT); -void Cast::backward(FFModel const &ff) { - this->execute_task(ff, CAST_BWD_TASK_ID, get_bwd_task_signature()); - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_backward(ff, argmap); - // IndexLauncher launcher(CAST_BWD_TASK_ID, - // parallel_is, - // TaskArgument(NULL, false), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // outputs[0]->region_grad)); - // launcher.add_field(0, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // inputs[0]->region_grad)); - // launcher.add_field(1, FID_DATA); - // runtime->execute_index_space(ctx, launcher); + return profile(forward_kernel, + profiling, + "[Cast] forward_time = %.2lfms\n", + &per_device_state, + input, + output); } -// template -// void Cast::backward_task_with_1_type(Task const *task, -// std::vector const -// ®ions, Context ctx, Runtime *runtime) -// { -// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); -// if (m->input_data_type == DT_FLOAT) { -// Cast::backward_task_with_2_type(task, regions, ctx, runtime); -// } else if (m->input_data_type == DT_DOUBLE) { -// Cast::backward_task_with_2_type(task, regions, ctx, -// runtime); -// } else if (m->input_data_type == DT_INT32) { -// Cast::backward_task_with_2_type(task, regions, ctx, -// runtime); -// } else if (m->input_data_type == DT_INT64) { -// Cast::backward_task_with_2_type(task, regions, ctx, -// runtime); -// } -// } - -// template -// void Cast::backward_task_with_2_type(Task const *task, -// std::vector const -// ®ions, Context ctx, Runtime *runtime) -// { -// assert(regions.size() == 2); -// assert(task->regions.size() == regions.size()); -// // Domain input_domain = runtime->get_index_space_domain( -// // ctx, task->regions[0].region.get_index_space()); -// Domain output_domain = runtime->get_index_space_domain( -// ctx, task->regions[1].region.get_index_space()); -// const IDT *input_ptr = helperGetTensorPointerRO( -// regions[0], task->regions[0], FID_DATA, ctx, runtime); -// ODT *output_ptr = helperGetTensorPointerRW( -// regions[1], task->regions[1], FID_DATA, ctx, runtime); -// backward_kernel_wrapper( -// input_ptr, output_ptr, output_domain.get_volume()); -// } - -void Cast::backward_task(Task const *task, +static void forward_task(Task const *task, std::vector const ®ions, Context ctx, Runtime *runtime) { - CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); - // if (m->output_data_type == DT_FLOAT) { - // Cast::backward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->output_data_type == DT_DOUBLE) { - // Cast::backward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->output_data_type == DT_INT32) { - // Cast::backward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->output_data_type == DT_INT64) { - // Cast::backward_task_with_1_type(task, regions, ctx, runtime); - // } - auto input_grad = acc.get_tensor(INPUT); - auto output_grad = acc.get_tensor(OUTPUT); - - profile(backward_kernel, - m->profiling, - "[Cast] forward_time = %.2lfms\n", - m, - input_grad, - output_grad) + TaskArgumentAccessor acc(task, regions, ctx, runtime); + forward_task_impl(acc); +} + +static optional backward_task_impl(TaskArgumentAccessor const &acc) { + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + ProfilingSettings profiling = acc.get_argument(PROFILING); + + auto input_grad = acc.get_tensor_grad(INPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); + + return profile(backward_kernel, + profiling, + "[Cast] forward_time = %.2lfms\n", + &per_device_state, + input_grad, + output_grad); +} + +static void backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + backward_task_impl(acc); } -bool Cast::measure_operator_cost(Simulator *sim, - MachineView const &mv, - CostMetrics &cost_metrics) const { +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + CastAttrs const &attrs, + InputParallelTensorDesc const &input_shape, + ProfilingSettings const &settings, + MachineView const &mv) { + auto env = sim.new_environment(); + // Assume cast has no cost - cost_metrics.forward_time = 0.0f; - cost_metrics.backward_time = 0.0f; - cost_metrics.inputs_memory = 0; - cost_metrics.outputs_memory = 0; - cost_metrics.weights_memory = 0; - return true; + float forward_time = 0.0; + float backward_time = 0.0; + float sync_time = 0.0; + return make_metrics(forward_time, backward_time, sync_time, env); } -void Cast::serialize(Legion::Serializer &sez) const { - sez.serialize(this->outputs[0]->data_type); +template <> +void register_task() { + OpTaskSignature init(OpTaskType::INIT); + + init.add_unchecked_arg_slot(HANDLE); + + init.add_input_slot(INPUT); + init.add_output_slot(OUTPUT); + + register_task(CAST_INIT_TASK_ID, "Cast Init", init, init_task); } -using PCG::Node; +template <> +void register_task() { + OpTaskSignature fwd(OpTaskType::FWD); + + fwd.add_arg_slot(PROFILING); + fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); -Node Cast::deserialize(FFModel &ff, - Legion::Deserializer &dez, - ParallelTensor inputs[], - int num_inputs) { - assert(num_inputs == 1); - DataType dtype; - dez.deserialize(dtype); - return ff.get_or_create_node(inputs[0], {dtype}); + fwd.add_input_slot(INPUT); + fwd.add_output_slot(OUTPUT); + + register_task(CAST_FWD_TASK_ID, "Cast Fwd", fwd, forward_task); } -Op *Cast::materialize(FFModel &ff, - ParallelTensor inputs[], - int num_inputs) const { - assert(num_inputs == 1); - return new Cast(ff, inputs[0], this->outputs[0]->data_type, this->name); +template <> +void register_task() { + OpTaskSignature bwd = infer_bwd_signature(get_op_signature(CAST_FWD_TASK_ID)); + + register_task(CAST_BWD_TASK_ID, "Cast Bwd", bwd, backward_task); } }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/cast.h b/lib/runtime/src/ops/cast.h index 7d346584d5..c3781ad783 100644 --- a/lib/runtime/src/ops/cast.h +++ b/lib/runtime/src/ops/cast.h @@ -16,8 +16,8 @@ #define _FLEXFLOW_CAST_H #include "op-attrs/ops/cast.h" -#include "op_task_invocation.h" #include "sim_environment.h" +#include "task_spec/op_task_invocation.h" namespace FlexFlow { @@ -33,11 +33,54 @@ OpTaskInvocation forward(CastAttrs const &); OpTaskInvocation backward(CastAttrs const &); CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, - BatchNormAttrs const &attrs, + CastAttrs const &attrs, ParallelTensorShape const &input_shape, ProfilingSettings const &settings, MachineView const &machine_view); +} // namespace FlexFlow + +#endif + +// template +// void Cast::backward_task_with_1_type(Task const *task, +// std::vector const +// ®ions, Context ctx, Runtime *runtime) +// { +// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); +// if (m->input_data_type == DT_FLOAT) { +// Cast::backward_task_with_2_type(task, regions, ctx, runtime); +// } else if (m->input_data_type == DT_DOUBLE) { +// Cast::backward_task_with_2_type(task, regions, ctx, +// runtime); +// } else if (m->input_data_type == DT_INT32) { +// Cast::backward_task_with_2_type(task, regions, ctx, +// runtime); +// } else if (m->input_data_type == DT_INT64) { +// Cast::backward_task_with_2_type(task, regions, ctx, +// runtime); +// } +// } + +// template +// void Cast::backward_task_with_2_type(Task const *task, +// std::vector const +// ®ions, Context ctx, Runtime *runtime) +// { +// assert(regions.size() == 2); +// assert(task->regions.size() == regions.size()); +// // Domain input_domain = runtime->get_index_space_domain( +// // ctx, task->regions[0].region.get_index_space()); +// Domain output_domain = runtime->get_index_space_domain( +// ctx, task->regions[1].region.get_index_space()); +// const IDT *input_ptr = helperGetTensorPointerRO( +// regions[0], task->regions[0], FID_DATA, ctx, runtime); +// ODT *output_ptr = helperGetTensorPointerRW( +// regions[1], task->regions[1], FID_DATA, ctx, runtime); +// backward_kernel_wrapper( +// input_ptr, output_ptr, output_domain.get_volume()); +// } + /* class Cast : public Op { */ /* public: */ /* Cast(FFModel &model, */ @@ -80,6 +123,222 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, /* CostMetrics &cost_metrics) const; */ /* }; */ -} // namespace FlexFlow +// void Cast::backward(FFModel const &ff) { +// this->execute_task(ff, CAST_BWD_TASK_ID, get_bwd_task_signature()); +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_backward(ff, argmap); +// IndexLauncher launcher(CAST_BWD_TASK_ID, +// parallel_is, +// TaskArgument(NULL, false), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// outputs[0]->region_grad)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// inputs[0]->region_grad)); +// launcher.add_field(1, FID_DATA); +// runtime->execute_index_space(ctx, launcher); +// } -#endif +// template +// void Cast::forward_task_with_1_type(Task const *task, +// std::vector const +// ®ions, Context ctx, Runtime *runtime) +// { +// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); +// if (m->output_data_type == DT_FLOAT) { +// Cast::forward_task_with_2_type(task, regions, ctx, runtime); +// } else if (m->output_data_type == DT_DOUBLE) { +// Cast::forward_task_with_2_type(task, regions, ctx, runtime); +// } else if (m->output_data_type == DT_INT32) { +// Cast::forward_task_with_2_type(task, regions, ctx, +// runtime); +// } else if (m->output_data_type == DT_INT64) { +// Cast::forward_task_with_2_type(task, regions, ctx, +// runtime); +// } +// } + +// template +// void Cast::forward_task_with_2_type(Task const *task, +// std::vector const +// ®ions, Context ctx, Runtime *runtime) +// { +// assert(regions.size() == 2); +// assert(task->regions.size() == regions.size()); +// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); +// // Domain input_domain = runtime->get_index_space_domain( +// // ctx, task->regions[0].region.get_index_space()); +// Domain output_domain = runtime->get_index_space_domain( +// ctx, task->regions[1].region.get_index_space()); +// const IDT *input_ptr = helperGetTensorPointerRO( +// regions[0], task->regions[0], FID_DATA, ctx, runtime); +// ODT *output_ptr = helperGetTensorPointerWO( +// regions[1], task->regions[1], FID_DATA, ctx, runtime); +// forward_kernel_wrapper( +// m, input_ptr, output_ptr, output_domain.get_volume()); +// } + +// void Cast::forward(FFModel const &ff) { +// this->execute_task(ff, CAST_FWD_TASK_ID, get_fwd_task_signature()); +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_forward(ff, argmap); +// IndexLauncher launcher(CAST_FWD_TASK_ID, +// parallel_is, +// TaskArgument(NULL, false), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(1, FID_DATA); +// runtime->execute_index_space(ctx, launcher); +// } + +// void Cast::init(FFModel const &ff) { +// this->execute_task(ff, CAST_INIT_TASK_ID, get_init_task_signature()); +// assert(check_output_input_weight_same_parallel_is()); +// parallel_is = outputs[0]->parallel_is; +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_init(ff, argmap); +// IndexLauncher launcher(CAST_INIT_TASK_ID, +// parallel_is, +// TaskArgument(this, sizeof(Cast)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(1, FID_DATA); +// FutureMap fm = runtime->execute_index_space(ctx, launcher); +// fm.wait_all_results(); +// set_opmeta_from_futuremap(ff, fm); +// } + +// void Cast::serialize(Legion::Serializer &sez) const { +// sez.serialize(this->outputs[0]->data_type); +// } + +// using PCG::Node; + +// Node Cast::deserialize(FFModel &ff, +// Legion::Deserializer &dez, +// ParallelTensor inputs[], +// int num_inputs) { +// assert(num_inputs == 1); +// DataType dtype; +// dez.deserialize(dtype); +// return ff.get_or_create_node(inputs[0], {dtype}); +// } + +// Op *Cast::materialize(FFModel &ff, +// ParallelTensor inputs[], +// int num_inputs) const { +// assert(num_inputs == 1); +// return new Cast(ff, inputs[0], this->outputs[0]->data_type, this->name); +// } + +// Cast::Cast(FFModel &model, +// ParallelTensor const &input, +// DataType _dtype, +// char const *name) +// : Op(model, +// OP_CAST, +// _dtype, +// name, +// 1 /*inputs*/, +// 0 /*weights*/, +// 1 /*outputs*/, +// input) { +// numOutputs = 1; +// numWeights = 0; +// int numdim = input->num_dims; +// ParallelDim dims[MAX_TENSOR_DIM]; +// for (int i = 0; i < numdim; i++) { +// dims[i] = input->dims[i]; +// } +// outputs[0] = +// model.create_parallel_tensor_legion_ordering(numdim, dims, _dtype, +// this); +// } + +// Tensor FFModel::cast(const Tensor input, DataType dtype, char const *name) { +// Layer *cast = new Layer(this, +// OP_CAST, +// dtype, +// name, +// 1 /*inputs*/, +// 0 /*weights*/, +// 1 /*outputs*/, +// input); +// int numdims = input->num_dims; +// int dims[MAX_TENSOR_DIM]; +// for (int i = 0; i < numdims; i++) { +// dims[i] = input->dims[i]; +// } +// cast->outputs[0] = create_tensor_legion_ordering( +// numdims, dims, dtype, cast, 0, true /*create_grad*/); +// cast->add_int_property("dtype", dtype); +// layers.push_back(cast); +// return cast->outputs[0]; +// } + +// Op *Cast::create_operator_from_layer( +// FFModel &model, +// Layer const *layer, +// std::vector const &inputs) { +// long long value; +// layer->get_int_property("dtype", value); +// DataType dtype = (DataType)value; +// return new Cast(model, inputs[0], dtype, layer->name); +// } + +// CastParams Cast::get_params() const { +// CastParams params; +// params.dtype = this->outputs[0]->data_type; +// return params; +// } + +// Cast::Cast(FFModel &model, +// CastParams const ¶ms, +// ParallelTensor const &input, +// char const *name) +// : Cast(model, input, params.dtype, name) {} From d02998d2b1aed5cd8d1c9c9ba4b687d275844914 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Mon, 28 Aug 2023 13:53:04 -0700 Subject: [PATCH 09/26] combine --- lib/kernels/include/kernels/combine_kernels.h | 11 +- lib/op-attrs/include/op-attrs/get_op_type.h | 1 + .../include/op-attrs/operator_attrs.h | 2 + lib/runtime/src/ops/combine.cc | 391 ++++++------------ lib/runtime/src/ops/combine.h | 276 ++++++++++++- 5 files changed, 400 insertions(+), 281 deletions(-) diff --git a/lib/kernels/include/kernels/combine_kernels.h b/lib/kernels/include/kernels/combine_kernels.h index 174d1eb925..44ab67d9a7 100644 --- a/lib/kernels/include/kernels/combine_kernels.h +++ b/lib/kernels/include/kernels/combine_kernels.h @@ -6,15 +6,18 @@ namespace FlexFlow { -class CombinePerDeviceState : public PerDeviceOpState { -public: - CombinePerDeviceState(FFHandler handle); - DataType data_type; +struct CombinePerDeviceState { + req data_type; }; +FF_VISITABLE_STRUCT_NO_EQ(CombinePerDeviceState, + data_type); + namespace Kernels { namespace Combine { +CombinePerDeviceState init_kernel(DataType data_type); + void forward_kernel(ffStream_t stream, CombinePerDeviceState const *m, GenericTensorAccessorR const &input, diff --git a/lib/op-attrs/include/op-attrs/get_op_type.h b/lib/op-attrs/include/op-attrs/get_op_type.h index 8b451b2705..910d5dc925 100644 --- a/lib/op-attrs/include/op-attrs/get_op_type.h +++ b/lib/op-attrs/include/op-attrs/get_op_type.h @@ -12,6 +12,7 @@ OperatorType get_op_type(BatchMatmulAttrs const &); OperatorType get_op_type(BatchNormAttrs const &); OperatorType get_op_type(BroadcastAttrs const &); OperatorType get_op_type(CastAttrs const &); +OperatorType get_op_type(CombineAttrs const &); OperatorType get_op_type(ConcatAttrs const &); OperatorType get_op_type(Conv2DAttrs const &); OperatorType get_op_type(DropoutAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/operator_attrs.h b/lib/op-attrs/include/op-attrs/operator_attrs.h index 5fd067313e..b64fe73497 100644 --- a/lib/op-attrs/include/op-attrs/operator_attrs.h +++ b/lib/op-attrs/include/op-attrs/operator_attrs.h @@ -43,6 +43,7 @@ using SharedOperatorAttrs = variant::value, ""); static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); +static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); diff --git a/lib/runtime/src/ops/combine.cc b/lib/runtime/src/ops/combine.cc index 2485955124..c2d8d9a017 100644 --- a/lib/runtime/src/ops/combine.cc +++ b/lib/runtime/src/ops/combine.cc @@ -13,322 +13,163 @@ * limitations under the License. */ -#include "parallel_ops/combine.h" +#include "combine.h" #include "kernels/combine_kernels.h" #include "utils/hash-utils.h" namespace FlexFlow { // declare Legion names -using Legion::ArgumentMap; using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; -using Legion::LogicalPartition; -using Legion::LogicalRegion; -using Legion::Machine; -using Legion::Memory; using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; using Legion::Runtime; using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; using namespace FlexFlow::Kernels::Combine; -CombineParams Combine::get_params() const { - CombineParams params; - params.combine_legion_dim = this->combine_dim; - params.combine_degree = this->combine_degree; - return params; -} +enum Slots { + INPUT, + OUTPUT, + PROFILING, + PER_DEVICE_STATE +}; -Combine::Combine(FFModel &model, - CombineParams const ¶ms, - ParallelTensor const input, - char const *name) - : Combine(model, - input, - params.combine_legion_dim, - params.combine_degree, - name) {} - -Combine::Combine(FFModel &model, - const ParallelTensor _input, - int _combine_legion_dim, - int _combine_degree, - char const *name) - : ParallelOp(model, OP_COMBINE, name, _input), - combine_dim(_combine_legion_dim), combine_degree(_combine_degree) { - int numdim = _input->num_dims; - ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdim; i++) { - dims[i] = _input->dims[i]; - } - assert(combine_degree > 0 && "Must use combine_degree > 0"); - assert(dims[combine_dim].degree % combine_degree == 0); - dims[combine_dim].degree /= combine_degree; - ParallelTensorBase::update_parallel_ids(numdim, dims); - outputs[0] = model.create_parallel_tensor_legion_ordering( - numdim, dims, DT_FLOAT, this); - // inputs[0]->print("Combine::input"); - // outputs[0]->print("Combine::output"); -} +OpTaskInvocation init(CombineAttrs const &attrs) { + OpTaskBinding binding; + + binding.bind(INPUT, input_tensor(0)); -PerDeviceOpState *Combine::init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - Combine *rep = (Combine *)task->args; - // FFHandler handle = *((FFHandler *)task->local_args); - // CombineMeta* m = new CombineMeta(handle); - // m->data_type = rep->outputs[0]->data_type; - return nullptr; + return {COMBINE_INIT_TASK_ID, binding}; } -void Combine::init(FFModel const &ff) { - parallel_is = outputs[0]->parallel_is; - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - assert(numOutputs == 1); - assert(numInputs == 1); - IndexLauncher launcher(COMBINE_INIT_TASK_ID, - parallel_is, - TaskArgument(this, sizeof(Combine)), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - outputs[0]->machine_view.hash()); - launcher.add_region_requirement(RegionRequirement( - input_lp, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, inputs[0]->region)); - launcher.add_field(0, FID_DATA); - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - outputs[0]->region)); - launcher.add_field(1, FID_DATA); - FutureMap fm = runtime->execute_index_space(ctx, launcher); - fm.wait_all_results(); +OpTaskInvocation forward(CombineAttrs const &attrs) { + OpTaskBinding binding; + + binding.bind_arg(PER_DEVICE_STATE, per_device_op_state()); + binding.bind_arg(PROFILING, profiling_settings()); + + binding.bind(INPUT, input_tensor(0)); + binding.bind(OUTPUT, output_tensor(0)); + + return {COMBINE_FWD_TASK_ID, binding}; } -void Combine::create_input_partition(FFModel &ff) { - assert(outputs[0]->part != LogicalPartition::NO_PART); - assert(inputs[0]->part != LogicalPartition::NO_PART); - ff.create_disjoint_partition(outputs[0]->num_dims, - outputs[0]->dims, - outputs[0]->parallel_is, - inputs[0]->region, - input_lp); - ff.create_disjoint_partition(inputs[0]->num_dims, - inputs[0]->dims, - inputs[0]->parallel_is, - outputs[0]->region_grad, - output_grad_lp); +OpTaskInvocation backward(CombineAttrs const &attrs) { + OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); + + return {COMBINE_BWD_TASK_ID, b}; } -void Combine::forward(FFModel const &ff) { - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - assert(numOutputs == 1); - assert(numInputs == 1); - assert(inputs[0]->data_type == outputs[0]->data_type); - DataType data_type = inputs[0]->data_type; - IndexLauncher launcher(COMBINE_FWD_TASK_ID, - outputs[0]->parallel_is, - TaskArgument(&data_type, sizeof(data_type)), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - outputs[0]->machine_view.hash()); - launcher.add_region_requirement(RegionRequirement( - input_lp, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, inputs[0]->region)); - launcher.add_field(0, FID_DATA); - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - outputs[0]->region)); - launcher.add_field(1, FID_DATA); - runtime->execute_index_space(ctx, launcher); +static DeviceSpecific + init_task_impl(TaskArgumentAccessor const &acc) { + + auto input = acc.get_tensor(INPUT); + + DeviceSpecific per_device_state = + acc.create_device_specific( + init_kernel(input.data_type)); + return per_device_state; } -void Combine::backward(FFModel const &ff) { - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - assert(numOutputs == 1); - assert(numInputs == 1); - assert(inputs[0]->data_type == outputs[0]->data_type); - DataType data_type = inputs[0]->data_type; - IndexLauncher launcher(COMBINE_BWD_TASK_ID, - inputs[0]->parallel_is, - TaskArgument(&data_type, sizeof(DataType)), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - inputs[0]->machine_view.hash()); - launcher.add_region_requirement(RegionRequirement(output_grad_lp, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - outputs[0]->region_grad)); - launcher.add_field(0, FID_DATA); - launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, - 0 /*projection id*/, - READ_WRITE, - EXCLUSIVE, - inputs[0]->region_grad)); - launcher.add_field(1, FID_DATA); - runtime->execute_index_space(ctx, launcher); +static DeviceSpecific + init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + return init_task_impl(acc); } -bool Combine::measure_operator_cost(Simulator *sim, - MachineView const &mv, - CostMetrics &cost_metrics) const { - // TODO: to be implemented - cost_metrics = CostMetrics(); - cost_metrics.forward_time = 0.05f; - cost_metrics.backward_time = 0.05f; - return true; +static optional forward_task_impl(TaskArgumentAccessor const &acc) { + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + ProfilingSettings profiling = acc.get_argument(PROFILING); + + auto input = acc.get_tensor(INPUT); + auto output = acc.get_tensor(OUTPUT); + + return profile(forward_kernel, + profiling, + "[Combine] forward_time = %.2lfms\n", + &per_device_state, + input, + output); } -bool Combine::get_int_parameter(PMParameter para, int *value) const { - switch (para) { - case PM_COMBINE_DIM: - *value = combine_dim; - return true; - case PM_COMBINE_DEGREE: - *value = combine_degree; - return true; - default: - return Op::get_int_parameter(para, value); - } +static void forward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + forward_task_impl(acc); } -bool Combine::append_parallel_op_info( - std::vector ¶llel_ops) const { - ParallelOpInfo ret; - ret.op_type = op_type; - ret.parallel_dim = combine_dim; - ret.parallel_degree = combine_degree; - parallel_ops.push_back(ret); - return true; +static optional backward_task_impl(TaskArgumentAccessor const &acc) { + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + ProfilingSettings profiling = acc.get_argument(PROFILING); + + auto input_grad = acc.get_tensor_grad(INPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); + + return profile(backward_kernel, + profiling, + "[Combine] forward_time = %.2lfms\n", + &per_device_state, + input_grad, + output_grad); } -tl::optional Combine::as_dot() const { - RecordFormatter rf; - { - std::ostringstream oss; - oss << "dim(" << this->combine_dim << ")"; - rf << oss.str(); - } - { - std::ostringstream oss; - oss << "deg(" << this->combine_degree << ")"; - rf << oss.str(); - } - return rf; +static void backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + backward_task_impl(acc); } -/*static*/ -void Combine::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 2); - assert(task->regions.size() == 2); - DataType data_type = *((DataType *)task->args); - if (data_type == DT_FLOAT) { - forward_task_with_type(task, regions, ctx, runtime); - } else if (data_type == DT_DOUBLE) { - forward_task_with_type(task, regions, ctx, runtime); - } else if (data_type == DT_INT32) { - forward_task_with_type(task, regions, ctx, runtime); - } else if (data_type == DT_INT64) { - forward_task_with_type(task, regions, ctx, runtime); - } else { - assert(false && "Unsupported data type in Combine forward"); - } +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + CombineAttrs const &attrs, + InputParallelTensorDesc const &input_shape, + ProfilingSettings const &settings, + MachineView const &mv) { + auto env = sim.new_environment(); + // TODO: to be implemented + float forward_time = 0.5; + float backward_time = 0.5; + float sync_time = 0.0; + return make_metrics(forward_time, backward_time, sync_time, env); } -template -void Combine::forward_task_with_type(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - Domain input_domain = runtime->get_index_space_domain( - ctx, task->regions[0].region.get_index_space()); - Domain output_domain = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); - assert(output_domain == input_domain); - - const DT *input_ptr = helperGetTensorPointerRO
( - regions[0], task->regions[0], FID_DATA, ctx, runtime); - DT *output_ptr = helperGetTensorPointerWO
( - regions[1], task->regions[1], FID_DATA, ctx, runtime); - - forward_kernel
(input_ptr, output_ptr, output_domain.get_volume()); +template <> +void register_task() { + OpTaskSignature init(OpTaskType::INIT); + + init.add_input_slot(INPUT); + + register_task(COMBINE_INIT_TASK_ID, "Combine Init", init, init_task); } -void Combine::backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 2); - assert(task->regions.size() == 2); - DataType data_type = *((DataType *)task->args); - if (data_type == DT_FLOAT) { - backward_task_with_type(task, regions, ctx, runtime); - } else if (data_type == DT_DOUBLE) { - backward_task_with_type(task, regions, ctx, runtime); - } else if (data_type == DT_INT32) { - backward_task_with_type(task, regions, ctx, runtime); - } else if (data_type == DT_INT64) { - backward_task_with_type(task, regions, ctx, runtime); - } else { - assert(false && "Unsupported data type in Combine backward"); - } +template <> +void register_task() { + OpTaskSignature fwd(OpTaskType::FWD); + + fwd.add_arg_slot(PROFILING); + fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); + + fwd.add_input_slot(INPUT); + fwd.add_output_slot(OUTPUT); + register_task(COMBINE_FWD_TASK_ID, "Combine Fwd", fwd, forward_task); } -template -void Combine::backward_task_with_type( - Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - Domain output_grad_domain = runtime->get_index_space_domain( - ctx, task->regions[0].region.get_index_space()); - Domain input_grad_domain = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); - assert(output_grad_domain == input_grad_domain); - - const DT *output_grad_ptr = helperGetTensorPointerRO
( - regions[0], task->regions[0], FID_DATA, ctx, runtime); - DT *input_grad_ptr = helperGetTensorPointerRW
( - regions[1], task->regions[1], FID_DATA, ctx, runtime); - - backward_kernel
( - output_grad_ptr, input_grad_ptr, output_grad_domain.get_volume()); +template <> +void register_task() { + OpTaskSignature bwd = + infer_bwd_signature(get_op_signature(COMBINE_FWD_TASK_ID)); + + register_task(COMBINE_BWD_TASK_ID, "Combine Bwd", bwd, backward_task); } }; // namespace FlexFlow namespace std { -size_t hash::operator()( - FlexFlow::CombineParams const ¶ms) const { - size_t key = 0; - hash_combine(key, params.combine_legion_dim); - hash_combine(key, params.combine_degree); - return key; -} }; // namespace std diff --git a/lib/runtime/src/ops/combine.h b/lib/runtime/src/ops/combine.h index 455a8ea780..4b96ce6495 100644 --- a/lib/runtime/src/ops/combine.h +++ b/lib/runtime/src/ops/combine.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_COMBINE_H #include "op-attrs/ops/combine.h" -#include "op_task_invocation.h" +#include "task_spec/op_task_invocation.h" #include "sim_environment.h" namespace FlexFlow { @@ -20,7 +20,7 @@ OpTaskInvocation backward(CombineAttrs const &); CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, CombineAttrs const &attrs, - ParallelTensorShape const &input_shape, + InputParallelTensorDesc const &input_shape, ProfilingSettings const &settings, MachineView const &machine_view); @@ -82,3 +82,275 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, } // namespace FlexFlow #endif + +// size_t hash::operator()( +// FlexFlow::CombineParams const ¶ms) const { +// size_t key = 0; +// hash_combine(key, params.combine_legion_dim); +// hash_combine(key, params.combine_degree); +// return key; +// } + +// template +// void Combine::backward_task_with_type( +// Task const *task, +// std::vector const ®ions, +// Context ctx, +// Runtime *runtime) { +// Domain output_grad_domain = runtime->get_index_space_domain( +// ctx, task->regions[0].region.get_index_space()); +// Domain input_grad_domain = runtime->get_index_space_domain( +// ctx, task->regions[1].region.get_index_space()); +// assert(output_grad_domain == input_grad_domain); + +// const DT *output_grad_ptr = helperGetTensorPointerRO
( +// regions[0], task->regions[0], FID_DATA, ctx, runtime); +// DT *input_grad_ptr = helperGetTensorPointerRW
( +// regions[1], task->regions[1], FID_DATA, ctx, runtime); + +// backward_kernel
( +// output_grad_ptr, input_grad_ptr, output_grad_domain.get_volume()); +// } + + +// void Combine::backward_task(Task const *task, +// std::vector const ®ions, +// Context ctx, +// Runtime *runtime) { +// assert(regions.size() == 2); +// assert(task->regions.size() == 2); +// DataType data_type = *((DataType *)task->args); +// if (data_type == DT_FLOAT) { +// backward_task_with_type(task, regions, ctx, runtime); +// } else if (data_type == DT_DOUBLE) { +// backward_task_with_type(task, regions, ctx, runtime); +// } else if (data_type == DT_INT32) { +// backward_task_with_type(task, regions, ctx, runtime); +// } else if (data_type == DT_INT64) { +// backward_task_with_type(task, regions, ctx, runtime); +// } else { +// assert(false && "Unsupported data type in Combine backward"); +// } +// } + +// bool Combine::get_int_parameter(PMParameter para, int *value) const { +// switch (para) { +// case PM_COMBINE_DIM: +// *value = combine_dim; +// return true; +// case PM_COMBINE_DEGREE: +// *value = combine_degree; +// return true; +// default: +// return Op::get_int_parameter(para, value); +// } +// } + +// bool Combine::append_parallel_op_info( +// std::vector ¶llel_ops) const { +// ParallelOpInfo ret; +// ret.op_type = op_type; +// ret.parallel_dim = combine_dim; +// ret.parallel_degree = combine_degree; +// parallel_ops.push_back(ret); +// return true; +// } + +// tl::optional Combine::as_dot() const { +// RecordFormatter rf; +// { +// std::ostringstream oss; +// oss << "dim(" << this->combine_dim << ")"; +// rf << oss.str(); +// } +// { +// std::ostringstream oss; +// oss << "deg(" << this->combine_degree << ")"; +// rf << oss.str(); +// } +// return rf; +// } + + +// void Combine::init(FFModel const &ff) { +// parallel_is = outputs[0]->parallel_is; +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// assert(numOutputs == 1); +// assert(numInputs == 1); +// IndexLauncher launcher(COMBINE_INIT_TASK_ID, +// parallel_is, +// TaskArgument(this, sizeof(Combine)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement( +// input_lp, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, inputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(1, FID_DATA); +// FutureMap fm = runtime->execute_index_space(ctx, launcher); +// fm.wait_all_results(); +// } + +// void Combine::create_input_partition(FFModel &ff) { +// assert(outputs[0]->part != LogicalPartition::NO_PART); +// assert(inputs[0]->part != LogicalPartition::NO_PART); +// ff.create_disjoint_partition(outputs[0]->num_dims, +// outputs[0]->dims, +// outputs[0]->parallel_is, +// inputs[0]->region, +// input_lp); +// ff.create_disjoint_partition(inputs[0]->num_dims, +// inputs[0]->dims, +// inputs[0]->parallel_is, +// outputs[0]->region_grad, +// output_grad_lp); +// } + +// void Combine::forward(FFModel const &ff) { +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// assert(numOutputs == 1); +// assert(numInputs == 1); +// assert(inputs[0]->data_type == outputs[0]->data_type); +// DataType data_type = inputs[0]->data_type; +// IndexLauncher launcher(COMBINE_FWD_TASK_ID, +// outputs[0]->parallel_is, +// TaskArgument(&data_type, sizeof(data_type)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement( +// input_lp, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, inputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(1, FID_DATA); +// runtime->execute_index_space(ctx, launcher); +// } + +// void Combine::backward(FFModel const &ff) { +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// assert(numOutputs == 1); +// assert(numInputs == 1); +// assert(inputs[0]->data_type == outputs[0]->data_type); +// DataType data_type = inputs[0]->data_type; +// IndexLauncher launcher(COMBINE_BWD_TASK_ID, +// inputs[0]->parallel_is, +// TaskArgument(&data_type, sizeof(DataType)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// inputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(output_grad_lp, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// outputs[0]->region_grad)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// inputs[0]->region_grad)); +// launcher.add_field(1, FID_DATA); +// runtime->execute_index_space(ctx, launcher); +// } + + + +// CombineParams Combine::get_params() const { +// CombineParams params; +// params.combine_legion_dim = this->combine_dim; +// params.combine_degree = this->combine_degree; +// return params; +// } + +// Combine::Combine(FFModel &model, +// CombineParams const ¶ms, +// ParallelTensor const input, +// char const *name) +// : Combine(model, +// input, +// params.combine_legion_dim, +// params.combine_degree, +// name) {} + +// Combine::Combine(FFModel &model, +// const ParallelTensor _input, +// int _combine_legion_dim, +// int _combine_degree, +// char const *name) +// : ParallelOp(model, OP_COMBINE, name, _input), +// combine_dim(_combine_legion_dim), combine_degree(_combine_degree) { +// int numdim = _input->num_dims; +// ParallelDim dims[MAX_TENSOR_DIM]; +// for (int i = 0; i < numdim; i++) { +// dims[i] = _input->dims[i]; +// } +// assert(combine_degree > 0 && "Must use combine_degree > 0"); +// assert(dims[combine_dim].degree % combine_degree == 0); +// dims[combine_dim].degree /= combine_degree; +// ParallelTensorBase::update_parallel_ids(numdim, dims); +// outputs[0] = model.create_parallel_tensor_legion_ordering( +// numdim, dims, DT_FLOAT, this); +// // inputs[0]->print("Combine::input"); +// // outputs[0]->print("Combine::output"); +// } + +// /*static*/ +// void Combine::forward_task(Task const *task, +// std::vector const ®ions, +// Context ctx, +// Runtime *runtime) { +// assert(regions.size() == 2); +// assert(task->regions.size() == 2); +// DataType data_type = *((DataType *)task->args); +// if (data_type == DT_FLOAT) { +// forward_task_with_type(task, regions, ctx, runtime); +// } else if (data_type == DT_DOUBLE) { +// forward_task_with_type(task, regions, ctx, runtime); +// } else if (data_type == DT_INT32) { +// forward_task_with_type(task, regions, ctx, runtime); +// } else if (data_type == DT_INT64) { +// forward_task_with_type(task, regions, ctx, runtime); +// } else { +// assert(false && "Unsupported data type in Combine forward"); +// } +// } + +// template +// void Combine::forward_task_with_type(Task const *task, +// std::vector const ®ions, +// Context ctx, +// Runtime *runtime) { +// Domain input_domain = runtime->get_index_space_domain( +// ctx, task->regions[0].region.get_index_space()); +// Domain output_domain = runtime->get_index_space_domain( +// ctx, task->regions[1].region.get_index_space()); +// assert(output_domain == input_domain); + +// const DT *input_ptr = helperGetTensorPointerRO
( +// regions[0], task->regions[0], FID_DATA, ctx, runtime); +// DT *output_ptr = helperGetTensorPointerWO
( +// regions[1], task->regions[1], FID_DATA, ctx, runtime); + +// forward_kernel
(input_ptr, output_ptr, output_domain.get_volume()); +// } From f2205f43b18ca97e332c11d56668e1c0516c6826 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 10:26:59 -0700 Subject: [PATCH 10/26] change --- lib/kernels/include/kernels/batch_matmul_kernels.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index 23522a89f7..6850dec178 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -19,8 +19,8 @@ FF_VISITABLE_STRUCT_NO_EQ( namespace Kernels { namespace BatchMatmul { -BMMPerDeviceState init_kernel(PerDeviceFFHandle handle, - Allocator allocator, +BMMPerDeviceState init_kernel(PerDeviceFFHandle const &handle, + Allocator const &allocator, int a_seq_length_dim, int b_seq_length_dim); From 2f4662da4fcb64dad9de7dd6294e8e6baa3e532e Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 10:28:45 -0700 Subject: [PATCH 11/26] change --- lib/kernels/include/kernels/batch_matmul_kernels.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index 6850dec178..e5062b2c61 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -26,7 +26,7 @@ BMMPerDeviceState init_kernel(PerDeviceFFHandle const &handle, void forward_kernel(ffStream_t stream, BMMPerDeviceState const &meta, - float *o_ptr, + float *output_ptr, float const *a_ptr, float const *b_ptr, float const *c_ptr, From 642eb90d99bdb83b3a212be50232520b497816ba Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 10:34:04 -0700 Subject: [PATCH 12/26] change --- lib/kernels/include/kernels/batch_matmul_kernels.h | 5 ++--- lib/runtime/src/ops/batch_matmul.cc | 1 - 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index e5062b2c61..ec32648d0f 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -27,9 +27,8 @@ BMMPerDeviceState init_kernel(PerDeviceFFHandle const &handle, void forward_kernel(ffStream_t stream, BMMPerDeviceState const &meta, float *output_ptr, - float const *a_ptr, - float const *b_ptr, - float const *c_ptr, + float const *lhs_input_ptr, + float const *rhs_input_ptr, int m, int n, int k, diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 59570f4417..09eb1f23c7 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -134,7 +134,6 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { output.get_float_ptr(), a_input.get_float_ptr(), b_input.get_float_ptr(), - nullptr, // c_ptr m, n, k, From e79406a842117e7823d8d63ed21ad5613c846078 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 10:46:00 -0700 Subject: [PATCH 13/26] change --- lib/kernels/src/cuda/batch_matmul_kernels.cu | 5 +---- lib/kernels/src/hip/batch_matmul_kernels.cpp | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/lib/kernels/src/cuda/batch_matmul_kernels.cu b/lib/kernels/src/cuda/batch_matmul_kernels.cu index 3593ac4ab2..cde0df93c0 100644 --- a/lib/kernels/src/cuda/batch_matmul_kernels.cu +++ b/lib/kernels/src/cuda/batch_matmul_kernels.cu @@ -18,9 +18,6 @@ namespace FlexFlow { -BatchMatmulPerDeviceState::BatchMatmulPerDeviceState(FFHandler handler) - : PerDeviceOpState(handler) {} - namespace Kernels { namespace BatchMatmul { @@ -124,7 +121,7 @@ O = A * B */ void forward_kernel(cudaStream_t stream, - BatchMatmulPerDeviceState const *meta, + BatchMatmulPerDeviceState const &meta, float *o_ptr, float const *a_ptr, float const *b_ptr, diff --git a/lib/kernels/src/hip/batch_matmul_kernels.cpp b/lib/kernels/src/hip/batch_matmul_kernels.cpp index e5334b1841..a06442d3d6 100644 --- a/lib/kernels/src/hip/batch_matmul_kernels.cpp +++ b/lib/kernels/src/hip/batch_matmul_kernels.cpp @@ -29,7 +29,7 @@ O: (batch, n, m) O = A * B */ void forward_kernel(hipStream_t stream, - BatchMatmulPerDeviceState const *meta, + BatchMatmulPerDeviceState const &meta, float *o_ptr, float const *a_ptr, float const *b_ptr, From bb3c10f8a9f46c9a09d514a9c70f4e72e37e3e94 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 10:51:31 -0700 Subject: [PATCH 14/26] change --- lib/op-attrs/include/op-attrs/ops/batch_matmul.h | 3 --- lib/runtime/src/ops/batch_matmul.cc | 6 +----- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index 6ac2a2cf69..00c700ba20 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -14,9 +14,6 @@ FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim); int get_aSeqLengthDim(BatchMatmulAttrs const &attrs); int get_bSeqLengthDim(BatchMatmulAttrs const &attrs); -ParallelTensorShape get_weights_shape(BatchMatmulAttrs const &, - ParallelTensorShape const &, - ParallelTensorShape const &); CHECK_VALID_OP_ATTR(BatchMatmulAttrs); } // namespace FlexFlow diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 09eb1f23c7..669cad215b 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -80,9 +80,7 @@ static DeviceSpecific DeviceSpecific per_device_state = acc.create_device_specific( init_kernel(handle, allocator, a_seq_length_dim, b_seq_length_dim)); - - // assert(weight.shape.get_volume() * sizeof(float) == - // acc.unwrap(per_device_state)->weightSize); + return per_device_state; } @@ -230,8 +228,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, ParallelTensorShape output_shape = get_output_shape(attrs, a_input.shape, b_input.shape); - ParallelTensorShape weight_shape = - get_weights_shape(attrs, a_input.shape, b_input.shape); SimTaskBinding init_binding; init_binding.bind_arg(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); From 848850900f9d44fe29df90c69a2099cb71be299e Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 11:00:34 -0700 Subject: [PATCH 15/26] change --- lib/runtime/include/runtime/config.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/runtime/include/runtime/config.h b/lib/runtime/include/runtime/config.h index 21cfca20cf..ef7e779469 100644 --- a/lib/runtime/include/runtime/config.h +++ b/lib/runtime/include/runtime/config.h @@ -107,7 +107,7 @@ struct FFConfig : public use_visitable_cmp { struct FFIterationConfig { FFIterationConfig(); void reset(); - req seq_length; + int seq_length; }; FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(FFIterationConfig, seq_length); From 36eba295424c5c58ea0e3aedea76e3f3adfd3790 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 11:03:04 -0700 Subject: [PATCH 16/26] change --- lib/runtime/src/ops/batch_matmul.h | 55 ------------------------------ 1 file changed, 55 deletions(-) diff --git a/lib/runtime/src/ops/batch_matmul.h b/lib/runtime/src/ops/batch_matmul.h index 42c9bc23de..018fe1d582 100644 --- a/lib/runtime/src/ops/batch_matmul.h +++ b/lib/runtime/src/ops/batch_matmul.h @@ -30,61 +30,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, ProfilingSettings const &settings, MachineView const &pc); -/* class BatchMatmul : public Op { */ -/* public: */ -/* BatchMatmul(FFModel &model, */ -/* const ParallelTensor A, */ -/* const ParallelTensor B, */ -/* int a_seq_length_dim, */ -/* int b_seq_length_dim, */ -/* char const *name = nullptr); */ -/* static Op * */ -/* create_operator_from_layer(FFModel &model, */ -/* Layer const *layer, */ -/* std::vector const &inputs); - */ - -/* void init(FFModel const &) override; */ -/* void forward(FFModel const &) override; */ -/* void backward(FFModel const &) override; */ -/* /1* static PCG::Node deserialize(FFModel &ff, *1/ */ -/* /1* Legion::Deserializer &d, *1/ */ -/* /1* ParallelTensor inputs[], *1/ */ -/* /1* int num_inputs); *1/ */ -/* static PerDeviceOpState *init_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void forward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void backward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* bool measure_operator_cost(Simulator *sim, */ -/* MachineView const &pc, */ -/* CostMetrics &cost_metrics) const override; */ - -/* OpTaskBinding get_init_task_binding() const override; */ -/* OpTaskBinding get_fwd_task_binding() const override; */ -/* OpTaskBinding get_bwd_task_binding() const override; */ -/* private: */ -/* template */ -/* void init_with_dim(FFModel const &ff); */ -/* template */ -/* void forward_with_dim(FFModel const &ff); */ -/* template */ -/* void backward_with_dim(FFModel const &ff); */ - -/* public: */ -/* int a_seq_length_dim, b_seq_length_dim; */ -/* }; */ - } // namespace FlexFlow #endif From 98773b4237de4a4585711bca14d26b9d3c6e75ff Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 11:15:53 -0700 Subject: [PATCH 17/26] change --- lib/runtime/src/ops/batch_matmul.cc | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 669cad215b..a8c2ec7bd7 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -72,8 +72,8 @@ OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { static DeviceSpecific init_task_impl(TaskArgumentAccessor const &acc) { - auto const a_seq_length_dim = acc.get_argument(A_SEQ_LENGTH_DIM); - auto const b_seq_length_dim = acc.get_argument(B_SEQ_LENGTH_DIM); + int const a_seq_length_dim = acc.get_argument(A_SEQ_LENGTH_DIM); + int const b_seq_length_dim = acc.get_argument(B_SEQ_LENGTH_DIM); PerDeviceFFHandle handle = acc.get_argument(HANDLE); Allocator allocator = acc.get_allocator(); @@ -94,9 +94,6 @@ static DeviceSpecific } static optional forward_task_impl(TaskArgumentAccessor const &acc) { - assert(regions.size() == 3); - assert(task->regions.size() == 3); - auto a_input = acc.get_tensor(A_INPUT); auto b_input = acc.get_tensor(B_INPUT); auto output = acc.get_tensor(OUTPUT); @@ -148,10 +145,6 @@ static void forward_task(Task const *task, } static optional backward_task_impl(TaskArgumentAccessor const &acc) { - // Currently assume C is NULL - assert(regions.size() == 6); - assert(task->regions.size() == 6); - // BatchMatmul* bmm = (BatchMatmul*) task->args; FFIterationConfig iter_config = acc.get_argument(ITERATION_CONFIG); @@ -161,15 +154,15 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { auto output = acc.get_tensor(OUTPUT); auto output_grad = acc.get_tensor_grad(OUTPUT); // is this equivalent to checking `Domain` equality? - assert(output == output_grad); + assert(output.shape == output_grad.shape); auto a_input = acc.get_tensor(A_INPUT); auto a_input_grad = acc.get_tensor_grad(A_INPUT); - assert(a_input == a_input_grad); + assert(a_input.shape == a_input_grad.shape); auto b_input = acc.get_tensor(B_INPUT); auto b_input_grad = acc.get_tensor_grad(B_INPUT); - assert(b_input == b_input_grad); + assert(b_input.shape == b_input_grad.shape); // check dins int m = b_input.shape[legion_dim_t(0)]; @@ -181,8 +174,8 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(a_input.shape.size() == b_input.shape.size()); assert(a_input.shape.size() == output.shape.size()); int batch = 1; - for (int i = 2; i < a_input.shape.get_dim(); - i++) { //@colin get_dim() or get_volume()? + for (int i = 2; i < a_input.shape.dims.num_dims(); + i++) { int dim_size = a_input.shape[legion_dim_t(i)]; assert(dim_size == b_input.shape[legion_dim_t(i)]); assert(dim_size == output.shape[legion_dim_t(i)]); From ae592613cf3fa6c763ca0f14b30a1899831ee4a9 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 11:28:18 -0700 Subject: [PATCH 18/26] change --- lib/runtime/src/task_spec/op_arg_ref.h | 10 ++-------- lib/runtime/src/task_spec/runtime_arg_ref.h | 2 ++ 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/lib/runtime/src/task_spec/op_arg_ref.h b/lib/runtime/src/task_spec/op_arg_ref.h index f2361af05d..20ce6892a2 100644 --- a/lib/runtime/src/task_spec/op_arg_ref.h +++ b/lib/runtime/src/task_spec/op_arg_ref.h @@ -4,15 +4,13 @@ #include "arg_ref.h" #include "device_specific.h" #include "op-attrs/parallel_tensor_shape.h" -#include "runtime/config.h" namespace FlexFlow { enum class OpArgRefType { PER_DEVICE_OP_STATE, - PARALLEL_TENSOR_SHAPE, - ITERATION_CONFIG -}; + PARALLEL_TENSOR_SHAPE + }; template using OpArgRef = ArgRef; @@ -28,10 +26,6 @@ OpArgRef input_parallel_tensor_shape(int idx) { return {OpArgRefType::PARALLEL_TENSOR_SHAPE}; } -OpArgRef iteration_config() { - return {OpArgRefType::ITERATION_CONFIG}; -} - } // namespace FlexFlow #endif diff --git a/lib/runtime/src/task_spec/runtime_arg_ref.h b/lib/runtime/src/task_spec/runtime_arg_ref.h index 6b4345091a..033c2bcfbc 100644 --- a/lib/runtime/src/task_spec/runtime_arg_ref.h +++ b/lib/runtime/src/task_spec/runtime_arg_ref.h @@ -3,6 +3,7 @@ #include "arg_ref.h" #include "device_specific.h" +#include "runtime/config.h" namespace FlexFlow { @@ -15,6 +16,7 @@ using RuntimeArgRefSpec = ArgRefSpec; RuntimeArgRef profiling_settings(); RuntimeArgRef> ff_handle(); +RuntimeArgRef iteration_config(); } // namespace FlexFlow From 8db251252670a8107430318da166e2ba059879ca Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 17:50:28 -0700 Subject: [PATCH 19/26] fix asserts --- lib/runtime/src/ops/attention.cc | 15 --------------- lib/runtime/src/ops/batch_matmul.cc | 13 ++++--------- 2 files changed, 4 insertions(+), 24 deletions(-) diff --git a/lib/runtime/src/ops/attention.cc b/lib/runtime/src/ops/attention.cc index bca87bdb53..94e2b03731 100644 --- a/lib/runtime/src/ops/attention.cc +++ b/lib/runtime/src/ops/attention.cc @@ -121,18 +121,6 @@ static DeviceSpecific int num_samples = get_piece_shape(query_parallel_tensor_shape)[ff_dim_t(2)]; int num_heads = get_piece_shape(weight_parallel_tensor_shape)[ff_dim_t(1)]; - assert(qoSeqLength == query.shape[legion_dim_t(1)]); - assert(qSize == query.shape[legion_dim_t(0)]); - assert(num_samples == key.shape[legion_dim_t(2)]); - assert(kvSeqLength == key.shape[legion_dim_t(1)]); - assert(kSize == key.shape[legion_dim_t(0)]); - assert(num_samples == value.shape[legion_dim_t(2)]); - assert(kvSeqLength == value.shape[legion_dim_t(1)]); - assert(vSize == value.shape[legion_dim_t(0)]); - assert(num_samples == output.shape[legion_dim_t(2)]); - assert(qoSeqLength == output.shape[legion_dim_t(1)]); - assert(oProjSize == output.shape[legion_dim_t(0)]); - DeviceSpecific per_device_state = acc.create_device_specific( init_kernel(handle, @@ -149,9 +137,6 @@ static DeviceSpecific qoSeqLength, kvSeqLength, attrs.add_bias_kv)); - - assert(weight.shape.get_volume() * sizeof(float) == - acc.unwrap(per_device_state)->weightSize); return per_device_state; } diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index a8c2ec7bd7..00699652e7 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -110,8 +110,8 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { int k = a_input.shape[legion_dim_t(0)]; assert(k == b_input.shape[legion_dim_t(1)]); - assert(a_input.shape.size() == b_input.shape.size()); - assert(a_input.shape.size() == output.shape.size()); + assert(a_input.shape.get_volume() == b_input.shape.get_volume()); + assert(a_input.shape.get_volume() == output.shape.get_volume()); int batch = 1; for (int i = 2; i < a_input.shape.get_dim(); @@ -171,8 +171,8 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(n == output.shape[legion_dim_t(1)]); int k = a_input.shape[legion_dim_t(0)]; assert(k == b_input.shape[legion_dim_t(1)]); - assert(a_input.shape.size() == b_input.shape.size()); - assert(a_input.shape.size() == output.shape.size()); + assert(a_input.shape.get_volume() == b_input.shape.get_volume()); + assert(a_input.shape.get_volume() == output.shape.get_volume()); int batch = 1; for (int i = 2; i < a_input.shape.dims.num_dims(); i++) { @@ -182,11 +182,6 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { batch *= dim_size; } - // TODO: add support for meta->a_seq_length_dim >= 0 - // or meta->b_seq_length_dim >= 0 - assert((meta->a_seq_length_dim >= a_len) || (iter_config.seq_length == 0)); - assert((meta->b_seq_length_dim >= b_len) || (iter_config.seq_length == 0)); - return profile(backward_kernel, profiling, "[BatchMatmul] backward_time = %.2lfms\n", From e8a6c30a439fcb1168986939de909d6c717c66f5 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 17:57:44 -0700 Subject: [PATCH 20/26] remove asserts --- lib/runtime/src/ops/batch_norm.cc | 6 ------ 1 file changed, 6 deletions(-) diff --git a/lib/runtime/src/ops/batch_norm.cc b/lib/runtime/src/ops/batch_norm.cc index ffd52c96fb..6ebf359051 100644 --- a/lib/runtime/src/ops/batch_norm.cc +++ b/lib/runtime/src/ops/batch_norm.cc @@ -130,9 +130,6 @@ static DeviceSpecific } static optional forward_task_impl(TaskArgumentAccessor const &acc) { - assert(regions.size() == 4); - assert(task->regions.size() == 4); - auto per_device_state = acc.get_argument(PER_DEVICE_STATE); ProfilingSettings profiling = acc.get_argument(PROFILING); @@ -161,9 +158,6 @@ static void forward_task(Task const *task, } static optional backward_task_impl(TaskArgumentAccessor const &acc) { - assert(regions.size() == 7); - assert(task->regions.size() == 7); - auto per_device_state = acc.get_argument(PER_DEVICE_STATE); ProfilingSettings profiling = acc.get_argument(PROFILING); From 09acbe545ad4d6d1cab1e14e55bf1d7e9a638954 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 29 Aug 2023 17:59:10 -0700 Subject: [PATCH 21/26] format --- lib/runtime/src/ops/batch_matmul.cc | 5 ++--- lib/runtime/src/task_spec/op_arg_ref.h | 5 +---- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 00699652e7..45c5e11b9c 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -80,7 +80,7 @@ static DeviceSpecific DeviceSpecific per_device_state = acc.create_device_specific( init_kernel(handle, allocator, a_seq_length_dim, b_seq_length_dim)); - + return per_device_state; } @@ -174,8 +174,7 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(a_input.shape.get_volume() == b_input.shape.get_volume()); assert(a_input.shape.get_volume() == output.shape.get_volume()); int batch = 1; - for (int i = 2; i < a_input.shape.dims.num_dims(); - i++) { + for (int i = 2; i < a_input.shape.dims.num_dims(); i++) { int dim_size = a_input.shape[legion_dim_t(i)]; assert(dim_size == b_input.shape[legion_dim_t(i)]); assert(dim_size == output.shape[legion_dim_t(i)]); diff --git a/lib/runtime/src/task_spec/op_arg_ref.h b/lib/runtime/src/task_spec/op_arg_ref.h index 20ce6892a2..3e931d79a4 100644 --- a/lib/runtime/src/task_spec/op_arg_ref.h +++ b/lib/runtime/src/task_spec/op_arg_ref.h @@ -7,10 +7,7 @@ namespace FlexFlow { -enum class OpArgRefType { - PER_DEVICE_OP_STATE, - PARALLEL_TENSOR_SHAPE - }; +enum class OpArgRefType { PER_DEVICE_OP_STATE, PARALLEL_TENSOR_SHAPE }; template using OpArgRef = ArgRef; From 67d2d1ed9620748ee2a415f7ce64929e613b3b47 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Thu, 31 Aug 2023 12:39:08 -0700 Subject: [PATCH 22/26] concat --- lib/kernels/include/kernels/concat_kernels.h | 15 +- lib/kernels/src/hip/concat_kernels.cpp | 4 - lib/op-attrs/include/op-attrs/ops/concat.h | 4 +- lib/runtime/src/ops/concat.cc | 610 ++++-------------- lib/runtime/src/ops/concat.h | 335 +++++++++- lib/runtime/src/serialization.h | 9 +- .../src/task_spec/op_task_invocation.h | 17 +- lib/runtime/src/task_spec/op_tensor_spec.h | 21 + .../src/task_spec/variadic_tensor_ref.h | 26 + 9 files changed, 540 insertions(+), 501 deletions(-) create mode 100644 lib/runtime/src/task_spec/op_tensor_spec.h create mode 100644 lib/runtime/src/task_spec/variadic_tensor_ref.h diff --git a/lib/kernels/include/kernels/concat_kernels.h b/lib/kernels/include/kernels/concat_kernels.h index 741bbbe9f0..1bbefdf843 100644 --- a/lib/kernels/include/kernels/concat_kernels.h +++ b/lib/kernels/include/kernels/concat_kernels.h @@ -6,28 +6,27 @@ namespace FlexFlow { -class ConcatPerDeviceState : public PerDeviceOpState { -public: - ConcatPerDeviceState(FFHandler handle) : PerDeviceOpState(handle){}; - int legion_axis; - char op_name[MAX_OPNAME]; +struct ConcatPerDeviceState { + req legion_axis; }; +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(ConcatPerDeviceState, legion_axis); + namespace Kernels { namespace Concat { -void init_meta(ConcatPerDeviceState *meta, int legion_axis); +ConcatPerDeviceState init_kernel(ff_dim_t legion_axis); void forward_kernel(ffStream_t stream, ConcatPerDeviceState const *m, GenericTensorAccessorW const &output, - GenericTensorAccessorR const *inputs, + std::vector const &inputs, int num_inputs); void backward_kernel(ffStream_t stream, ConcatPerDeviceState const *m, GenericTensorAccessorR const &output_grad, - GenericTensorAccessorW const *input_grads, + std::vector const &input_grads, int num_inputs); } // namespace Concat diff --git a/lib/kernels/src/hip/concat_kernels.cpp b/lib/kernels/src/hip/concat_kernels.cpp index e818f8b568..f943bc9156 100644 --- a/lib/kernels/src/hip/concat_kernels.cpp +++ b/lib/kernels/src/hip/concat_kernels.cpp @@ -26,10 +26,6 @@ using Legion::Rect; namespace Kernels { namespace Concat { -void init_meta(ConcatPerDeviceState *m, int legion_axis) { - m->legion_axis = legion_axis; -} - template void calc_blk_size(coord_t &num_blocks, coord_t &blk_size, diff --git a/lib/op-attrs/include/op-attrs/ops/concat.h b/lib/op-attrs/include/op-attrs/ops/concat.h index b9bd14a231..cbc864be44 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat.h +++ b/lib/op-attrs/include/op-attrs/ops/concat.h @@ -9,9 +9,9 @@ namespace FlexFlow { struct ConcatAttrs { - ff_dim_t axis; + req axis; }; -FF_VISITABLE_STRUCT(ConcatAttrs, axis); +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(ConcatAttrs, axis); CHECK_VALID_OP_ATTR(ConcatAttrs); } // namespace FlexFlow diff --git a/lib/runtime/src/ops/concat.cc b/lib/runtime/src/ops/concat.cc index f17a33b956..20c54c993e 100644 --- a/lib/runtime/src/ops/concat.cc +++ b/lib/runtime/src/ops/concat.cc @@ -16,537 +16,197 @@ #include "concat.h" #include "kernels/concat_kernels.h" #include "legion/legion_utilities.h" -#include "task_spec.h" +#include "task_spec/variadic_tensor_ref.h" #include "utils/hash-utils.h" +#include "op-attrs/get_output_shapes.h" namespace FlexFlow { -enum Slots { - INPUTS, - OUTPUT, - INPUT_GRADS, - OUTPUT_GRAD, - ATTRS, - PROFILING -} +using namespace FlexFlow::Kernels::Concat; -// declare Legion names -using Legion::ArgumentMap; using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; using Legion::Runtime; using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; -using PCG::Node; -using namespace FlexFlow::Kernels::Concat; +enum Slots { + INPUTS, + OUTPUT, + ATTRS, + PROFILING, + HANDLE, + PER_DEVICE_STATE, + NUM_INPUTS +}; -bool operator==(ConcatParams const &lhs, ConcatParams const &rhs) { - return lhs.axis == rhs.axis; -} +OpTaskInvocation init(ConcatAttrs const &attrs) { + OpTaskBinding binding; -ConcatParams Concat::get_params() const { - ConcatParams params; - params.axis = legion_axis; - return params; -} + binding.bind_arg(PROFILING, profiling_settings()); + binding.bind_arg(ATTRS, attrs); -Tensor - FFModel::concat(int n, Tensor const *tensors, int axis, char const *name) { - Layer *concat = new Layer(this, - OP_CONCAT, - DT_FLOAT, - name, - n /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - tensors); - int numdim = tensors[0]->num_dims; - // Making sure axis is between [0, numdim) - axis = (axis % numdim + numdim) % numdim; - int dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdim; i++) { - dims[i] = tensors[0]->dims[i]; - } - for (int i = 1; i < n; i++) { - assert(tensors[i]->data_type == tensors[0]->data_type); - assert(tensors[i]->num_dims == tensors[0]->num_dims); - for (int j = 0; j < numdim; j++) { - if (j != numdim - axis - 1) { - assert(tensors[i]->dims[j] == tensors[0]->dims[j]); - } else { - dims[j] += tensors[i]->dims[j]; - } - } - } - concat->outputs[0] = create_tensor_legion_ordering( - numdim, dims, tensors[0]->data_type, concat, 0, true /*create_grad*/); - concat->add_int_property("legion_axis", numdim - axis - 1); - layers.push_back(concat); - return concat->outputs[0]; + return {CONCAT_INIT_TASK_ID, binding}; } -Op *Concat::create_operator_from_layer( - FFModel &model, - Layer const *layer, - std::vector const &inputs) { - long long value; - layer->get_int_property("legion_axis", value); - int legion_axis = value; - return new Concat( - model, inputs.size(), inputs.data(), legion_axis, layer->name); -} +OpTaskInvocation forward(ConcatAttrs const &attrs) { + OpTaskBinding binding; + binding.bind_arg(PER_DEVICE_STATE, per_device_op_state()); + binding.bind(INPUTS, get_input_tensors()); + binding.bind(OUTPUT, output_tensor(0)); + binding.bind(NUM_INPUTS, get_number_inputs()); + binding.bind_arg(PROFILING, profiling_settings()); -Concat::Concat(FFModel &model, - int _n, - ParallelTensor const *_tensors, - int _legion_axis, - char const *name) - : Op(model, - OP_CONCAT, - DT_FLOAT, - name, - _n /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - _tensors), - legion_axis(_legion_axis) { - int num_dim = inputs[0]->num_dims; - ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < num_dim; i++) { - dims[i] = inputs[0]->dims[i]; - } - for (int i = 1; i < numInputs; i++) { - assert(inputs[i]->data_type == inputs[0]->data_type); - assert(inputs[i]->num_dims == inputs[0]->num_dims); - for (int j = 0; j < num_dim; j++) { - if (j != legion_axis) { - assert(inputs[i]->dims[j] == inputs[0]->dims[j]); - } else { - // Assert that the concat dim cannot be parallelized - assert(inputs[i]->dims[j].parallel_idx == -1); - assert(inputs[i]->dims[j].degree == 1); - dims[j].size += inputs[i]->dims[j].size; - } - } - } - numOutputs = 1; - outputs[0] = model.create_parallel_tensor_legion_ordering( - num_dim, dims, inputs[0]->data_type, this); + return {CONCAT_FWD_TASK_ID, binding}; } -Concat::Concat(FFModel &model, - ConcatParams const ¶ms, - std::vector const &inputs, - char const *name) - : Concat(model, inputs.size(), inputs.data(), params.axis, name) {} +OpTaskInvocation backward(ConcatAttrs const &attrs) { + OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); -static OpTaskSignature get_init_task_signature() { - OpTaskSignature init(OpTaskType::INIT); + return {CONCAT_BWD_TASK_ID, b}; +} - init.add_arg_slot(ATTRS); - init.add_arg_slot(PROFILING); +static DeviceSpecific + init_task_impl(TaskArgumentAccessor const &acc) { + auto const &attrs = acc.get_argument(ATTRS); + PerDeviceFFHandle handle = acc.get_argument(HANDLE); - init.add_input_slot(INPUTS, SlotType::VARIADIC); - init.add_output_slot(OUTPUT); + DeviceSpecific per_device_state = + acc.create_device_specific(init_kernel(attrs.axis)); + return per_device_state; +} - return init; +static DeviceSpecific + init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + return init_task_impl(acc); } -static OpTaskSignature get_fwd_task_signature() { - OpTaskSignature fwd(OpTaskType::FWD); +static optional forward_task_impl(TaskArgumentAccessor const &acc) { + auto per_device_state = acc.get_argument(PER_DEVICE_STATE); + int number_of_inputs = acc.get_argument(NUM_INPUTS); + ProfilingSettings profiling = acc.get_argument(PROFILING); - fwd.add_arg_slot(ATTRS); + auto output = acc.get_tensor(OUTPUT); + auto inputs = acc.get_variadic_tensor(INPUTS); - fwd.add_input_slot(INPUTS, SlotType::VARIADIC); - fwd.add_output_slot(OUTPUT); + return profile(forward_kernel, + profiling, + "[Concat] forward_time = %.2lfms\n", + &per_device_state, + output, + inputs, + number_of_inputs); +} - return init; +static void forward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + forward_task_impl(acc); } -static OpTaskSignature get_bwd_task_signature() { - OpTaskSignature bwd(OpTaskType::BWD); +static optional backward_task_impl(TaskArgumentAccessor const &acc) { + auto per_device_state = acc.get_argument(PER_DEVICE_STATE); + int number_of_inputs = acc.get_argument(NUM_INPUTS); + ProfilingSettings profiling = acc.get_argument(PROFILING); - bwd.add_arg_slot(ATTRS); + auto input_grads = acc.get_variadic_tensor_grad(INPUTS); + auto output_grad = acc.get_tensor_grad(OUTPUT); - bwd.add_input_grad_slot(INPUT_GRADS, SlotType::VARIADIC); - bwd.add_output_grad_slot(OUTPUT_GRAD); + assert(number_of_inputs <= MAX_NUM_INPUTS); - return bwd; + return profile(backward_kernel, + profiling, + "[Concat] backward_time = %.2lfms\n", + &per_device_state, + output_grad, + input_grads, + number_of_inputs); } -OpTaskBinding Concat::get_init_task_binding() const { - OpTaskBinding binding; - - binding.bind_arg(PROFILING, this->profiling); - binding.bind_arg(ATTRS, this->attrs); - - return binding; +static void backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + backward_task_impl(acc); } -OpTaskBinding Concat::get_fwd_task_binding() const { - OpTaskBinding binding; +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + ConcatAttrs const &attrs, + InputVariadicParallelTensorDesc const &inputs_shape, + ProfilingSettings const &settings, + MachineView const &mv) { + int numInputs = (inputs_shape.shapes).size(); + assert(numInputs <= MAX_NUM_INPUTS); - binding.bind_arg(ATTRS, this->attrs); + auto env = sim.new_environment(); - for (int i = 0; i < this->attrs.n; i++) { - binding.bind(INPUTS, input_tensor(i)); - } + ParallelTensorShape output_shape = get_output_shape(attrs, inputs_shape.shapes); - binding.bind(OUTPUT, output_tensor(0)); + SimTaskBinding init_binding; + init_binding.bind_arg(PROFILING, settings); + init_binding.bind_arg(ATTRS, attrs); - return binding; -} + auto init_accessor = + env.get_init_accessor(CONCAT_INIT_TASK_ID, init_binding); + DeviceSpecific per_device_state = + init_task_impl(init_accessor); -OpTaskBinding Concat::get_bwd_task_binding() const { - OpTaskBinding binding; + SimTaskBinding fwd_binding; + fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); + fwd_binding.bind(INPUTS, inputs_shape); + fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind_arg(NUM_INPUTS, numInputs); + fwd_binding.bind_arg(PROFILING, settings); - binding.bind_arg(ATTRS, this->attrs); + SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); - for (int i = 0; i < this->attrs.n; i++) { - binding.bind(INPUT_GRADS, input_tensor(i).grad()); - } + auto fwd_accessor = env.get_fwd_accessor(CONCAT_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = env.get_bwd_accessor(CONCAT_BWD_TASK_ID, bwd_binding); - binding.bind(OUTPUT_GRAD, output_tensor(0).grad()); + float forward_time = forward_task_impl(fwd_accessor).value(); + float backward_time = backward_task_impl(bwd_accessor).value(); - return binding; -} + float sync_time = default_estimate_sync_time(env); + return make_metrics(forward_time, backward_time, sync_time, env); -void Concat::init(FFModel const &ff) { - this->execute_task(ff, CONCAT_INIT_TASK_ID, get_init_task_signature()); - // assert(check_output_input_weight_same_parallel_is()); - // parallel_is = outputs[0]->parallel_is; - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_init(ff, argmap); - // IndexLauncher launcher(CONCAT_INIT_TASK_ID, - // parallel_is, - // TaskArgument(this, sizeof(Concat)), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // for (int i = 0; i < numInputs; i++) { - // launcher.add_region_requirement(RegionRequirement(inputs[i]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[i]->region)); - // launcher.add_field(i + 1, FID_DATA); - // } - // for (int i = 0; i < numInputs; i++) { - // launcher.add_region_requirement(RegionRequirement(inputs[i]->part_grad, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // inputs[i]->region_grad)); - // launcher.add_field(i + numInputs + 1, FID_DATA); - // } - // FutureMap fm = runtime->execute_index_space(ctx, launcher); - // fm.wait_all_results(); - // set_opmeta_from_futuremap(ff, fm); } -PerDeviceOpState *Concat::init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - auto const &attrs = acc.get_argument(ATTRS); - bool profiling = acc.get_argument(PROFILING); - - FFHandler handler = *((FFHandler const *)task->local_args); - ConcatPerDeviceState *m = new ConcatPerDeviceState(handler); - // Note that our internal axis index ordering is opposite to other frameworks - init_meta(m, attrs.legion_axis); - m->profiling = profiling; - std::strcpy(m->op_name, attrs.name); - return m; -} +template <> +void register_task() { + OpTaskSignature init(OpTaskType::INIT); -void Concat::forward(FFModel const &ff) { - this->execute_task(ff, CONCAT_FWD_TASK_ID, get_fwd_task_signature()); - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_forward(ff, argmap); - // IndexLauncher launcher(CONCAT_FWD_TASK_ID, - // parallel_is, - // TaskArgument(this, sizeof(Concat)), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // for (int i = 0; i < numInputs; i++) { - // launcher.add_region_requirement(RegionRequirement(inputs[i]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[i]->region)); - // launcher.add_field(i + 1, FID_DATA); - // } - // runtime->execute_index_space(ctx, launcher); + init.add_arg_slot(ATTRS); + init.add_arg_slot(PROFILING); + + register_task(CONCAT_INIT_TASK_ID, "Concat Init", init, init_task); } -/* - regions[0](O): output - regions[1..numInputs](I): inputs -*/ -void Concat::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - // Concat const *cc = (Concat *)task->args; - ConcatPerDeviceState const *m = *((ConcatPerDeviceState **)task->local_args); - // Note that our internal axis index ordering is opposite to other frameworks - assert(regions.size() == attrs.n + 1); - assert(task->regions.size() == attrs.n + 1); - // Domain out_domain = runtime->get_index_space_domain( - // ctx, task->regions[0].region.get_index_space()); - // GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( - // DT_FLOAT, regions[0], task->regions[0], FID_DATA, ctx, runtime); - // assert(out_domain.get_dim() == cc->outputs[0].num_dims); - // Domain in_domain[MAX_NUM_INPUTS]; - // for (int i = 0; i < cc->numInputs; i++) - // in_domain[i] = runtime->get_index_space_domain( - // ctx, task->regions[i + 1].region.get_index_space()); - // float *output = helperGetTensorPointerWO( - // regions[0], task->regions[0], FID_DATA, ctx, runtime); - - auto output = acc.get_tensor(OUTPUT); - auto inputs = acc.get_variadic_tensor(INPUTS); - - // GenericTensorAccessorR inputs[MAX_NUM_INPUTS]; - // for (int i = 0; i < attrs.n; i++) { - // // inputs[i] = helperGetTensorPointerRO( - // // regions[i + 1], task->regions[i + 1], FID_DATA, ctx, runtime); - // inputs[i] = helperGetGenericTensorAccessorRO( - // DT_FLOAT, regions[i + 1], task->regions[i + 1], FID_DATA, ctx, - // runtime); - // } - profile(forward_kernel, - m->profiling, - "[Concat] forward_time = %.2lfms\n", - m, - output, - inputs, - attrs.n) -} +template <> +void register_task() { + OpTaskSignature fwd(OpTaskType::FWD); -void Concat::backward(FFModel const &ff) { - this->execute_task(ff, CONCAT_BWD_TASK_ID, get_bwd_task_signature()); - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_backward(ff, argmap); - // IndexLauncher launcher(CONCAT_BWD_TASK_ID, - // parallel_is, - // TaskArgument(this, sizeof(Concat)), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // outputs[0]->region_grad)); - // launcher.add_field(0, FID_DATA); - // for (int i = 0; i < numInputs; i++) { - // launcher.add_region_requirement(RegionRequirement(inputs[i]->part_grad, - // 0 /*projection id*/, - // READ_WRITE, - // EXCLUSIVE, - // inputs[i]->region_grad)); - // // LogicalRegion lr = inputs[i]->region_grad; - // // printf("concat[%d]: region(%d,%d,%d)\n", i+1, - // // lr.get_index_space().get_id(), lr.get_field_space().get_id(), - // // lr.get_tree_id()); - // launcher.add_field(i + 1, FID_DATA); - // } - // runtime->execute_index_space(ctx, launcher); -} + fwd.add_arg_slot(NUM_INPUTS); + fwd.add_arg_slot(PROFILING); + fwd.add_input_slot(INPUTS, SlotType::VARIADIC); + fwd.add_output_slot(OUTPUT); + fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); -/* - regions[0](I): output_grad - regions[1..numInputs](I/O): input_grad -*/ -void Concat::backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - // Concat const *cc = (Concat *)task->args; - ConcatPerDeviceState const *m = *((ConcatPerDeviceState **)task->local_args); - // Note that our internal axis index ordering is opposite to other frameworks - assert(regions.size() == attrs.n + 1); - assert(task->regions.size() == attrs.n + 1); - assert(attrs.n <= MAX_NUM_INPUTS); - // Domain out_grad_domain = runtime->get_index_space_domain( - // ctx, task->regions[0].region.get_index_space()); - // assert(out_grad_domain.get_dim() == cc->outputs[0].num_dims); - // Domain in_grad_domains[MAX_NUM_INPUTS]; - // for (int i = 0; i < cc->numInputs; i++) - // in_grad_domains[i] = runtime->get_index_space_domain( - // ctx, task->regions[i + 1].region.get_index_space()); - // float const *output_grad = helperGetTensorPointerRO( - // regions[0], task->regions[0], FID_DATA, ctx, runtime); - - auto input_grads = acc.get_variadic_tensor(INPUT_GRADS); - auto output_grad = acc.get_tensor(OUTPUT_GRAD); - - // GenericTensorAccessorR output_grad = helperGetGenericTensorAccessorRO( - // DT_FLOAT, regions[0], task->regions[0], FID_DATA, ctx, runtime); - // GenericTensorAccessorW input_grads[MAX_NUM_INPUTS]; - // for (int i = 0; i < attrs.n; i++) { - // // input_grads[i] = helperGetTensorPointerRW( - // // regions[i + 1], task->regions[i + 1], FID_DATA, ctx, runtime); - // input_grads[i] = helperGetGenericTensorAccessorRW( - // DT_FLOAT, regions[i + 1], task->regions[i + 1], FID_DATA, ctx, - // runtime); - // } - - profile(backward_kernel, - m->profiling, - "[Concat] backward_time = %.2lfms\n", - m, - output_grad, - input_grads, - attrs.n) + register_task(CONCAT_FWD_TASK_ID, "Concat Fwd", fwd, forward_task); } -bool Concat::get_int_parameter(PMParameter para, int *value) const { - switch (para) { - case PM_AXIS: - *value = legion_axis; - return true; - default: - return Op::get_int_parameter(para, value); - } -} +template <> +void register_task() { + OpTaskSignature bwd = + infer_bwd_signature(get_op_signature(CONCAT_FWD_TASK_ID)); -bool Concat::measure_operator_cost(Simulator *sim, - MachineView const &mv, - CostMetrics &cost_metrics) const { - assert(numInputs <= MAX_NUM_INPUTS); - ParallelTensorBase sub_inputs[MAX_NUM_INPUTS], sub_output; - if (!outputs[0]->get_sub_tensor(mv, sub_output)) { - return false; - } - for (int i = 0; i < numInputs; i++) { - if (!inputs[i]->get_sub_tensor(mv, sub_inputs[i])) { - return false; - } - } - - ConcatPerDeviceState *m = sim->concat_meta; - init_meta(m, this->legion_axis); - - sim->free_all(); - float *input_ptrs[MAX_NUM_INPUTS]; - float *input_grad_ptrs[MAX_NUM_INPUTS]; - bool out_of_memory = false; - for (int i = 0; i < numInputs; i++) { - input_ptrs[i] = - (float *)sim->allocate(sub_inputs[i].get_volume(), DT_FLOAT); - out_of_memory = out_of_memory || (input_ptrs[i] == NULL); - } - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - Domain out_domain = sub_output.get_domain(); - float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - GenericTensorAccessorW output_acc(DT_FLOAT, out_domain, output_ptr); - cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - out_of_memory = out_of_memory || (output_ptr == NULL); - if (out_of_memory) { - cost_metrics.forward_time = Simulator::MAXIMUM_TASK_RUN_TIME; - cost_metrics.backward_time = Simulator::MAXIMUM_TASK_RUN_TIME; - return true; - } - - Domain in_domains[MAX_NUM_INPUTS]; - GenericTensorAccessorR input_acc[MAX_NUM_INPUTS]; - for (int i = 0; i < numInputs; i++) { - in_domains[i] = sub_inputs[i].get_domain(); - input_acc[i] = - GenericTensorAccessorR(DT_FLOAT, in_domains[i], input_ptrs[i]); - } - - assert(m->profiling == false); - - std::function forward, backward; - forward = [&](ffStream_t stream) { - forward_kernel(stream, m, output_acc, input_acc, numInputs); - }; - if (sim->computationMode == COMP_MODE_TRAINING) { - GenericTensorAccessorW input_grad_accs[MAX_NUM_INPUTS]; - for (int i = 0; i < numInputs; i++) { - input_grad_ptrs[i] = - (float *)sim->allocate(sub_inputs[i].get_volume(), DT_FLOAT); - out_of_memory = out_of_memory || (input_grad_ptrs[i] == NULL); - input_grad_accs[i] = - GenericTensorAccessorW(DT_FLOAT, in_domains[i], input_grad_ptrs[i]); - } - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - float *output_grad_ptr = - (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - GenericTensorAccessorR output_grad_acc( - DT_FLOAT, out_domain, output_grad_ptr); - cost_metrics.outputs_memory += - cost_metrics.total_mem_diff_from(sim->offset); - - out_of_memory = out_of_memory || (output_grad_ptr == NULL); - if (out_of_memory) { - cost_metrics.forward_time = Simulator::MAXIMUM_TASK_RUN_TIME; - cost_metrics.backward_time = Simulator::MAXIMUM_TASK_RUN_TIME; - return true; - } - backward = [&](ffStream_t stream) { - backward_kernel(stream, m, output_grad_acc, input_grad_accs, numInputs); - }; - } - - inner_measure_operator_cost(sim, forward, backward, cost_metrics); - - if (sim->computationMode == COMP_MODE_TRAINING) { - printf( - "[Measure Concat] name(%s) forward_time(%.4lf) backward_time(%.4lf)\n", - name, - cost_metrics.forward_time, - cost_metrics.backward_time); - } else { - printf("[Measure Concat] name(%s) forward_time(%.4lf)\n", - name, - cost_metrics.forward_time); - } - - return true; + register_task(CONCAT_BWD_TASK_ID, "BatchMatmul Bwd", bwd, backward_task); } + }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/concat.h b/lib/runtime/src/ops/concat.h index 0e0c0c2523..5f792566b4 100644 --- a/lib/runtime/src/ops/concat.h +++ b/lib/runtime/src/ops/concat.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_CONCAT_H #include "op-attrs/ops/concat.h" -#include "op_task_invocation.h" +#include "task_spec/op_task_invocation.h" #include "sim_environment.h" namespace FlexFlow { @@ -76,3 +76,336 @@ CostMetrics } // namespace FlexFlow #endif + + +// bool operator==(ConcatParams const &lhs, ConcatParams const &rhs) { +// return lhs.axis == rhs.axis; +// } + +// ConcatParams Concat::get_params() const { +// ConcatParams params; +// params.axis = legion_axis; +// return params; +// } + +// Tensor +// FFModel::concat(int n, Tensor const *tensors, int axis, char const *name) { +// Layer *concat = new Layer(this, +// OP_CONCAT, +// DT_FLOAT, +// name, +// n /*inputs*/, +// 0 /*weights*/, +// 1 /*outputs*/, +// tensors); +// int numdim = tensors[0]->num_dims; +// // Making sure axis is between [0, numdim) +// axis = (axis % numdim + numdim) % numdim; +// int dims[MAX_TENSOR_DIM]; +// for (int i = 0; i < numdim; i++) { +// dims[i] = tensors[0]->dims[i]; +// } +// for (int i = 1; i < n; i++) { +// assert(tensors[i]->data_type == tensors[0]->data_type); +// assert(tensors[i]->num_dims == tensors[0]->num_dims); +// for (int j = 0; j < numdim; j++) { +// if (j != numdim - axis - 1) { +// assert(tensors[i]->dims[j] == tensors[0]->dims[j]); +// } else { +// dims[j] += tensors[i]->dims[j]; +// } +// } +// } +// concat->outputs[0] = create_tensor_legion_ordering( +// numdim, dims, tensors[0]->data_type, concat, 0, true /*create_grad*/); +// concat->add_int_property("legion_axis", numdim - axis - 1); +// layers.push_back(concat); +// return concat->outputs[0]; +// } + +// Op *Concat::create_operator_from_layer( +// FFModel &model, +// Layer const *layer, +// std::vector const &inputs) { +// long long value; +// layer->get_int_property("legion_axis", value); +// int legion_axis = value; +// return new Concat( +// model, inputs.size(), inputs.data(), legion_axis, layer->name); +// } + +// Concat::Concat(FFModel &model, +// int _n, +// ParallelTensor const *_tensors, +// int _legion_axis, +// char const *name) +// : Op(model, +// OP_CONCAT, +// DT_FLOAT, +// name, +// _n /*inputs*/, +// 0 /*weights*/, +// 1 /*outputs*/, +// _tensors), +// legion_axis(_legion_axis) { +// int num_dim = inputs[0]->num_dims; +// ParallelDim dims[MAX_TENSOR_DIM]; +// for (int i = 0; i < num_dim; i++) { +// dims[i] = inputs[0]->dims[i]; +// } +// for (int i = 1; i < numInputs; i++) { +// assert(inputs[i]->data_type == inputs[0]->data_type); +// assert(inputs[i]->num_dims == inputs[0]->num_dims); +// for (int j = 0; j < num_dim; j++) { +// if (j != legion_axis) { +// assert(inputs[i]->dims[j] == inputs[0]->dims[j]); +// } else { +// // Assert that the concat dim cannot be parallelized +// assert(inputs[i]->dims[j].parallel_idx == -1); +// assert(inputs[i]->dims[j].degree == 1); +// dims[j].size += inputs[i]->dims[j].size; +// } +// } +// } +// numOutputs = 1; +// outputs[0] = model.create_parallel_tensor_legion_ordering( +// num_dim, dims, inputs[0]->data_type, this); +// } + +// Concat::Concat(FFModel &model, +// ConcatParams const ¶ms, +// std::vector const &inputs, +// char const *name) +// : Concat(model, inputs.size(), inputs.data(), params.axis, name) {} + + +// void Concat::init(FFModel const &ff) { +// this->execute_task(ff, CONCAT_INIT_TASK_ID, get_init_task_signature()); +// // assert(check_output_input_weight_same_parallel_is()); +// // parallel_is = outputs[0]->parallel_is; +// // ArgumentMap argmap; +// // Context ctx = ff.config.lg_ctx; +// // Runtime *runtime = ff.config.lg_hlr; +// // set_argumentmap_for_init(ff, argmap); +// // IndexLauncher launcher(CONCAT_INIT_TASK_ID, +// // parallel_is, +// // TaskArgument(this, sizeof(Concat)), +// // argmap, +// // Predicate::TRUE_PRED, +// // false /*must*/, +// // 0 /*mapper_id*/, +// // outputs[0]->machine_view.hash()); +// // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// // 0 /*projection id*/, +// // WRITE_ONLY, +// // EXCLUSIVE, +// // outputs[0]->region)); +// // launcher.add_field(0, FID_DATA); +// // for (int i = 0; i < numInputs; i++) { +// // launcher.add_region_requirement(RegionRequirement(inputs[i]->part, +// // 0 /*projection id*/, +// // READ_ONLY, +// // EXCLUSIVE, +// // inputs[i]->region)); +// // launcher.add_field(i + 1, FID_DATA); +// // } +// // for (int i = 0; i < numInputs; i++) { +// // launcher.add_region_requirement(RegionRequirement(inputs[i]->part_grad, +// // 0 /*projection id*/, +// // WRITE_ONLY, +// // EXCLUSIVE, +// // inputs[i]->region_grad)); +// // launcher.add_field(i + numInputs + 1, FID_DATA); +// // } +// // FutureMap fm = runtime->execute_index_space(ctx, launcher); +// // fm.wait_all_results(); +// // set_opmeta_from_futuremap(ff, fm); +// } + + +// void Concat::forward(FFModel const &ff) { +// this->execute_task(ff, CONCAT_FWD_TASK_ID, get_fwd_task_signature()); +// // ArgumentMap argmap; +// // Context ctx = ff.config.lg_ctx; +// // Runtime *runtime = ff.config.lg_hlr; +// // set_argumentmap_for_forward(ff, argmap); +// // IndexLauncher launcher(CONCAT_FWD_TASK_ID, +// // parallel_is, +// // TaskArgument(this, sizeof(Concat)), +// // argmap, +// // Predicate::TRUE_PRED, +// // false /*must*/, +// // 0 /*mapper_id*/, +// // outputs[0]->machine_view.hash()); +// // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// // 0 /*projection id*/, +// // WRITE_ONLY, +// // EXCLUSIVE, +// // outputs[0]->region)); +// // launcher.add_field(0, FID_DATA); +// // for (int i = 0; i < numInputs; i++) { +// // launcher.add_region_requirement(RegionRequirement(inputs[i]->part, +// // 0 /*projection id*/, +// // READ_ONLY, +// // EXCLUSIVE, +// // inputs[i]->region)); +// // launcher.add_field(i + 1, FID_DATA); +// // } +// // runtime->execute_index_space(ctx, launcher); +// } + +/* + regions[0](O): output + regions[1..numInputs](I): inputs +*/ + + +// void Concat::backward(FFModel const &ff) { +// this->execute_task(ff, CONCAT_BWD_TASK_ID, get_bwd_task_signature()); +// // ArgumentMap argmap; +// // Context ctx = ff.config.lg_ctx; +// // Runtime *runtime = ff.config.lg_hlr; +// // set_argumentmap_for_backward(ff, argmap); +// // IndexLauncher launcher(CONCAT_BWD_TASK_ID, +// // parallel_is, +// // TaskArgument(this, sizeof(Concat)), +// // argmap, +// // Predicate::TRUE_PRED, +// // false /*must*/, +// // 0 /*mapper_id*/, +// // outputs[0]->machine_view.hash()); +// // launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, +// // 0 /*projection id*/, +// // READ_ONLY, +// // EXCLUSIVE, +// // outputs[0]->region_grad)); +// // launcher.add_field(0, FID_DATA); +// // for (int i = 0; i < numInputs; i++) { +// // launcher.add_region_requirement(RegionRequirement(inputs[i]->part_grad, +// // 0 /*projection id*/, +// // READ_WRITE, +// // EXCLUSIVE, +// // inputs[i]->region_grad)); +// // // LogicalRegion lr = inputs[i]->region_grad; +// // // printf("concat[%d]: region(%d,%d,%d)\n", i+1, +// // // lr.get_index_space().get_id(), lr.get_field_space().get_id(), +// // // lr.get_tree_id()); +// // launcher.add_field(i + 1, FID_DATA); +// // } +// // runtime->execute_index_space(ctx, launcher); +// } + +/* + regions[0](I): output_grad + regions[1..numInputs](I/O): input_grad +*/ + +// bool Concat::get_int_parameter(PMParameter para, int *value) const { +// switch (para) { +// case PM_AXIS: +// *value = legion_axis; +// return true; +// default: +// return Op::get_int_parameter(para, value); +// } +// } + + +// bool Concat::measure_operator_cost(Simulator *sim, +// MachineView const &mv, +// CostMetrics &cost_metrics) const { +// assert(numInputs <= MAX_NUM_INPUTS); +// ParallelTensorBase sub_inputs[MAX_NUM_INPUTS], sub_output; +// if (!outputs[0]->get_sub_tensor(mv, sub_output)) { +// return false; +// } +// for (int i = 0; i < numInputs; i++) { +// if (!inputs[i]->get_sub_tensor(mv, sub_inputs[i])) { +// return false; +// } +// } + +// ConcatPerDeviceState *m = sim->concat_meta; +// init_meta(m, this->legion_axis); + +// sim->free_all(); +// float *input_ptrs[MAX_NUM_INPUTS]; +// float *input_grad_ptrs[MAX_NUM_INPUTS]; +// bool out_of_memory = false; +// for (int i = 0; i < numInputs; i++) { +// input_ptrs[i] = +// (float *)sim->allocate(sub_inputs[i].get_volume(), DT_FLOAT); +// out_of_memory = out_of_memory || (input_ptrs[i] == NULL); +// } +// cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); + +// Domain out_domain = sub_output.get_domain(); +// float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); +// GenericTensorAccessorW output_acc(DT_FLOAT, out_domain, output_ptr); +// cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); + +// out_of_memory = out_of_memory || (output_ptr == NULL); +// if (out_of_memory) { +// cost_metrics.forward_time = Simulator::MAXIMUM_TASK_RUN_TIME; +// cost_metrics.backward_time = Simulator::MAXIMUM_TASK_RUN_TIME; +// return true; +// } + +// Domain in_domains[MAX_NUM_INPUTS]; +// GenericTensorAccessorR input_acc[MAX_NUM_INPUTS]; +// for (int i = 0; i < numInputs; i++) { +// in_domains[i] = sub_inputs[i].get_domain(); +// input_acc[i] = +// GenericTensorAccessorR(DT_FLOAT, in_domains[i], input_ptrs[i]); +// } + +// assert(m->profiling == false); + +// std::function forward, backward; +// forward = [&](ffStream_t stream) { +// forward_kernel(stream, m, output_acc, input_acc, numInputs); +// }; +// if (sim->computationMode == COMP_MODE_TRAINING) { +// GenericTensorAccessorW input_grad_accs[MAX_NUM_INPUTS]; +// for (int i = 0; i < numInputs; i++) { +// input_grad_ptrs[i] = +// (float *)sim->allocate(sub_inputs[i].get_volume(), DT_FLOAT); +// out_of_memory = out_of_memory || (input_grad_ptrs[i] == NULL); +// input_grad_accs[i] = +// GenericTensorAccessorW(DT_FLOAT, in_domains[i], input_grad_ptrs[i]); +// } +// cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); +// float *output_grad_ptr = +// (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); +// GenericTensorAccessorR output_grad_acc( +// DT_FLOAT, out_domain, output_grad_ptr); +// cost_metrics.outputs_memory += +// cost_metrics.total_mem_diff_from(sim->offset); + +// out_of_memory = out_of_memory || (output_grad_ptr == NULL); +// if (out_of_memory) { +// cost_metrics.forward_time = Simulator::MAXIMUM_TASK_RUN_TIME; +// cost_metrics.backward_time = Simulator::MAXIMUM_TASK_RUN_TIME; +// return true; +// } +// backward = [&](ffStream_t stream) { +// backward_kernel(stream, m, output_grad_acc, input_grad_accs, numInputs); +// }; +// } + +// inner_measure_operator_cost(sim, forward, backward, cost_metrics); + +// if (sim->computationMode == COMP_MODE_TRAINING) { +// printf( +// "[Measure Concat] name(%s) forward_time(%.4lf) backward_time(%.4lf)\n", +// name, +// cost_metrics.forward_time, +// cost_metrics.backward_time); +// } else { +// printf("[Measure Concat] name(%s) forward_time(%.4lf)\n", +// name, +// cost_metrics.forward_time); +// } + +// return true; +// } \ No newline at end of file diff --git a/lib/runtime/src/serialization.h b/lib/runtime/src/serialization.h index adf838201a..6839bfdde5 100644 --- a/lib/runtime/src/serialization.h +++ b/lib/runtime/src/serialization.h @@ -9,7 +9,8 @@ #include "utils/optional.h" #include "utils/variant.h" #include "utils/visitable.h" -#include +#include "utils/type_traits.h" +#include "utils/required.h" namespace FlexFlow { @@ -77,6 +78,12 @@ struct is_trivially_serializable< typename std::enable_if::value>::type> : std::true_type {}; +template +struct is_trivially_serializable>> : is_trivially_serializable> {}; + +template +struct is_trivially_serializable> : is_trivially_serializable {}; + template <> struct is_trivially_serializable : std::true_type {}; template <> diff --git a/lib/runtime/src/task_spec/op_task_invocation.h b/lib/runtime/src/task_spec/op_task_invocation.h index 07f5bf12ae..9203f2f4f1 100644 --- a/lib/runtime/src/task_spec/op_task_invocation.h +++ b/lib/runtime/src/task_spec/op_task_invocation.h @@ -5,6 +5,8 @@ #include "index_task_invocation.h" #include "legion.h" #include "op_arg_ref.h" +#include "op_tensor_spec.h" +#include "variadic_tensor_ref.h" #include "op_task_signature.h" #include "runtime/config.h" #include "runtime/profiling.h" @@ -22,16 +24,6 @@ namespace FlexFlow { enum class IsTrainable { YES, NO }; -struct OpTensorSpec { - TensorRole role; - req idx; -}; -FF_VISITABLE_STRUCT(OpTensorSpec, role, idx); - -OpTensorSpec input_tensor(int); -OpTensorSpec output_tensor(int); -OpTensorSpec weight_tensor(int); - using OpArgSpec = variant + void bind(slot_id name, VariadicTensorRef const &t) { + NOT_IMPLEMENTED(); + } + template void bind_device_specific_arg(slot_id name, T const &t) { NOT_IMPLEMENTED(); diff --git a/lib/runtime/src/task_spec/op_tensor_spec.h b/lib/runtime/src/task_spec/op_tensor_spec.h new file mode 100644 index 0000000000..84261141d3 --- /dev/null +++ b/lib/runtime/src/task_spec/op_tensor_spec.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_OP_TENSOR_SPEC_REF_H +#define _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_OP_TENSOR_SPEC_REF_H + +#include "op_task_signature.h" + +namespace FlexFlow { + +struct OpTensorSpec { + TensorRole role; + req idx; +}; +FF_VISITABLE_STRUCT(OpTensorSpec, role, idx); + +OpTensorSpec input_tensor(int); +OpTensorSpec output_tensor(int); +OpTensorSpec weight_tensor(int); + + +} // namespace FlexFlow + +#endif diff --git a/lib/runtime/src/task_spec/variadic_tensor_ref.h b/lib/runtime/src/task_spec/variadic_tensor_ref.h new file mode 100644 index 0000000000..1b9bc33b4f --- /dev/null +++ b/lib/runtime/src/task_spec/variadic_tensor_ref.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_VARIADIC_TENSOR_ARG_REF_H +#define _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_VARIADIC_TENSOR_ARG_REF_H + +#include "op_tensor_spec.h" +#include "arg_ref.h" + +namespace FlexFlow { + +enum class VariadicTensorRefType {INPUT_TENSORS, + NUM_INPUTS}; + +template +using VariadicTensorRef = ArgRef; + +VariadicTensorRef get_input_tensors() { + return {VariadicTensorRefType::INPUT_TENSORS}; +} + +VariadicTensorRef get_number_inputs() { + return {VariadicTensorRefType::NUM_INPUTS}; +} + + +} // namespace FlexFlow + +#endif From d3342d5abe3b12cbd1e49e15cd5cc538b8f1d20b Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Fri, 1 Sep 2023 10:56:18 -0700 Subject: [PATCH 23/26] conv2d --- lib/kernels/include/kernels/conv_2d_kernels.h | 65 +- lib/runtime/src/ops/conv_2d.cc | 1154 +++-------------- lib/runtime/src/ops/conv_2d.h | 701 +++++++++- 3 files changed, 905 insertions(+), 1015 deletions(-) diff --git a/lib/kernels/include/kernels/conv_2d_kernels.h b/lib/kernels/include/kernels/conv_2d_kernels.h index 50b3c0601f..dc6bf941b1 100644 --- a/lib/kernels/include/kernels/conv_2d_kernels.h +++ b/lib/kernels/include/kernels/conv_2d_kernels.h @@ -5,45 +5,50 @@ namespace FlexFlow { -class Conv2DPerDeviceState : public PerDeviceOpState { -public: - Conv2DPerDeviceState(FFHandler handler); - ffTensorDescriptor_t inputTensor, biasTensor, outputTensor; +struct Conv2DPerDeviceState { + PerDeviceFFHandle handle; + ffTensorDescriptor_t inputTensor; + ffTensorDescriptor_t biasTensor; + ffTensorDescriptor_t outputTensor; ffFilterDescriptor_t filterDesc; ffActivationDescriptor_t actiDesc; ffConvolutionDescriptor_t convDesc; ffConvolutionFwdAlgo_t fwdAlgo; ffConvolutionBwdFilterAlgo_t bwdFilterAlgo; ffConvolutionBwdDataAlgo_t bwdDataAlgo; - bool relu, use_bias; - char op_name[MAX_OPNAME]; + req> activation; + req use_bias; }; +FF_VISITABLE_STRUCT_NO_EQ(Conv2DPerDeviceState, + handle, + inputTensor, + biasTensor, + outputTensor, + filterDesc, + actiDesc, + convDesc, + fwdAlgo, + bwdFilterAlgo, + bwdDataAlgo, + activation, + use_bias); + namespace Kernels { namespace Conv2D { -void init_kernel(Conv2DPerDeviceState *m, - int input_w, - int input_h, - int input_c, - int input_n, - int output_w, - int output_h, - int output_c, - int output_n, - int kernel_h, - int kernel_w, - int groups, - int stride_h, - int stride_w, - int pad_h, - int pad_w, - float const *input_ptr, - float *output_ptr, - float const *kernel_ptr, - float *kernel_grad_ptr, - float *forward_time = nullptr, - float *backward_time = nullptr); +Conv2DPerDeviceState init_kernel(PerDeviceFFHandle handle, + ffTensorDescriptor_t inputTensor, + ffTensorDescriptor_t biasTensor, + ffTensorDescriptor_t outputTensor, + ffFilterDescriptor_t filterDesc, + ffActivationDescriptor_t actiDesc, + ffConvolutionDescriptor_t convDesc, + ffConvolutionFwdAlgo_t fwdAlgo, + ffConvolutionBwdFilterAlgo_t bwdFilterAlgo, + ffConvolutionBwdDataAlgo_t bwdDataAlgo, + req> relu, + bool use_bias); void forward_kernel(ffStream_t stream, Conv2DPerDeviceState const *m, @@ -58,8 +63,8 @@ void backward_kernel(ffStream_t stream, float *input_grad_ptr, float const *output_ptr, float *output_grad_ptr, - float const *kernel_ptr, - float *kernel_grad_ptr, + float const *filter_ptr, + float *filter_grad_ptr, float *bias_grad_ptr); } // namespace Conv2D diff --git a/lib/runtime/src/ops/conv_2d.cc b/lib/runtime/src/ops/conv_2d.cc index e362c73f92..fc87ed987f 100644 --- a/lib/runtime/src/ops/conv_2d.cc +++ b/lib/runtime/src/ops/conv_2d.cc @@ -1,1051 +1,237 @@ #include "conv_2d.h" #include "kernels/conv_2d_kernels.h" -#include "layer.h" #include "legion/legion_utilities.h" #include "mpark/variant.hpp" -#include "task_spec.h" #include "utils/hash-utils.h" +#include "op-attrs/get_output_shapes.h" namespace FlexFlow { -enum Slots { - INPUT, - OUTPUT, - FILTER, - BIAS, - FILTER_GRAD, - INPUT_GRAD, - OUTPUT_GRAD, - BIAS_GRAD, - ATTRS, - PROFILING, -} - -// declare Legion names -using Legion::ArgumentMap; using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; -using Legion::InlineLauncher; using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; using Legion::Runtime; using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; using namespace FlexFlow::Kernels::Conv2D; -Tensor FFModel::conv2d(Tensor const &input, - int outChannels, - int kernelH, - int kernelW, - int strideH, - int strideW, - int paddingH, - int paddingW, - ActiMode activation, - int groups, - bool use_bias, - Layer const *shared_op, - Initializer *kernel_initializer, - Initializer *bias_initializer, - char const *name) { - assert(input->num_dims() == 4); /*NCHW*/ - - Conv2DAttrs attrs = {outChannels, - kernelH, - kernelW, - strideH, - strideW, - paddingH, - paddingW, - groups, - activation, - use_bias}; - - TensorShape output_shape = get_output_shape(attrs, input->get_shape()); - Tensor output = this->tensor_mgr.create(output_shape, CreateGrad::YES, conv); - - std::vector weights; - - TensorShape kernel_shape = get_kernel_shape(attrs, input->get_shape()); - weights.push_back(this->tensor_mgr.create( - kernel_shape, CreateGrad::YES, kernel_initializer, CHOSEN_SYNC_TYPE)); - - if (use_bias) { - TensorShape bias_shape = get_bias_shape(attrs, input->get_shape()); - weights.push_back(this->tensor_mgr.create( - bias_shape, CreateGrad::YES, bias_initializer, CHOSEN_SYNC_TYPE)); - } - - Layer *conv = - this->layer_mgr.create(attrs, DT_FLOAT, name, {input}, weights, {output}); - - //{ - // int numdims = 4; - // int dims[MAX_TENSOR_DIM]; - // dims[3] = input->dims[3]; - // dims[2] = outChannels; - // dims[1] = 1 + (input->dims[1] + 2 * paddingH - kernelH) / strideH; - // dims[0] = 1 + (input->dims[0] + 2 * paddingW - kernelW) / strideW; - // conv->outputs[0] = create_tensor_legion_ordering( - // numdims, dims, DT_FLOAT, conv, 0, true /*create_grad*/); - //} - //{ - // int dims[4] = {kernelW, kernelH, input->dims[2], outChannels}; - // conv->weights[0] = create_weight_legion_ordering(4, - // dims, - // DT_FLOAT, - // conv, - // true /*create_grad*/, - // kernel_initializer, - // CHOSEN_SYNC_TYPE); - //} - // if (use_bias) { - // int dims[1] = {outChannels}; - // conv->weights[1] = create_weight_legion_ordering(1, - // dims, - // DT_FLOAT, - // conv, - // true /*create_grad*/, - // bias_initializer, - // CHOSEN_SYNC_TYPE); - //} - conv->add_initializer("kernel", kernel_initializer); - conv->add_initializer("bias", bias_initializer); - /* layers.push_back(conv); */ - return conv->outputs[0]; -} - -Op *Conv2D::create_operator_from_layer( - FFModel &model, - Layer const *layer, - std::vector const &inputs) { - return new Conv2D(model, - get(layer->attrs), - inputs, - layer->name, - false /*allocate_weights*/ - ); -} - -/* void Conv2DParams::mark_replica_dims( */ -/* ParallelTensorShape const &input, */ -/* ParallelDim output_dims[MAX_TENSOR_DIM], */ -/* ParallelDim kernel_dims[MAX_TENSOR_DIM], */ -/* ParallelDim bias_dims[MAX_TENSOR_DIM]) const { */ -/* if (output_dims != nullptr) { */ -/* output_dims[Conv2DOutput::REPLICA].is_replica_dim = true; */ -/* } */ -/* if (kernel_dims != nullptr) { */ -/* kernel_dims[Conv2DOutput::REPLICA].is_replica_dim = true; */ -/* } */ -/* if (bias_dims != nullptr) { */ -/* bias_dims[Conv2DBias::REPLICA_1].is_replica_dim = true; */ -/* bias_dims[Conv2DBias::REPLICA_2].is_replica_dim = true; */ -/* bias_dims[Conv2DBias::REPLICA_3].is_replica_dim = true; */ -/* bias_dims[Conv2DBias::REPLICA_4].is_replica_dim = true; */ -/* } */ -/* } */ - -/* int Conv2DParams::output_size(ParallelTensorShape const &input, */ -/* ParallelDim output_dims[MAX_TENSOR_DIM]) const - * { */ -/* int input_w = input.dims[Conv2DInput::WIDTH].size; */ -/* int input_h = input.dims[Conv2DInput::HEIGHT].size; */ - -/* output_dims[Conv2DOutput::SAMPLE].size = - * input.dims[Conv2DInput::SAMPLE].size; */ -/* output_dims[Conv2DOutput::CHANNEL].size = out_channels; */ -/* output_dims[Conv2DOutput::HEIGHT].size = */ -/* 1 + (input_h + 2 * padding_h - kernel_h) / stride_h; */ -/* output_dims[Conv2DOutput::WIDTH].size = */ -/* 1 + (input_w + 2 * padding_w - kernel_w) / stride_w; */ - -/* return input.num_dims; */ -/* }; */ - -/* int Conv2DParams::kernel_size(ParallelTensorShape const &input, */ -/* ParallelDim kernel_dims[MAX_TENSOR_DIM]) const - * { */ -/* kernel_dims[Conv2DKernel::CHANNEL_OUT].size = this->out_channels; */ -/* kernel_dims[Conv2DKernel::CHANNEL_IN].size = */ -/* input.dims[Conv2DInput::CHANNEL].size / this->groups; */ -/* kernel_dims[Conv2DKernel::HEIGHT].size = */ -/* this->kernel_h * input.dims[Conv2DInput::HEIGHT].degree; */ -/* kernel_dims[Conv2DKernel::WIDTH].size = */ -/* this->kernel_w * input.dims[Conv2DInput::WIDTH].degree; */ - -/* return Conv2DKernel::NUMDIM; */ -/* } */ - -/* int Conv2DParams::bias_size(ParallelTensorShape const &input, */ -/* ParallelDim bias_dims[MAX_TENSOR_DIM]) const { */ -/* bias_dims[Conv2DBias::CHANNEL].size = this->out_channels; */ - -/* return Conv2DBias::NUMDIM; */ -/* }; */ - -/* void Conv2DParams::solve_dims(ParallelTensorShape const &input, */ -/* ParallelDim output_dims[MAX_TENSOR_DIM], */ -/* int *output_ndims, */ -/* ParallelDim kernel_dims[MAX_TENSOR_DIM], */ -/* int *kernel_ndims, */ -/* ParallelDim bias_dims[MAX_TENSOR_DIM], */ -/* int *bias_ndims) const { */ -/* assert((output_dims == nullptr) == (output_ndims == nullptr)); */ -/* assert((kernel_dims == nullptr) == (kernel_ndims == nullptr)); */ -/* assert((bias_dims == nullptr) == (bias_ndims == nullptr)); */ - -/* std::vector mapping; */ -/* Conv2D::construct_mappings(mapping, this->use_bias); */ - -/* this->mark_replica_dims(input, output_dims, kernel_dims, bias_dims); */ - -/* std::vector output_dim_sets; */ -/* if (output_dims != nullptr) { */ -/* output_dim_sets.push_back(output_dims); */ -/* } */ - -/* std::vector weight_dim_sets; */ -/* if (kernel_dims != nullptr) { */ -/* weight_dim_sets.push_back(kernel_dims); */ -/* } */ -/* if (bias_dims != nullptr && this->use_bias) { */ -/* weight_dim_sets.push_back(bias_dims); */ -/* } */ - -/* solve_parallel_dim_mappings( */ -/* mapping, {input.dims}, weight_dim_sets, output_dim_sets); */ - -/* if (output_dims != nullptr) { */ -/* *output_ndims = this->output_size(input, output_dims); */ -/* } */ -/* if (kernel_dims != nullptr) { */ -/* *kernel_ndims = this->kernel_size(input, kernel_dims); */ -/* } */ -/* if (bias_dims != nullptr && this->use_bias) { */ -/* *bias_ndims = this->bias_size(input, bias_dims); */ -/* } */ -/* } */ - -/*static*/ -/* void Conv2D::construct_mappings(std::vector &out, - */ -/* bool use_bias) { */ -/* Conv2D::construct_output_mappings(out); */ -/* Conv2D::construct_weight_mappings(out, use_bias); */ -/* } */ - -/*static*/ -/* void Conv2D::construct_output_mappings( */ -/* std::vector &out) { */ -/* Op::construct_output_parallel_dims( */ -/* out, */ -/* {{Conv2DInput::CHANNEL, */ -/* MappingOperation::REPLICATE, */ -/* Conv2DOutput::REPLICA}, */ -/* {Conv2DInput::SAMPLE, MappingOperation::PARTITION, - * Conv2DOutput::SAMPLE}, */ -/* {Conv2DInput::REPLICA, */ -/* MappingOperation::PARTITION, */ -/* Conv2DOutput::CHANNEL}, */ -/* {Conv2DInput::HEIGHT, MappingOperation::PARTITION, - * Conv2DOutput::HEIGHT}, */ -/* {Conv2DInput::WIDTH, MappingOperation::PARTITION, - * Conv2DOutput::WIDTH}}); */ -/* } */ - -/*static*/ -/* void Conv2D::construct_weight_mappings( */ -/* std::vector &out, bool use_bias) { */ -/* Op::construct_weight_parallel_dims( */ -/* out, */ -/* { */ -/* {Conv2DInput::REPLICA, */ -/* MappingOperation::PARTITION, */ -/* Conv2DKernel::CHANNEL_OUT}, */ -/* {Conv2DInput::SAMPLE, */ -/* MappingOperation::REPLICATE, */ -/* Conv2DKernel::REPLICA}, */ -/* {Conv2DInput::CHANNEL, */ -/* MappingOperation::PARTITION, */ -/* Conv2DKernel::CHANNEL_IN}, */ -/* {Conv2DInput::HEIGHT, */ -/* MappingOperation::REPLICATE, */ -/* Conv2DKernel::HEIGHT}, // Kernel::{HEIGHT, WEIGHT} would both work - */ -/* // here */ -/* {Conv2DInput::WIDTH, */ -/* MappingOperation::REPLICATE, */ -/* Conv2DKernel::WIDTH}, // same as above */ -/* }, */ -/* Conv2DInput::INDEX, */ -/* Conv2DKernel::INDEX); */ - -/* if (use_bias) { */ -/* Op::construct_weight_parallel_dims( */ -/* out, */ -/* {{Conv2DInput::REPLICA, Conv2DBias::REPLICA_1}, */ -/* {Conv2DInput::SAMPLE, Conv2DBias::REPLICA_2}, */ -/* {Conv2DInput::CHANNEL, Conv2DBias::CHANNEL}, */ -/* {Conv2DInput::HEIGHT, Conv2DBias::REPLICA_3}, */ -/* {Conv2DInput::WIDTH, Conv2DBias::REPLICA_4}}, */ -/* Conv2DInput::INDEX, */ -/* Conv2DBias::INDEX); */ -/* } */ -/* } */ - -Conv2D::Conv2D(FFModel &model, - Conv2D const &other, - const ParallelTensor input, - bool allocate_weights) - : Conv2D(model, - other.layer_guid, - input, - other.out_channels, - other.kernel_h, - other.kernel_w, - other.stride_h, - other.stride_w, - other.padding_h, - other.padding_w, - other.activation, - other.groups, - other.use_bias, - allocate_weights, - other.name) {} - -Conv2D::Conv2D(FFModel &model, - Conv2DAttrs const &attrs, - std::vector const &inputs, - char const *name, - bool allocate_weights) - : Conv2D(model, - params.layer_guid, - input, - params.out_channels, - params.kernel_h, - params.kernel_w, - params.stride_h, - params.stride_w, - params.padding_h, - params.padding_w, - params.activation, - params.groups, - params.use_bias, - allocate_weights, - name) {} - -/* bool Conv2DParams::is_valid(ParallelTensorShape const &input) const { */ -/* ParallelTensorShape output_shape, kernel_shape, bias_shape; */ -/* this->solve_dims(input, */ -/* output_shape.dims, */ -/* &output_shape.num_dims, */ -/* kernel_shape.dims, */ -/* &kernel_shape.num_dims, */ -/* bias_shape.dims, */ -/* &bias_shape.num_dims); */ -/* bool is_valid = true; */ -/* is_valid &= input.is_valid(); */ -/* is_valid &= output_shape.is_valid(); */ -/* is_valid &= kernel_shape.is_valid(); */ -/* if (use_bias) { */ -/* is_valid &= bias_shape.is_valid(); */ -/* } */ - -/* // TODO FIXME: Currently disable parallelizing the height and width - * dimension */ -/* if (input.dims[0].degree > 1 || input.dims[1].degree > 1) { */ -/* return false; */ -/* } */ - -/* return is_valid; */ -/* } */ - -Conv2D::Conv2D(FFModel &model, - LayerID const &_layer_guid, - const ParallelTensor input, - int outChannels, - int kernelH, - int kernelW, - int strideH, - int strideW, - int paddingH, - int paddingW, - ActiMode activation, - int groups, - bool use_bias, - bool allocate_weights, - char const *name) - : Op(model, - OP_CONV2D, - DT_FLOAT, - name, - 1 /*inputs*/, - use_bias ? 2 : 1 /*weights*/, - allocate_weights, - 1 /*outputs*/, - input), - in_channels(input->dims[Conv2DInput::CHANNEL].size / - input->dims[Conv2DInput::CHANNEL].degree), - out_channels(outChannels), kernel_h(kernelH), kernel_w(kernelW), - stride_h(strideH), stride_w(strideW), padding_h(paddingH), - padding_w(paddingW), activation(activation), groups(groups), - use_bias(use_bias) { - // overwrite layer_guid - layer_guid = _layer_guid; - assert(input->num_dims == Conv2DInput::NUMDIM); - assert(this->stride_h > 0); - assert(this->stride_w > 0); - - ParallelDim output_dims[MAX_TENSOR_DIM], kernel_dims[MAX_TENSOR_DIM], - bias_dims[MAX_TENSOR_DIM]; - int output_ndims, kernel_ndims, bias_ndims; - - this->construct_mappings(*this->parallel_dims_mapping, this->use_bias); - this->get_params().solve_dims(this->inputs[0]->get_shape(), - output_dims, - &output_ndims, - kernel_dims, - &kernel_ndims, - bias_dims, - &bias_ndims); - - if (allocate_weights) { - Initializer *kernel_initializer = new GlorotUniform(std::rand() /*seed*/); - - weights[Conv2DKernel::INDEX] = - model.create_parallel_weight_legion_ordering(kernel_ndims, - kernel_dims, - DT_FLOAT, - NULL /*owner_op*/, - true /*create_grad*/, - kernel_initializer, - CHOSEN_SYNC_TYPE); - - if (use_bias) { - Initializer *bias_initializer = new ZeroInitializer(); - - weights[Conv2DBias::INDEX] = - model.create_parallel_weight_legion_ordering(bias_ndims, - bias_dims, - DT_FLOAT, - NULL /*owner_op*/, - true /*create_grad*/, - bias_initializer, - CHOSEN_SYNC_TYPE); - } - } - - outputs[0] = model.create_parallel_tensor_legion_ordering( - output_ndims, output_dims, DT_FLOAT, this); - - assert(check_output_input_weight_parallel_dims(allocate_weights)); -} - -static OpTaskSignature get_init_task_signature() { - OpTaskSignature init(OpTaskType::INIT); - - init.add_arg_slot(ATTRS); - init.add_arg_slot(PROFILING); - - init.add_input_slot(INPUT); - init.add_output_slot(OUTPUT, WRITE_ONLY); - init.add_param_slot(FILTER); - init.add_param_slot(BIAS); - init.add_param_grad_slot(FILTER_GRAD, WRITE_ONLY); - init.add_input_grad_slot(INPUT_GRAD); - - return init; -} - -static OpTaskSignature get_fwd_task_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_arg_slot(ATTRS); - - fwd.add_input_slot(INPUT); - fwd.add_output_slot(OUTPUT, WRITE_ONLY); - fwd.add_param_slot(FILTER); - fwd.add_param_slot(BIAS); - - return fwd; -} - -static OpTaskSignature get_bwd_task_signature() { - OpTaskSignature bwd(OpTaskType::BWD); - - bwd.add_arg_slot(ATTRS); - - bwd.add_input_slot(INPUT); - bwd.add_input_grad_slot(INPUT_GRAD, READ_WRITE); - bwd.add_output_slot(OUTPUT); - bwd.add_output_grad_slot(OUTPUT_GRAD, READ_WRITE); - bwd.add_param_slot(FILTER); - bwd.add_param_grad_slot(FILTER_GRAD, READ_WRITE); - bwd.add_param_grad_slot(BIAS_GRAD, READ_WRITE); - - return bwd; -} +enum Slots { + INPUT, + OUTPUT, + FILTER, + BIAS, + ATTRS, + PROFILING, + PER_DEVICE_STATE, + HANDLE +}; -OpTaskBinding Conv2d::get_init_task_binding() const { +OpTaskInvocation init(Conv2DAttrs const &attrs) { OpTaskBinding binding; - binding.bind_arg(ATTRS, this->attrs); - binding.bind_arg(PROFILING, this->profiling); - - binding.bind(INPUT, input_tensor(0)); - binding.bind(OUTPUT, output_tensor(0)); - binding.bind(FILTER, param_tensor(0)); - binding.bind(BIAS, param_tensor(1)); - binding.bind(FILTER_GRAD, param_tensor(0).grad()); - binding.bind(INPUT_GRAD, input_tensor(0).grad()); + binding.bind_arg(ATTRS, attrs); + binding.bind_arg(HANDLE, ff_handle()); - return binding; + return {CONV2D_INIT_TASK_ID, binding}; } -OpTaskBinding Conv2d::get_fwd_task_binding() const { +OpTaskInvocation forward(Conv2DAttrs const &attrs) { OpTaskBinding binding; - binding.bind_arg(ATTRS, this->attrs); + binding.bind_arg(PROFILING, profiling_settings()); + binding.bind_arg(PER_DEVICE_STATE, + per_device_op_state()); binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - binding.bind(FILTER, param_tensor(0)); - binding.bind(BIAS, param_tensor(1)); - - return binding; + binding.bind(FILTER, weight_tensor(0)); + binding.bind(BIAS, weight_tensor(1)); + + return {CONV2D_FWD_TASK_ID, binding}; +} + +OpTaskInvocation backward(Conv2DAttrs const &attrs) { + OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); + + return {CONV2D_BWD_TASK_ID, binding}; +} + +static DeviceSpecific + init_task_impl(TaskArgumentAccessor const &acc) { + + PerDeviceFFHandle handle = acc.get_argument(HANDLE); + auto const &attrs = acc.get_argument(ATTRS); + + ffTensorDescriptor_t inputTensor; + ffTensorDescriptor_t biasTensor; + ffTensorDescriptor_t outputTensor; + ffFilterDescriptor_t filterDesc; + ffActivationDescriptor_t actiDesc; + ffConvolutionDescriptor_t convDesc; + ffConvolutionFwdAlgo_t fwdAlgo; + ffConvolutionBwdFilterAlgo_t bwdFilterAlgo; + ffConvolutionBwdDataAlgo_t bwdDataAlgo; + + DeviceSpecific per_device_state = + acc.create_device_specific( + init_kernel(handle, + inputTensor, + biasTensor, + outputTensor, + filterDesc, + actiDesc, + convDesc, + fwdAlgo, + bwdFilterAlgo, + bwdDataAlgo, + attrs.activation, + attrs.use_bias)); + + return per_device_state; +} + +static DeviceSpecific + init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + return init_task_impl(acc); } -OpTaskBinding Conv2d::get_bwd_task_binding() const { - OpTaskBinding binding; - - binding.bind_arg(ATTRS, this->attrs); - - binding.bind(INPUT, input_tensor(0)); - binding.bind(INPUT_GRAD, input_tensor(0).grad()); - binding.bind(OUTPUT, output_tensor(0)); - binding.bind(OUTPUT_GRAD, output_tensor(0).grad()); - binding.bind(FILTER, param_tensor(0)); - binding.bind(FILTER_GRAD, param_tensor(0).grad()); - binding.bind(BIAS_GRAD, param_tensor(1).grad()); +static optional forward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_argument(PROFILING); + auto per_device_state = acc.get_argument(PER_DEVICE_STATE); - return binding; -} + auto input = acc.get_tensor(INPUT); + auto filter = acc.get_tensor(FILTER); + auto bias = acc.get_tensor(BIAS); + auto output = acc.get_tensor(OUTPUT); -void Conv2D::init(FFModel const &ff) { - this->execute_task(ff, CONV2D_INIT_TASK_ID, get_init_task_signature()); - // assert(check_output_input_weight_same_parallel_is()); - // parallel_is = outputs[0]->parallel_is; - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_init(ff, argmap); - // IndexLauncher launcher(CONV2D_INIT_TASK_ID, - // parallel_is, - // TaskArgument(this, sizeof(Conv2D)), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(1, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(weights[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // weights[0]->region)); - // launcher.add_field(2, FID_DATA); - // // launcher.add_region_requirement( - // // RegionRequirement(weights[1]->part, 0/*projection id*/, - // // READ_ONLY, EXCLUSIVE, weights[1]->region)); - // // launcher.add_field(3, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(weights[0]->part_grad, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // weights[0]->region_grad)); - // launcher.add_field(3, FID_DATA); - // // launcher.add_region_requirement( - // // RegionRequirement(inputs[0]->part_grad, 0/*projection id*/, - // // WRITE_ONLY, EXCLUSIVE, inputs[0]->region_grad)); - // // launcher.add_field(4, FID_DATA); - // FutureMap fm = runtime->execute_index_space(ctx, launcher); - // fm.wait_all_results(); - // set_opmeta_from_futuremap(ff, fm); + return profile(forward_kernel, + profiling, + "[Conv2d] forward_time = %.2lfms\n", + &per_device_state, + input.get_float_ptr(), + output.get_float_ptr(), + filter.get_float_ptr(), + bias.get_float_ptr()); } -/* - regions[0]: input - regions[1]: output - regions[2](I): filter - regions[3](I): bias - regions[4](O): filter_grad - regions[5](O): input_grad -*/ -PerDeviceOpState *Conv2D::init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 4); - assert(task->regions.size() == 4); - // Conv2D const *conv = (Conv2D *)task->args; +static void forward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { TaskArgumentAccessor acc(task, regions, ctx, runtime); - FFHandler handle = *((FFHandler const *)task->local_args); - auto const &attrs = acc.get_argument(ATTRS); - bool profiling = acc.get_argument(PROFILING); - // TensorAccessorR acc_input( - // regions[0], task->regions[0], FID_DATA, ctx, runtime); - // TensorAccessorW acc_output(regions[1], - // task->regions[1], - // FID_DATA, - // ctx, - // runtime, - // false - // /*readOutput*/); - // TensorAccessorR acc_kernel( - // regions[2], task->regions[2], FID_DATA, ctx, runtime); - // TensorAccessorR acc_bias( - // regions[3], task->regions[3], FID_DATA, ctx, runtime); - // TensorAccessorW acc_kernel_grad( - // regions[3], - // task->regions[3], - // FID_DATA, - // ctx, - // runtime, - // false /*readOutput*/); - // TensorAccessorW acc_input_grad( - // regions[4], task->regions[4], FID_DATA, ctx, runtime, - // false/*readOutput*/); - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - auto filter = acc.get_tensor(FILTER); - auto bias = acc.get_tensor(BIAS); - auto filter_grad = acc.get_tensor(FILTER_GRAD); - auto input_grad = acc.get_tensor(INPUT_GRAD); - - Conv2DPerDeviceState *m = new Conv2DPerDeviceState(handle); - m->relu = attrs.activation == AC_MODE_RELU; - m->use_bias = attrs.use_bias; - m->profiling = profiling; - // m->trainableInputs[0] = conv->trainableInputs[0]; ?? - std::strcpy(m->op_name, attrs.name); - - int input_w = input.shape[0]; - int input_h = input.shape[1]; - int input_c = input.shape[2]; - int input_n = input.shape[3]; - int output_w = output.shape[0]; - int output_h = output.shape[1]; - int output_c = output.shape[2]; - int output_n = output.shape[3]; - - printf("init conv (input): n(%d) c(%d) h(%d) w(%d)\n", - input_n, - input_c, - input_h, - input_w); - printf("init conv (output): n(%d) c(%d) h(%d) w(%d)\n", - output_n, - output_c, - output_h, - output_w); - - // printf("convDim: padding(%d %d) stride(%d %d)\n", conv->padding_h, - // conv->padding_w, conv->stride_h, conv->stride_w); - int pad_h = - ((output_h - 1) * attrs.stride_h + attrs.kernel_h - input_h + 1) / 2; - int pad_w = - ((output_w - 1) * attrs.stride_w + attrs.kernel_w - input_w + 1) / 2; - if (pad_h != attrs.padding_h) { - printf("Warning: changing conv_padding_h to satisfy output_h size\n"); - } - if (pad_w != attrs.padding_w) { - printf("Warning: changing conv_padding_w to satisfy output_w size\n"); - } - - init_kernel(m, - input_w, - input_h, - input_c, - input_n, - output_w, - output_h, - output_c, - output_n, - attrs.kernel_h, - attrs.kernel_w, - attrs.groups, - attrs.stride_h, - attrs.stride_w, - pad_h, - pad_w, - input.get_float_ptr(), - output.get_float_ptr(), - filter.get_float_ptr(), - filter_grad.get_float_ptr()); - - return m; + forward_task_impl(acc); } -// TaskSpec Conv2D::get_tasks_spec() const { -// OpTasksSpec spec { -// CONV2D_INIT_TASK_ID, -// CONV2D_FWD_TASK_ID, -// CONV2D_BWD_TASK_ID -// }; -// auto &fwd = spec.get_fwd(); - -// fwd.add_input_slot(INPUT); -// fwd.add_param_slot(KERNEL); -// fwd.add_output_slot(OUTPUT); - -// auto input = spec.input_tensor(0); -// auto kernel = spec.param_tensor(0); -// auto bias = spec.param_tensor(1); -// auto output = spec.output_tensor(0); - -// fwd[INPUT] = input; -// fwd[KERNEL] = kernel; -// if (this->use_bias) { -// fwd[BIAS] = bias; -// } -// fwd[OUTPUT] = output; - -// return spec; -// } - -/* TaskSpec Conv2D::get_forward_task_spec() const { */ -/* TaskSpec spec = { CONV2D_FWD_TASK_ID, Pass::FWD }; */ - -/* auto input = spec.add_tensor(TensorRole::INPUT, 0); */ -/* auto kernel = spec.add_tensor(TensorRole::PARAM, 0); */ -/* auto bias = spec.add_tensor(TensorRole::BIAS, 1); */ -/* auto output = spec.add_tensor(TensorRole::OUTPUT, 0); */ - -/* spec.add_input(INPUT, input); */ -/* spec.add_input(KERNEL, kernel); */ - -/* if (this->use_bias) { */ -/* spec.add_input(BIAS, bias); */ -/* } */ - -/* spec.add_output(OUTPUT, output); */ - -/* return spec; */ -/* } */ - -/* TaskSpec Conv2D::get_backward_task_spec() const { */ -/* TaskSpec spec = { CONV2D_BWD_TASK_ID, Pass::BWD }; */ - -/* auto input = spec.add_tensor(TensorRole::INPUT, 0); */ -/* auto kernel = spec.add_tensor(TensorRole::PARAM, 0); */ -/* auto bias = spec.add_tensor(TensorRole::BIAS, 1); */ -/* auto output = spec.add_tensor(TensorRole::OUTPUT, 0); */ - -/* spec.add_input(INPUT, input); */ -/* spec.add_output(INPUT_GRAD, input.grad); */ -/* spec.add_input(KERNEL, kernel); */ -/* spec.add_output(KERNEL_GRAD, kernel.grad); */ - -/* if (this->use_bias) { */ -/* spec.add_input(BIAS, bias); */ -/* spec.add_output(BIAS_GRAD, bias.grad); */ -/* } */ - -/* spec.add_input(OUTPUT, output); */ -/* spec.add_input(OUTPUT_GRAD, output.grad); */ +static optional backward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_argument(PROFILING); + auto per_device_state = acc.get_argument(PER_DEVICE_STATE); -/* return spec; */ -/* } */ + auto input = acc.get_tensor(INPUT); + auto output = acc.get_tensor(OUTPUT); + auto filter = acc.get_tensor(FILTER); -void Conv2D::forward(FFModel const &ff) { - this->execute_task(ff, CONV2D_FWD_TASK_ID, get_fwd_task_signature()); -} + auto input_grad = acc.get_tensor_grad(INPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); + auto filter_grad = acc.get_tensor_grad(FILTER); + auto bias_grad = acc.get_tensor_grad(BIAS); -void Conv2D::backward(FFModel const &ff) { - this->execute_task(ff, CONV2D_bWD_TASK_ID, get_bwd_task_signature()); + return profile(backward_kernel, + profiling, + "[Conv2d] backward_time = %.2lfms\n", + &per_device_state, + input.get_float_ptr(), + input_grad.get_float_ptr(), + output.get_float_ptr(), + output_grad.get_float_ptr(), + filter.get_float_ptr(), + filter_grad.get_float_ptr(), + bias_grad.get_float_ptr()); } -/* - regions[0](I): input - regions[1](O): output - regions[2](I): filter - regions[3](I): bias -*/ -void Conv2D::forward_task(Task const *task, +static void backward_task(Task const *task, std::vector const ®ions, Context ctx, Runtime *runtime) { - Conv2DPerDeviceState const *m = *((Conv2DPerDeviceState **)task->local_args); - TaskArgumentAccessor acc(task, regions, ctx, runtime); + backward_task_impl(acc); +} - auto input = acc.get_tensor(INPUT); - auto filter = acc.get_tensor(FILTER); - auto bias = acc.get_tensor(BIAS); - auto output = acc.get_tensor(OUTPUT); - - // TensorAccessorR acc_input( - // regions[0], task->regions[0], FID_DATA, ctx, runtime); - // TensorAccessorW acc_output(regions[1], - // task->regions[1], - // FID_DATA, - // ctx, - // runtime, - // false - // /*readOutput*/); - // TensorAccessorR acc_kernel( - // regions[2], task->regions[2], FID_DATA, ctx, runtime); - // float const *acc_bias_ptr = NULL; - // if (m->use_bias) { - // TensorAccessorR acc_bias( - // regions[3], task->regions[3], FID_DATA, ctx, runtime); - // acc_bias_ptr = acc_bias.ptr; - // } +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + Conv2DAttrs const &attrs, + InputParallelTensorDesc const &input_shape, + InputParallelTensorDesc const &filter_shape, + InputParallelTensorDesc const &bias_shape, + ProfilingSettings const &settings, + MachineView const &mv) { + + auto env = sim.new_environment(); - profile(forward_kernel, - m->profiling, - "[Conv2d] forward_time = %.2lfms\n", - m, - input.get_float_ptr(), - output.get_float_ptr(), - filter.get_float_ptr(), - bias.get_float_ptr()); -} + ParallelTensorShape output_shape = get_output_shape(attrs, input_shape.shape); -/* - region(I): input - region(I/O): input_grad (if trainableInputs[0]) - region(I): output - region(I/O): output_grad - region(I): filter - region(I/O): filter_grad - region(I/O): bias_grad (if use_bias) -*/ -void Conv2D::backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - // Conv2D* conv = (Conv2D*) task->args; - Conv2DPerDeviceState const *m = *((Conv2DPerDeviceState **)task->local_args); - assert(regions.size() == (5 + static_cast(m->trainableInputs[0]) + - static_cast(m->use_bias))); - assert(task->regions.size() == - (5 + static_cast(m->trainableInputs[0]) + - static_cast(m->use_bias))); - size_t rid = 0; - TensorAccessorR acc_input( - regions[rid], task->regions[rid], FID_DATA, ctx, runtime); - rid++; - float *acc_input_grad_ptr = NULL; - if (m->trainableInputs[0]) { - TensorAccessorW acc_input_grad( - regions[rid], - task->regions[rid], - FID_DATA, - ctx, - runtime, - true /*readOutput*/); - acc_input_grad_ptr = acc_input_grad.ptr; - rid++; - } - TensorAccessorR acc_output( - regions[rid], task->regions[rid], FID_DATA, ctx, runtime); - rid++; - TensorAccessorW acc_output_grad( - regions[rid], - task->regions[rid], - FID_DATA, - ctx, - runtime, - true /*readOutput*/); - rid++; - TensorAccessorR acc_kernel( - regions[rid], task->regions[rid], FID_DATA, ctx, runtime); - rid++; - TensorAccessorW acc_kernel_grad( - regions[rid], - task->regions[rid], - FID_DATA, - ctx, - runtime, - true /*readOutput*/); - rid++; - float *acc_bias_grad_ptr = NULL; - if (m->use_bias) { - TensorAccessorW acc_bias_grad( - regions[rid], - task->regions[rid], - FID_DATA, - ctx, - runtime, - true /*readOutput*/); - acc_bias_grad_ptr = static_cast(acc_bias_grad.ptr); - rid++; - } - assert(rid == regions.size()); + SimTaskBinding init_binding; + init_binding.bind_arg(ATTRS, attrs); + init_binding.bind_arg(HANDLE, ff_handle()); - backward_kernel_wrapper(m, - acc_input.ptr, - acc_input_grad_ptr, - acc_output.ptr, - acc_output_grad.ptr, - acc_kernel.ptr, - acc_kernel_grad.ptr, - acc_bias_grad_ptr); -} + auto init_accessor = + env.get_init_accessor(CONV2D_INIT_TASK_ID, init_binding); + DeviceSpecific per_device_state = + init_task_impl(init_accessor); -bool Conv2D::estimate_sync_cost(Simulator *sim, - MachineView const &view, - CostMetrics &cost_metrics) const { - ParallelDim kernel_dims[MAX_TENSOR_DIM], bias_dims[MAX_TENSOR_DIM]; - int kernel_ndims, bias_ndims; + SimTaskBinding fwd_binding; + fwd_binding.bind_arg(PROFILING, settings); + fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); + + fwd_binding.bind(INPUT, input_shape); + fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind(FILTER, filter_shape); + fwd_binding.bind(BIAS, bias_shape); - this->get_params().solve_dims(this->inputs[0]->get_shape(), - nullptr, - nullptr, - kernel_dims, - &kernel_ndims, - bias_dims, - &bias_ndims); + SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); - cost_metrics.sync_time = - sim->default_estimate_sync_cost(kernel_dims, kernel_ndims, view); + auto fwd_accessor = env.get_fwd_accessor(CONV2D_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = env.get_bwd_accessor(CONV2D_BWD_TASK_ID, bwd_binding); - if (this->use_bias) { - cost_metrics.sync_time += - sim->default_estimate_sync_cost(bias_dims, bias_ndims, view); - } + float forward_time = forward_task_impl(fwd_accessor).value(); + float backward_time = backward_task_impl(bwd_accessor).value(); - return true; + float sync_time = default_estimate_sync_time(env); + return make_metrics(forward_time, backward_time, sync_time, env); } -tl::optional Conv2D::as_dot() const { - RecordFormatter rr; - RecordFormatter r; +template <> +void register_task() { + OpTaskSignature init(OpTaskType::INIT); - r << this->inputs[0]->get_shape().as_dot(); - r << "in_channels" << this->in_channels; - r << "out_channels" << this->out_channels; - r << "kernel_h" << this->kernel_h; - r << "kernel_w" << this->kernel_w; - r << "padding_h" << this->padding_h; - r << "padding_w" << this->padding_w; - r << "stride_h" << this->stride_h; - r << "stride_w" << this->stride_w; - r << this->outputs[0]->get_shape().as_dot(); - rr << r; + init.add_arg_slot(ATTRS); + init.add_unchecked_arg_slot(HANDLE); - return rr; + register_task(CONV2D_INIT_TASK_ID, "Conv2D Init", init, init_task); } -bool Conv2D::measure_operator_cost(Simulator *sim, - MachineView const &mv, - CostMetrics &cost_metrics) const { - ParallelTensorBase sub_output, sub_input; - if (!outputs[0]->get_sub_tensor(mv, sub_output)) { - return false; - } - if (!inputs[0]->get_sub_tensor(mv, sub_input)) { - return false; - } - int input_w = sub_input.dims[0].size; - int input_h = sub_input.dims[1].size; - int input_c = sub_input.dims[2].size; - int input_n = sub_input.dims[3].size; - int output_w = sub_output.dims[0].size; - int output_h = sub_output.dims[1].size; - int output_c = sub_output.dims[2].size; - int output_n = sub_output.dims[3].size; - int pad_h = ((output_h - 1) * stride_h + kernel_h - input_h + 1) / 2; - int pad_w = ((output_w - 1) * stride_w + kernel_w - input_w + 1) / 2; - - Conv2DPerDeviceState *m = sim->conv2d_meta; - m->relu = activation == AC_MODE_RELU; - // require input_c is divisible by groups +template <> +void register_task() { + OpTaskSignature fwd(OpTaskType::FWD); - // allocate tensors in simulator - sim->free_all(); - float *input_ptr = (float *)sim->allocate(sub_input.get_volume(), DT_FLOAT); - assert(input_ptr != NULL); - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); + fwd.add_arg_slot(PROFILING); + fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); - float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - assert(output_ptr != NULL); - cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); + fwd.add_input_slot(INPUT); + fwd.add_output_slot(OUTPUT); + fwd.add_weight_slot(FILTER); + fwd.add_weight_slot(BIAS); - float *weight_ptr = (float *)sim->allocate( - (size_t)output_c * input_c * kernel_h * kernel_w / groups, DT_FLOAT); - assert(weight_ptr != NULL); - float *bias_ptr = (float *)sim->allocate(output_c, DT_FLOAT); - assert(bias_ptr != NULL); - cost_metrics.weights_memory += cost_metrics.total_mem_diff_from(sim->offset); + register_task(CONV2D_FWD_TASK_ID, "Conv2D Fwd", fwd, forward_task); +} - init_kernel(m, - input_w, - input_h, - input_c, - input_n, - output_w, - output_h, - output_c, - output_n, - kernel_h, - kernel_w, - groups, - stride_h, - stride_w, - pad_h, - pad_w, - input_ptr, - output_ptr, - weight_ptr, - weight_ptr, // note we reuse weight_ptr for kernel_grad_ptr here - // to avoid allocating another tensor - &cost_metrics.forward_time, - &cost_metrics.backward_time); +template <> +void register_task() { + OpTaskSignature bwd = + infer_bwd_signature(get_op_signature(CONV2D_FWD_TASK_ID)); - log_measure.debug("[Measure Conv2D] name(%s) input(%d %d %d %d) weight(%d %d " - "%d %d) output(%d %d %d %d) stride(%d %d) padding(%d %d) " - "forward_time(%.4lf) backward_time(%.4lf)\n", - name, - input_n, - input_c, - input_h, - input_w, - output_c, - input_c / groups, - kernel_h, - kernel_w, - output_n, - output_c, - output_h, - output_w, - stride_h, - stride_w, - padding_h, - padding_w, - cost_metrics.forward_time, - cost_metrics.backward_time); - return true; + register_task(CONV2D_BWD_TASK_ID, "Conv2D Bwd", bwd, backward_task); } } // namespace FlexFlow diff --git a/lib/runtime/src/ops/conv_2d.h b/lib/runtime/src/ops/conv_2d.h index 382538b70a..ae76778670 100644 --- a/lib/runtime/src/ops/conv_2d.h +++ b/lib/runtime/src/ops/conv_2d.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_CONV_2D_H #include "op-attrs/ops/conv_2d.h" -#include "op_task_invocation.h" +#include "task_spec/op_task_invocation.h" #include "sim_environment.h" namespace FlexFlow { @@ -134,3 +134,702 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, } // namespace FlexFlow #endif + + + +// Tensor FFModel::conv2d(Tensor const &input, +// int outChannels, +// int kernelH, +// int kernelW, +// int strideH, +// int strideW, +// int paddingH, +// int paddingW, +// ActiMode activation, +// int groups, +// bool use_bias, +// Layer const *shared_op, +// Initializer *kernel_initializer, +// Initializer *bias_initializer, +// char const *name) { +// assert(input->num_dims() == 4); /*NCHW*/ + +// Conv2DAttrs attrs = {outChannels, +// kernelH, +// kernelW, +// strideH, +// strideW, +// paddingH, +// paddingW, +// groups, +// activation, +// use_bias}; + +// TensorShape output_shape = get_output_shape(attrs, input->get_shape()); +// Tensor output = this->tensor_mgr.create(output_shape, CreateGrad::YES, conv); + +// std::vector weights; + +// TensorShape kernel_shape = get_kernel_shape(attrs, input->get_shape()); +// weights.push_back(this->tensor_mgr.create( +// kernel_shape, CreateGrad::YES, kernel_initializer, CHOSEN_SYNC_TYPE)); + +// if (use_bias) { +// TensorShape bias_shape = get_bias_shape(attrs, input->get_shape()); +// weights.push_back(this->tensor_mgr.create( +// bias_shape, CreateGrad::YES, bias_initializer, CHOSEN_SYNC_TYPE)); +// } + +// Layer *conv = +// this->layer_mgr.create(attrs, DT_FLOAT, name, {input}, weights, {output}); + +// //{ +// // int numdims = 4; +// // int dims[MAX_TENSOR_DIM]; +// // dims[3] = input->dims[3]; +// // dims[2] = outChannels; +// // dims[1] = 1 + (input->dims[1] + 2 * paddingH - kernelH) / strideH; +// // dims[0] = 1 + (input->dims[0] + 2 * paddingW - kernelW) / strideW; +// // conv->outputs[0] = create_tensor_legion_ordering( +// // numdims, dims, DT_FLOAT, conv, 0, true /*create_grad*/); +// //} +// //{ +// // int dims[4] = {kernelW, kernelH, input->dims[2], outChannels}; +// // conv->weights[0] = create_weight_legion_ordering(4, +// // dims, +// // DT_FLOAT, +// // conv, +// // true /*create_grad*/, +// // kernel_initializer, +// // CHOSEN_SYNC_TYPE); +// //} +// // if (use_bias) { +// // int dims[1] = {outChannels}; +// // conv->weights[1] = create_weight_legion_ordering(1, +// // dims, +// // DT_FLOAT, +// // conv, +// // true /*create_grad*/, +// // bias_initializer, +// // CHOSEN_SYNC_TYPE); +// //} +// conv->add_initializer("kernel", kernel_initializer); +// conv->add_initializer("bias", bias_initializer); +// /* layers.push_back(conv); */ +// return conv->outputs[0]; +// } + +// Op *Conv2D::create_operator_from_layer( +// FFModel &model, +// Layer const *layer, +// std::vector const &inputs) { +// return new Conv2D(model, +// get(layer->attrs), +// inputs, +// layer->name, +// false /*allocate_weights*/ +// ); +// } + + +/* void Conv2DParams::mark_replica_dims( */ +/* ParallelTensorShape const &input, */ +/* ParallelDim output_dims[MAX_TENSOR_DIM], */ +/* ParallelDim kernel_dims[MAX_TENSOR_DIM], */ +/* ParallelDim bias_dims[MAX_TENSOR_DIM]) const { */ +/* if (output_dims != nullptr) { */ +/* output_dims[Conv2DOutput::REPLICA].is_replica_dim = true; */ +/* } */ +/* if (kernel_dims != nullptr) { */ +/* kernel_dims[Conv2DOutput::REPLICA].is_replica_dim = true; */ +/* } */ +/* if (bias_dims != nullptr) { */ +/* bias_dims[Conv2DBias::REPLICA_1].is_replica_dim = true; */ +/* bias_dims[Conv2DBias::REPLICA_2].is_replica_dim = true; */ +/* bias_dims[Conv2DBias::REPLICA_3].is_replica_dim = true; */ +/* bias_dims[Conv2DBias::REPLICA_4].is_replica_dim = true; */ +/* } */ +/* } */ + +/* int Conv2DParams::output_size(ParallelTensorShape const &input, */ +/* ParallelDim output_dims[MAX_TENSOR_DIM]) const + * { */ +/* int input_w = input.dims[Conv2DInput::WIDTH].size; */ +/* int input_h = input.dims[Conv2DInput::HEIGHT].size; */ + +/* output_dims[Conv2DOutput::SAMPLE].size = + * input.dims[Conv2DInput::SAMPLE].size; */ +/* output_dims[Conv2DOutput::CHANNEL].size = out_channels; */ +/* output_dims[Conv2DOutput::HEIGHT].size = */ +/* 1 + (input_h + 2 * padding_h - kernel_h) / stride_h; */ +/* output_dims[Conv2DOutput::WIDTH].size = */ +/* 1 + (input_w + 2 * padding_w - kernel_w) / stride_w; */ + +/* return input.num_dims; */ +/* }; */ + +/* int Conv2DParams::kernel_size(ParallelTensorShape const &input, */ +/* ParallelDim kernel_dims[MAX_TENSOR_DIM]) const + * { */ +/* kernel_dims[Conv2DKernel::CHANNEL_OUT].size = this->out_channels; */ +/* kernel_dims[Conv2DKernel::CHANNEL_IN].size = */ +/* input.dims[Conv2DInput::CHANNEL].size / this->groups; */ +/* kernel_dims[Conv2DKernel::HEIGHT].size = */ +/* this->kernel_h * input.dims[Conv2DInput::HEIGHT].degree; */ +/* kernel_dims[Conv2DKernel::WIDTH].size = */ +/* this->kernel_w * input.dims[Conv2DInput::WIDTH].degree; */ + +/* return Conv2DKernel::NUMDIM; */ +/* } */ + +/* int Conv2DParams::bias_size(ParallelTensorShape const &input, */ +/* ParallelDim bias_dims[MAX_TENSOR_DIM]) const { */ +/* bias_dims[Conv2DBias::CHANNEL].size = this->out_channels; */ + +/* return Conv2DBias::NUMDIM; */ +/* }; */ + +/* void Conv2DParams::solve_dims(ParallelTensorShape const &input, */ +/* ParallelDim output_dims[MAX_TENSOR_DIM], */ +/* int *output_ndims, */ +/* ParallelDim kernel_dims[MAX_TENSOR_DIM], */ +/* int *kernel_ndims, */ +/* ParallelDim bias_dims[MAX_TENSOR_DIM], */ +/* int *bias_ndims) const { */ +/* assert((output_dims == nullptr) == (output_ndims == nullptr)); */ +/* assert((kernel_dims == nullptr) == (kernel_ndims == nullptr)); */ +/* assert((bias_dims == nullptr) == (bias_ndims == nullptr)); */ + +/* std::vector mapping; */ +/* Conv2D::construct_mappings(mapping, this->use_bias); */ + +/* this->mark_replica_dims(input, output_dims, kernel_dims, bias_dims); */ + +/* std::vector output_dim_sets; */ +/* if (output_dims != nullptr) { */ +/* output_dim_sets.push_back(output_dims); */ +/* } */ + +/* std::vector weight_dim_sets; */ +/* if (kernel_dims != nullptr) { */ +/* weight_dim_sets.push_back(kernel_dims); */ +/* } */ +/* if (bias_dims != nullptr && this->use_bias) { */ +/* weight_dim_sets.push_back(bias_dims); */ +/* } */ + +/* solve_parallel_dim_mappings( */ +/* mapping, {input.dims}, weight_dim_sets, output_dim_sets); */ + +/* if (output_dims != nullptr) { */ +/* *output_ndims = this->output_size(input, output_dims); */ +/* } */ +/* if (kernel_dims != nullptr) { */ +/* *kernel_ndims = this->kernel_size(input, kernel_dims); */ +/* } */ +/* if (bias_dims != nullptr && this->use_bias) { */ +/* *bias_ndims = this->bias_size(input, bias_dims); */ +/* } */ +/* } */ + +/*static*/ +/* void Conv2D::construct_mappings(std::vector &out, + */ +/* bool use_bias) { */ +/* Conv2D::construct_output_mappings(out); */ +/* Conv2D::construct_weight_mappings(out, use_bias); */ +/* } */ + +/*static*/ +/* void Conv2D::construct_output_mappings( */ +/* std::vector &out) { */ +/* Op::construct_output_parallel_dims( */ +/* out, */ +/* {{Conv2DInput::CHANNEL, */ +/* MappingOperation::REPLICATE, */ +/* Conv2DOutput::REPLICA}, */ +/* {Conv2DInput::SAMPLE, MappingOperation::PARTITION, + * Conv2DOutput::SAMPLE}, */ +/* {Conv2DInput::REPLICA, */ +/* MappingOperation::PARTITION, */ +/* Conv2DOutput::CHANNEL}, */ +/* {Conv2DInput::HEIGHT, MappingOperation::PARTITION, + * Conv2DOutput::HEIGHT}, */ +/* {Conv2DInput::WIDTH, MappingOperation::PARTITION, + * Conv2DOutput::WIDTH}}); */ +/* } */ + +/*static*/ +/* void Conv2D::construct_weight_mappings( */ +/* std::vector &out, bool use_bias) { */ +/* Op::construct_weight_parallel_dims( */ +/* out, */ +/* { */ +/* {Conv2DInput::REPLICA, */ +/* MappingOperation::PARTITION, */ +/* Conv2DKernel::CHANNEL_OUT}, */ +/* {Conv2DInput::SAMPLE, */ +/* MappingOperation::REPLICATE, */ +/* Conv2DKernel::REPLICA}, */ +/* {Conv2DInput::CHANNEL, */ +/* MappingOperation::PARTITION, */ +/* Conv2DKernel::CHANNEL_IN}, */ +/* {Conv2DInput::HEIGHT, */ +/* MappingOperation::REPLICATE, */ +/* Conv2DKernel::HEIGHT}, // Kernel::{HEIGHT, WEIGHT} would both work + */ +/* // here */ +/* {Conv2DInput::WIDTH, */ +/* MappingOperation::REPLICATE, */ +/* Conv2DKernel::WIDTH}, // same as above */ +/* }, */ +/* Conv2DInput::INDEX, */ +/* Conv2DKernel::INDEX); */ + +/* if (use_bias) { */ +/* Op::construct_weight_parallel_dims( */ +/* out, */ +/* {{Conv2DInput::REPLICA, Conv2DBias::REPLICA_1}, */ +/* {Conv2DInput::SAMPLE, Conv2DBias::REPLICA_2}, */ +/* {Conv2DInput::CHANNEL, Conv2DBias::CHANNEL}, */ +/* {Conv2DInput::HEIGHT, Conv2DBias::REPLICA_3}, */ +/* {Conv2DInput::WIDTH, Conv2DBias::REPLICA_4}}, */ +/* Conv2DInput::INDEX, */ +/* Conv2DBias::INDEX); */ +/* } */ +/* } */ + + + +// Conv2D::Conv2D(FFModel &model, +// Conv2D const &other, +// const ParallelTensor input, +// bool allocate_weights) +// : Conv2D(model, +// other.layer_guid, +// input, +// other.out_channels, +// other.kernel_h, +// other.kernel_w, +// other.stride_h, +// other.stride_w, +// other.padding_h, +// other.padding_w, +// other.activation, +// other.groups, +// other.use_bias, +// allocate_weights, +// other.name) {} + +// Conv2D::Conv2D(FFModel &model, +// Conv2DAttrs const &attrs, +// std::vector const &inputs, +// char const *name, +// bool allocate_weights) +// : Conv2D(model, +// params.layer_guid, +// input, +// params.out_channels, +// params.kernel_h, +// params.kernel_w, +// params.stride_h, +// params.stride_w, +// params.padding_h, +// params.padding_w, +// params.activation, +// params.groups, +// params.use_bias, +// allocate_weights, +// name) {} + +/* bool Conv2DParams::is_valid(ParallelTensorShape const &input) const { */ +/* ParallelTensorShape output_shape, kernel_shape, bias_shape; */ +/* this->solve_dims(input, */ +/* output_shape.dims, */ +/* &output_shape.num_dims, */ +/* kernel_shape.dims, */ +/* &kernel_shape.num_dims, */ +/* bias_shape.dims, */ +/* &bias_shape.num_dims); */ +/* bool is_valid = true; */ +/* is_valid &= input.is_valid(); */ +/* is_valid &= output_shape.is_valid(); */ +/* is_valid &= kernel_shape.is_valid(); */ +/* if (use_bias) { */ +/* is_valid &= bias_shape.is_valid(); */ +/* } */ + +/* // TODO FIXME: Currently disable parallelizing the height and width + * dimension */ +/* if (input.dims[0].degree > 1 || input.dims[1].degree > 1) { */ +/* return false; */ +/* } */ + +/* return is_valid; */ +/* } */ + + + +// Conv2D::Conv2D(FFModel &model, +// LayerID const &_layer_guid, +// const ParallelTensor input, +// int outChannels, +// int kernelH, +// int kernelW, +// int strideH, +// int strideW, +// int paddingH, +// int paddingW, +// ActiMode activation, +// int groups, +// bool use_bias, +// bool allocate_weights, +// char const *name) +// : Op(model, +// OP_CONV2D, +// DT_FLOAT, +// name, +// 1 /*inputs*/, +// use_bias ? 2 : 1 /*weights*/, +// allocate_weights, +// 1 /*outputs*/, +// input), +// in_channels(input->dims[Conv2DInput::CHANNEL].size / +// input->dims[Conv2DInput::CHANNEL].degree), +// out_channels(outChannels), kernel_h(kernelH), kernel_w(kernelW), +// stride_h(strideH), stride_w(strideW), padding_h(paddingH), +// padding_w(paddingW), activation(activation), groups(groups), +// use_bias(use_bias) { +// // overwrite layer_guid +// layer_guid = _layer_guid; +// assert(input->num_dims == Conv2DInput::NUMDIM); +// assert(this->stride_h > 0); +// assert(this->stride_w > 0); + +// ParallelDim output_dims[MAX_TENSOR_DIM], kernel_dims[MAX_TENSOR_DIM], +// bias_dims[MAX_TENSOR_DIM]; +// int output_ndims, kernel_ndims, bias_ndims; + +// this->construct_mappings(*this->parallel_dims_mapping, this->use_bias); +// this->get_params().solve_dims(this->inputs[0]->get_shape(), +// output_dims, +// &output_ndims, +// kernel_dims, +// &kernel_ndims, +// bias_dims, +// &bias_ndims); + +// if (allocate_weights) { +// Initializer *kernel_initializer = new GlorotUniform(std::rand() /*seed*/); + +// weights[Conv2DKernel::INDEX] = +// model.create_parallel_weight_legion_ordering(kernel_ndims, +// kernel_dims, +// DT_FLOAT, +// NULL /*owner_op*/, +// true /*create_grad*/, +// kernel_initializer, +// CHOSEN_SYNC_TYPE); + +// if (use_bias) { +// Initializer *bias_initializer = new ZeroInitializer(); + +// weights[Conv2DBias::INDEX] = +// model.create_parallel_weight_legion_ordering(bias_ndims, +// bias_dims, +// DT_FLOAT, +// NULL /*owner_op*/, +// true /*create_grad*/, +// bias_initializer, +// CHOSEN_SYNC_TYPE); +// } +// } + +// outputs[0] = model.create_parallel_tensor_legion_ordering( +// output_ndims, output_dims, DT_FLOAT, this); + +// assert(check_output_input_weight_parallel_dims(allocate_weights)); +// } + + +// tl::optional Conv2D::as_dot() const { +// RecordFormatter rr; +// RecordFormatter r; + +// r << this->inputs[0]->get_shape().as_dot(); +// r << "in_channels" << this->in_channels; +// r << "out_channels" << this->out_channels; +// r << "kernel_h" << this->kernel_h; +// r << "kernel_w" << this->kernel_w; +// r << "padding_h" << this->padding_h; +// r << "padding_w" << this->padding_w; +// r << "stride_h" << this->stride_h; +// r << "stride_w" << this->stride_w; +// r << this->outputs[0]->get_shape().as_dot(); +// rr << r; + +// return rr; +// } + + +// bool Conv2D::estimate_sync_cost(Simulator *sim, +// MachineView const &view, +// CostMetrics &cost_metrics) const { +// ParallelDim kernel_dims[MAX_TENSOR_DIM], bias_dims[MAX_TENSOR_DIM]; +// int kernel_ndims, bias_ndims; + +// this->get_params().solve_dims(this->inputs[0]->get_shape(), +// nullptr, +// nullptr, +// kernel_dims, +// &kernel_ndims, +// bias_dims, +// &bias_ndims); + +// cost_metrics.sync_time = +// sim->default_estimate_sync_cost(kernel_dims, kernel_ndims, view); + +// if (this->use_bias) { +// cost_metrics.sync_time += +// sim->default_estimate_sync_cost(bias_dims, bias_ndims, view); +// } + +// return true; +// } + +/* + region(I): input + region(I/O): input_grad (if trainableInputs[0]) + region(I): output + region(I/O): output_grad + region(I): filter + region(I/O): filter_grad + region(I/O): bias_grad (if use_bias) +*/ + +/* + regions[0](I): input + regions[1](O): output + regions[2](I): filter + regions[3](I): bias +*/ + + +// void Conv2D::forward(FFModel const &ff) { +// this->execute_task(ff, CONV2D_FWD_TASK_ID, get_fwd_task_signature()); +// } + +// void Conv2D::backward(FFModel const &ff) { +// this->execute_task(ff, CONV2D_bWD_TASK_ID, get_bwd_task_signature()); +// } + + +// TaskSpec Conv2D::get_tasks_spec() const { +// OpTasksSpec spec { +// CONV2D_INIT_TASK_ID, +// CONV2D_FWD_TASK_ID, +// CONV2D_BWD_TASK_ID +// }; +// auto &fwd = spec.get_fwd(); + +// fwd.add_input_slot(INPUT); +// fwd.add_param_slot(KERNEL); +// fwd.add_output_slot(OUTPUT); + +// auto input = spec.input_tensor(0); +// auto kernel = spec.param_tensor(0); +// auto bias = spec.param_tensor(1); +// auto output = spec.output_tensor(0); + +// fwd[INPUT] = input; +// fwd[KERNEL] = kernel; +// if (this->use_bias) { +// fwd[BIAS] = bias; +// } +// fwd[OUTPUT] = output; + +// return spec; +// } + +/* TaskSpec Conv2D::get_forward_task_spec() const { */ +/* TaskSpec spec = { CONV2D_FWD_TASK_ID, Pass::FWD }; */ + +/* auto input = spec.add_tensor(TensorRole::INPUT, 0); */ +/* auto kernel = spec.add_tensor(TensorRole::PARAM, 0); */ +/* auto bias = spec.add_tensor(TensorRole::BIAS, 1); */ +/* auto output = spec.add_tensor(TensorRole::OUTPUT, 0); */ + +/* spec.add_input(INPUT, input); */ +/* spec.add_input(KERNEL, kernel); */ + +/* if (this->use_bias) { */ +/* spec.add_input(BIAS, bias); */ +/* } */ + +/* spec.add_output(OUTPUT, output); */ + +/* return spec; */ +/* } */ + +/* TaskSpec Conv2D::get_backward_task_spec() const { */ +/* TaskSpec spec = { CONV2D_BWD_TASK_ID, Pass::BWD }; */ + +/* auto input = spec.add_tensor(TensorRole::INPUT, 0); */ +/* auto kernel = spec.add_tensor(TensorRole::PARAM, 0); */ +/* auto bias = spec.add_tensor(TensorRole::BIAS, 1); */ +/* auto output = spec.add_tensor(TensorRole::OUTPUT, 0); */ + +/* spec.add_input(INPUT, input); */ +/* spec.add_output(INPUT_GRAD, input.grad); */ +/* spec.add_input(KERNEL, kernel); */ +/* spec.add_output(KERNEL_GRAD, kernel.grad); */ + +/* if (this->use_bias) { */ +/* spec.add_input(BIAS, bias); */ +/* spec.add_output(BIAS_GRAD, bias.grad); */ +/* } */ + +/* spec.add_input(OUTPUT, output); */ +/* spec.add_input(OUTPUT_GRAD, output.grad); */ + +/* return spec; */ +/* } */ + + +/* + regions[0]: input + regions[1]: output + regions[2](I): filter + regions[3](I): bias + regions[4](O): filter_grad + regions[5](O): input_grad +*/ + + +// void Conv2D::init(FFModel const &ff) { +// this->execute_task(ff, CONV2D_INIT_TASK_ID, get_init_task_signature()); + // assert(check_output_input_weight_same_parallel_is()); + // parallel_is = outputs[0]->parallel_is; + // ArgumentMap argmap; + // Context ctx = ff.config.lg_ctx; + // Runtime *runtime = ff.config.lg_hlr; + // set_argumentmap_for_init(ff, argmap); + // IndexLauncher launcher(CONV2D_INIT_TASK_ID, + // parallel_is, + // TaskArgument(this, sizeof(Conv2D)), + // argmap, + // Predicate::TRUE_PRED, + // false /*must*/, + // 0 /*mapper_id*/, + // outputs[0]->machine_view.hash()); + // launcher.add_region_requirement(RegionRequirement(inputs[0]->part, + // 0 /*projection id*/, + // READ_ONLY, + // EXCLUSIVE, + // inputs[0]->region)); + // launcher.add_field(0, FID_DATA); + // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, + // 0 /*projection id*/, + // WRITE_ONLY, + // EXCLUSIVE, + // outputs[0]->region)); + // launcher.add_field(1, FID_DATA); + // launcher.add_region_requirement(RegionRequirement(weights[0]->part, + // 0 /*projection id*/, + // READ_ONLY, + // EXCLUSIVE, + // weights[0]->region)); + // launcher.add_field(2, FID_DATA); + // // launcher.add_region_requirement( + // // RegionRequirement(weights[1]->part, 0/*projection id*/, + // // READ_ONLY, EXCLUSIVE, weights[1]->region)); + // // launcher.add_field(3, FID_DATA); + // launcher.add_region_requirement(RegionRequirement(weights[0]->part_grad, + // 0 /*projection id*/, + // WRITE_ONLY, + // EXCLUSIVE, + // weights[0]->region_grad)); + // launcher.add_field(3, FID_DATA); + // // launcher.add_region_requirement( + // // RegionRequirement(inputs[0]->part_grad, 0/*projection id*/, + // // WRITE_ONLY, EXCLUSIVE, inputs[0]->region_grad)); + // // launcher.add_field(4, FID_DATA); + // FutureMap fm = runtime->execute_index_space(ctx, launcher); + // fm.wait_all_results(); + // set_opmeta_from_futuremap(ff, fm); +// } + + // printf("init conv (input): n(%d) c(%d) h(%d) w(%d)\n", + // input_n, + // input_c, + // input_h, + // input_w); + // printf("init conv (output): n(%d) c(%d) h(%d) w(%d)\n", + // output_n, + // output_c, + // output_h, + // output_w); + + // printf("convDim: padding(%d %d) stride(%d %d)\n", conv->padding_h, + // conv->padding_w, conv->stride_h, conv->stride_w); + // int pad_h = + // ((output_h - 1) * attrs.stride_h + attrs.kernel_h - input_h + 1) / 2; + // int pad_w = + // ((output_w - 1) * attrs.stride_w + attrs.kernel_w - input_w + 1) / 2; + // if (pad_h != attrs.padding_h) { + // printf("Warning: changing conv_padding_h to satisfy output_h size\n"); + // } + // if (pad_w != attrs.padding_w) { + // printf("Warning: changing conv_padding_w to satisfy output_w size\n"); + // } + +// size_t rid = 0; +// TensorAccessorR acc_input( +// regions[rid], task->regions[rid], FID_DATA, ctx, runtime); +// rid++; +// float *acc_input_grad_ptr = NULL; +// if (m->trainableInputs[0]) { +// TensorAccessorW acc_input_grad( +// regions[rid], +// task->regions[rid], +// FID_DATA, +// ctx, +// runtime, +// true /*readOutput*/); +// acc_input_grad_ptr = acc_input_grad.ptr; +// rid++; +// } +// TensorAccessorR acc_output( +// regions[rid], task->regions[rid], FID_DATA, ctx, runtime); +// rid++; +// TensorAccessorW acc_output_grad( +// regions[rid], +// task->regions[rid], +// FID_DATA, +// ctx, +// runtime, +// true /*readOutput*/); +// rid++; +// TensorAccessorR acc_kernel( +// regions[rid], task->regions[rid], FID_DATA, ctx, runtime); +// rid++; +// TensorAccessorW acc_kernel_grad( +// regions[rid], +// task->regions[rid], +// FID_DATA, +// ctx, +// runtime, +// true /*readOutput*/); +// rid++; +// float *acc_bias_grad_ptr = NULL; +// if (m->use_bias) { +// TensorAccessorW acc_bias_grad( +// regions[rid], +// task->regions[rid], +// FID_DATA, +// ctx, +// runtime, +// true /*readOutput*/); +// acc_bias_grad_ptr = static_cast(acc_bias_grad.ptr); +// rid++; +// } \ No newline at end of file From bf750675ebabc1abbf3bad7620455f80a74d4b1a Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Fri, 1 Sep 2023 11:16:07 -0700 Subject: [PATCH 24/26] format --- lib/kernels/include/kernels/combine_kernels.h | 3 +-- lib/runtime/src/ops/combine.cc | 15 +++++---------- lib/runtime/src/ops/combine.h | 18 ++++++++---------- 3 files changed, 14 insertions(+), 22 deletions(-) diff --git a/lib/kernels/include/kernels/combine_kernels.h b/lib/kernels/include/kernels/combine_kernels.h index 44ab67d9a7..24b9adb803 100644 --- a/lib/kernels/include/kernels/combine_kernels.h +++ b/lib/kernels/include/kernels/combine_kernels.h @@ -10,8 +10,7 @@ struct CombinePerDeviceState { req data_type; }; -FF_VISITABLE_STRUCT_NO_EQ(CombinePerDeviceState, - data_type); +FF_VISITABLE_STRUCT_NO_EQ(CombinePerDeviceState, data_type); namespace Kernels { namespace Combine { diff --git a/lib/runtime/src/ops/combine.cc b/lib/runtime/src/ops/combine.cc index c2d8d9a017..a22b317d10 100644 --- a/lib/runtime/src/ops/combine.cc +++ b/lib/runtime/src/ops/combine.cc @@ -26,12 +26,7 @@ using Legion::Task; using namespace FlexFlow::Kernels::Combine; -enum Slots { - INPUT, - OUTPUT, - PROFILING, - PER_DEVICE_STATE -}; +enum Slots { INPUT, OUTPUT, PROFILING, PER_DEVICE_STATE }; OpTaskInvocation init(CombineAttrs const &attrs) { OpTaskBinding binding; @@ -44,7 +39,8 @@ OpTaskInvocation init(CombineAttrs const &attrs) { OpTaskInvocation forward(CombineAttrs const &attrs) { OpTaskBinding binding; - binding.bind_arg(PER_DEVICE_STATE, per_device_op_state()); + binding.bind_arg(PER_DEVICE_STATE, + per_device_op_state()); binding.bind_arg(PROFILING, profiling_settings()); binding.bind(INPUT, input_tensor(0)); @@ -61,7 +57,7 @@ OpTaskInvocation backward(CombineAttrs const &attrs) { static DeviceSpecific init_task_impl(TaskArgumentAccessor const &acc) { - + auto input = acc.get_tensor(INPUT); DeviceSpecific per_device_state = @@ -171,5 +167,4 @@ void register_task() { }; // namespace FlexFlow -namespace std { -}; // namespace std +namespace std {}; // namespace std diff --git a/lib/runtime/src/ops/combine.h b/lib/runtime/src/ops/combine.h index 4b96ce6495..512c3be363 100644 --- a/lib/runtime/src/ops/combine.h +++ b/lib/runtime/src/ops/combine.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_COMBINE_H #include "op-attrs/ops/combine.h" -#include "task_spec/op_task_invocation.h" #include "sim_environment.h" +#include "task_spec/op_task_invocation.h" namespace FlexFlow { @@ -112,7 +112,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, // output_grad_ptr, input_grad_ptr, output_grad_domain.get_volume()); // } - // void Combine::backward_task(Task const *task, // std::vector const ®ions, // Context ctx, @@ -171,7 +170,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, // return rf; // } - // void Combine::init(FFModel const &ff) { // parallel_is = outputs[0]->parallel_is; // ArgumentMap argmap; @@ -188,7 +186,8 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, // 0 /*mapper_id*/, // outputs[0]->machine_view.hash()); // launcher.add_region_requirement(RegionRequirement( -// input_lp, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, inputs[0]->region)); +// input_lp, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, +// inputs[0]->region)); // launcher.add_field(0, FID_DATA); // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, // 0 /*projection id*/, @@ -232,7 +231,8 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, // 0 /*mapper_id*/, // outputs[0]->machine_view.hash()); // launcher.add_region_requirement(RegionRequirement( -// input_lp, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, inputs[0]->region)); +// input_lp, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, +// inputs[0]->region)); // launcher.add_field(0, FID_DATA); // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, // 0 /*projection id*/, @@ -274,8 +274,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, // runtime->execute_index_space(ctx, launcher); // } - - // CombineParams Combine::get_params() const { // CombineParams params; // params.combine_legion_dim = this->combine_dim; @@ -338,9 +336,9 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, // template // void Combine::forward_task_with_type(Task const *task, -// std::vector const ®ions, -// Context ctx, -// Runtime *runtime) { +// std::vector const +// ®ions, Context ctx, Runtime *runtime) +// { // Domain input_domain = runtime->get_index_space_domain( // ctx, task->regions[0].region.get_index_space()); // Domain output_domain = runtime->get_index_space_domain( From d1f4fb91729c7bf9cfe612c46cf33707c0ae656a Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Fri, 1 Sep 2023 11:17:05 -0700 Subject: [PATCH 25/26] format --- lib/kernels/include/kernels/concat_kernels.h | 11 ++-- lib/runtime/src/ops/concat.cc | 58 ++++++++++--------- lib/runtime/src/ops/concat.h | 47 ++++++++------- lib/runtime/src/serialization.h | 7 ++- .../src/task_spec/op_task_invocation.h | 4 +- lib/runtime/src/task_spec/op_tensor_spec.h | 1 - .../src/task_spec/variadic_tensor_ref.h | 6 +- 7 files changed, 67 insertions(+), 67 deletions(-) diff --git a/lib/kernels/include/kernels/concat_kernels.h b/lib/kernels/include/kernels/concat_kernels.h index 1bbefdf843..165f63f332 100644 --- a/lib/kernels/include/kernels/concat_kernels.h +++ b/lib/kernels/include/kernels/concat_kernels.h @@ -23,11 +23,12 @@ void forward_kernel(ffStream_t stream, std::vector const &inputs, int num_inputs); -void backward_kernel(ffStream_t stream, - ConcatPerDeviceState const *m, - GenericTensorAccessorR const &output_grad, - std::vector const &input_grads, - int num_inputs); +void backward_kernel( + ffStream_t stream, + ConcatPerDeviceState const *m, + GenericTensorAccessorR const &output_grad, + std::vector const &input_grads, + int num_inputs); } // namespace Concat } // namespace Kernels diff --git a/lib/runtime/src/ops/concat.cc b/lib/runtime/src/ops/concat.cc index 20c54c993e..d8f610c3ea 100644 --- a/lib/runtime/src/ops/concat.cc +++ b/lib/runtime/src/ops/concat.cc @@ -16,9 +16,9 @@ #include "concat.h" #include "kernels/concat_kernels.h" #include "legion/legion_utilities.h" +#include "op-attrs/get_output_shapes.h" #include "task_spec/variadic_tensor_ref.h" #include "utils/hash-utils.h" -#include "op-attrs/get_output_shapes.h" namespace FlexFlow { @@ -50,7 +50,8 @@ OpTaskInvocation init(ConcatAttrs const &attrs) { OpTaskInvocation forward(ConcatAttrs const &attrs) { OpTaskBinding binding; - binding.bind_arg(PER_DEVICE_STATE, per_device_op_state()); + binding.bind_arg(PER_DEVICE_STATE, + per_device_op_state()); binding.bind(INPUTS, get_input_tensors()); binding.bind(OUTPUT, output_tensor(0)); binding.bind(NUM_INPUTS, get_number_inputs()); @@ -70,7 +71,7 @@ static DeviceSpecific auto const &attrs = acc.get_argument(ATTRS); PerDeviceFFHandle handle = acc.get_argument(HANDLE); - DeviceSpecific per_device_state = + DeviceSpecific per_device_state = acc.create_device_specific(init_kernel(attrs.axis)); return per_device_state; } @@ -85,7 +86,8 @@ static DeviceSpecific } static optional forward_task_impl(TaskArgumentAccessor const &acc) { - auto per_device_state = acc.get_argument(PER_DEVICE_STATE); + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); int number_of_inputs = acc.get_argument(NUM_INPUTS); ProfilingSettings profiling = acc.get_argument(PROFILING); @@ -93,12 +95,12 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { auto inputs = acc.get_variadic_tensor(INPUTS); return profile(forward_kernel, - profiling, - "[Concat] forward_time = %.2lfms\n", - &per_device_state, - output, - inputs, - number_of_inputs); + profiling, + "[Concat] forward_time = %.2lfms\n", + &per_device_state, + output, + inputs, + number_of_inputs); } static void forward_task(Task const *task, @@ -110,7 +112,8 @@ static void forward_task(Task const *task, } static optional backward_task_impl(TaskArgumentAccessor const &acc) { - auto per_device_state = acc.get_argument(PER_DEVICE_STATE); + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); int number_of_inputs = acc.get_argument(NUM_INPUTS); ProfilingSettings profiling = acc.get_argument(PROFILING); @@ -120,12 +123,12 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(number_of_inputs <= MAX_NUM_INPUTS); return profile(backward_kernel, - profiling, - "[Concat] backward_time = %.2lfms\n", - &per_device_state, - output_grad, - input_grads, - number_of_inputs); + profiling, + "[Concat] backward_time = %.2lfms\n", + &per_device_state, + output_grad, + input_grads, + number_of_inputs); } static void backward_task(Task const *task, @@ -136,24 +139,25 @@ static void backward_task(Task const *task, backward_task_impl(acc); } -CostMetrics measure_operator_cost(SimEnvFactory const &sim, - ConcatAttrs const &attrs, - InputVariadicParallelTensorDesc const &inputs_shape, - ProfilingSettings const &settings, - MachineView const &mv) { +CostMetrics + measure_operator_cost(SimEnvFactory const &sim, + ConcatAttrs const &attrs, + InputVariadicParallelTensorDesc const &inputs_shape, + ProfilingSettings const &settings, + MachineView const &mv) { int numInputs = (inputs_shape.shapes).size(); assert(numInputs <= MAX_NUM_INPUTS); auto env = sim.new_environment(); - ParallelTensorShape output_shape = get_output_shape(attrs, inputs_shape.shapes); + ParallelTensorShape output_shape = + get_output_shape(attrs, inputs_shape.shapes); SimTaskBinding init_binding; init_binding.bind_arg(PROFILING, settings); init_binding.bind_arg(ATTRS, attrs); - auto init_accessor = - env.get_init_accessor(CONCAT_INIT_TASK_ID, init_binding); + auto init_accessor = env.get_init_accessor(CONCAT_INIT_TASK_ID, init_binding); DeviceSpecific per_device_state = init_task_impl(init_accessor); @@ -174,7 +178,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, float sync_time = default_estimate_sync_time(env); return make_metrics(forward_time, backward_time, sync_time, env); - } template <> @@ -183,7 +186,7 @@ void register_task() { init.add_arg_slot(ATTRS); init.add_arg_slot(PROFILING); - + register_task(CONCAT_INIT_TASK_ID, "Concat Init", init, init_task); } @@ -208,5 +211,4 @@ void register_task() { register_task(CONCAT_BWD_TASK_ID, "BatchMatmul Bwd", bwd, backward_task); } - }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/concat.h b/lib/runtime/src/ops/concat.h index 5f792566b4..0493006345 100644 --- a/lib/runtime/src/ops/concat.h +++ b/lib/runtime/src/ops/concat.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_CONCAT_H #include "op-attrs/ops/concat.h" -#include "task_spec/op_task_invocation.h" #include "sim_environment.h" +#include "task_spec/op_task_invocation.h" namespace FlexFlow { @@ -77,7 +77,6 @@ CostMetrics #endif - // bool operator==(ConcatParams const &lhs, ConcatParams const &rhs) { // return lhs.axis == rhs.axis; // } @@ -89,7 +88,8 @@ CostMetrics // } // Tensor -// FFModel::concat(int n, Tensor const *tensors, int axis, char const *name) { +// FFModel::concat(int n, Tensor const *tensors, int axis, char const *name) +// { // Layer *concat = new Layer(this, // OP_CONCAT, // DT_FLOAT, @@ -178,7 +178,6 @@ CostMetrics // char const *name) // : Concat(model, inputs.size(), inputs.data(), params.axis, name) {} - // void Concat::init(FFModel const &ff) { // this->execute_task(ff, CONCAT_INIT_TASK_ID, get_init_task_signature()); // // assert(check_output_input_weight_same_parallel_is()); @@ -210,11 +209,11 @@ CostMetrics // // launcher.add_field(i + 1, FID_DATA); // // } // // for (int i = 0; i < numInputs; i++) { -// // launcher.add_region_requirement(RegionRequirement(inputs[i]->part_grad, +// // launcher.add_region_requirement(RegionRequirement(inputs[i]->part_grad, // // 0 /*projection id*/, // // WRITE_ONLY, // // EXCLUSIVE, -// // inputs[i]->region_grad)); +// // inputs[i]->region_grad)); // // launcher.add_field(i + numInputs + 1, FID_DATA); // // } // // FutureMap fm = runtime->execute_index_space(ctx, launcher); @@ -222,7 +221,6 @@ CostMetrics // // set_opmeta_from_futuremap(ff, fm); // } - // void Concat::forward(FFModel const &ff) { // this->execute_task(ff, CONCAT_FWD_TASK_ID, get_fwd_task_signature()); // // ArgumentMap argmap; @@ -259,7 +257,6 @@ CostMetrics regions[1..numInputs](I): inputs */ - // void Concat::backward(FFModel const &ff) { // this->execute_task(ff, CONCAT_BWD_TASK_ID, get_bwd_task_signature()); // // ArgumentMap argmap; @@ -278,14 +275,14 @@ CostMetrics // // 0 /*projection id*/, // // READ_ONLY, // // EXCLUSIVE, -// // outputs[0]->region_grad)); +// // outputs[0]->region_grad)); // // launcher.add_field(0, FID_DATA); // // for (int i = 0; i < numInputs; i++) { -// // launcher.add_region_requirement(RegionRequirement(inputs[i]->part_grad, +// // launcher.add_region_requirement(RegionRequirement(inputs[i]->part_grad, // // 0 /*projection id*/, // // READ_WRITE, // // EXCLUSIVE, -// // inputs[i]->region_grad)); +// // inputs[i]->region_grad)); // // // LogicalRegion lr = inputs[i]->region_grad; // // // printf("concat[%d]: region(%d,%d,%d)\n", i+1, // // // lr.get_index_space().get_id(), lr.get_field_space().get_id(), @@ -310,7 +307,6 @@ CostMetrics // } // } - // bool Concat::measure_operator_cost(Simulator *sim, // MachineView const &mv, // CostMetrics &cost_metrics) const { @@ -337,12 +333,14 @@ CostMetrics // (float *)sim->allocate(sub_inputs[i].get_volume(), DT_FLOAT); // out_of_memory = out_of_memory || (input_ptrs[i] == NULL); // } -// cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); +// cost_metrics.inputs_memory += +// cost_metrics.total_mem_diff_from(sim->offset); // Domain out_domain = sub_output.get_domain(); -// float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); -// GenericTensorAccessorW output_acc(DT_FLOAT, out_domain, output_ptr); -// cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); +// float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), +// DT_FLOAT); GenericTensorAccessorW output_acc(DT_FLOAT, out_domain, +// output_ptr); cost_metrics.outputs_memory += +// cost_metrics.total_mem_diff_from(sim->offset); // out_of_memory = out_of_memory || (output_ptr == NULL); // if (out_of_memory) { @@ -372,10 +370,11 @@ CostMetrics // (float *)sim->allocate(sub_inputs[i].get_volume(), DT_FLOAT); // out_of_memory = out_of_memory || (input_grad_ptrs[i] == NULL); // input_grad_accs[i] = -// GenericTensorAccessorW(DT_FLOAT, in_domains[i], input_grad_ptrs[i]); +// GenericTensorAccessorW(DT_FLOAT, in_domains[i], +// input_grad_ptrs[i]); // } -// cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); -// float *output_grad_ptr = +// cost_metrics.inputs_memory += +// cost_metrics.total_mem_diff_from(sim->offset); float *output_grad_ptr = // (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); // GenericTensorAccessorR output_grad_acc( // DT_FLOAT, out_domain, output_grad_ptr); @@ -389,7 +388,8 @@ CostMetrics // return true; // } // backward = [&](ffStream_t stream) { -// backward_kernel(stream, m, output_grad_acc, input_grad_accs, numInputs); +// backward_kernel(stream, m, output_grad_acc, input_grad_accs, +// numInputs); // }; // } @@ -397,9 +397,8 @@ CostMetrics // if (sim->computationMode == COMP_MODE_TRAINING) { // printf( -// "[Measure Concat] name(%s) forward_time(%.4lf) backward_time(%.4lf)\n", -// name, -// cost_metrics.forward_time, +// "[Measure Concat] name(%s) forward_time(%.4lf) +// backward_time(%.4lf)\n", name, cost_metrics.forward_time, // cost_metrics.backward_time); // } else { // printf("[Measure Concat] name(%s) forward_time(%.4lf)\n", @@ -408,4 +407,4 @@ CostMetrics // } // return true; -// } \ No newline at end of file +// } diff --git a/lib/runtime/src/serialization.h b/lib/runtime/src/serialization.h index 6839bfdde5..5c1194c7d6 100644 --- a/lib/runtime/src/serialization.h +++ b/lib/runtime/src/serialization.h @@ -7,10 +7,10 @@ #include "legion/legion_utilities.h" #include "op-attrs/dim_ordered.h" #include "utils/optional.h" +#include "utils/required.h" +#include "utils/type_traits.h" #include "utils/variant.h" #include "utils/visitable.h" -#include "utils/type_traits.h" -#include "utils/required.h" namespace FlexFlow { @@ -79,7 +79,8 @@ struct is_trivially_serializable< : std::true_type {}; template -struct is_trivially_serializable>> : is_trivially_serializable> {}; +struct is_trivially_serializable>> + : is_trivially_serializable> {}; template struct is_trivially_serializable> : is_trivially_serializable {}; diff --git a/lib/runtime/src/task_spec/op_task_invocation.h b/lib/runtime/src/task_spec/op_task_invocation.h index 9203f2f4f1..56e709734e 100644 --- a/lib/runtime/src/task_spec/op_task_invocation.h +++ b/lib/runtime/src/task_spec/op_task_invocation.h @@ -5,9 +5,8 @@ #include "index_task_invocation.h" #include "legion.h" #include "op_arg_ref.h" -#include "op_tensor_spec.h" -#include "variadic_tensor_ref.h" #include "op_task_signature.h" +#include "op_tensor_spec.h" #include "runtime/config.h" #include "runtime/profiling.h" #include "serialization.h" @@ -16,6 +15,7 @@ #include "utils/bidict.h" #include "utils/optional.h" #include "utils/stack_map.h" +#include "variadic_tensor_ref.h" #include #include #include diff --git a/lib/runtime/src/task_spec/op_tensor_spec.h b/lib/runtime/src/task_spec/op_tensor_spec.h index 84261141d3..d859bb3072 100644 --- a/lib/runtime/src/task_spec/op_tensor_spec.h +++ b/lib/runtime/src/task_spec/op_tensor_spec.h @@ -15,7 +15,6 @@ OpTensorSpec input_tensor(int); OpTensorSpec output_tensor(int); OpTensorSpec weight_tensor(int); - } // namespace FlexFlow #endif diff --git a/lib/runtime/src/task_spec/variadic_tensor_ref.h b/lib/runtime/src/task_spec/variadic_tensor_ref.h index 1b9bc33b4f..ddd9bd5069 100644 --- a/lib/runtime/src/task_spec/variadic_tensor_ref.h +++ b/lib/runtime/src/task_spec/variadic_tensor_ref.h @@ -1,13 +1,12 @@ #ifndef _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_VARIADIC_TENSOR_ARG_REF_H #define _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_VARIADIC_TENSOR_ARG_REF_H -#include "op_tensor_spec.h" #include "arg_ref.h" +#include "op_tensor_spec.h" namespace FlexFlow { -enum class VariadicTensorRefType {INPUT_TENSORS, - NUM_INPUTS}; +enum class VariadicTensorRefType { INPUT_TENSORS, NUM_INPUTS }; template using VariadicTensorRef = ArgRef; @@ -20,7 +19,6 @@ VariadicTensorRef get_number_inputs() { return {VariadicTensorRefType::NUM_INPUTS}; } - } // namespace FlexFlow #endif From fdf8351bc7cdcf1c155b5a38f357dae0fe968067 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Fri, 1 Sep 2023 11:20:25 -0700 Subject: [PATCH 26/26] format --- lib/kernels/include/kernels/conv_2d_kernels.h | 22 +-- lib/runtime/src/ops/conv_2d.cc | 79 ++++---- lib/runtime/src/ops/conv_2d.h | 170 +++++++++--------- 3 files changed, 131 insertions(+), 140 deletions(-) diff --git a/lib/kernels/include/kernels/conv_2d_kernels.h b/lib/kernels/include/kernels/conv_2d_kernels.h index dc6bf941b1..75eefbe1c2 100644 --- a/lib/kernels/include/kernels/conv_2d_kernels.h +++ b/lib/kernels/include/kernels/conv_2d_kernels.h @@ -38,17 +38,17 @@ namespace Kernels { namespace Conv2D { Conv2DPerDeviceState init_kernel(PerDeviceFFHandle handle, - ffTensorDescriptor_t inputTensor, - ffTensorDescriptor_t biasTensor, - ffTensorDescriptor_t outputTensor, - ffFilterDescriptor_t filterDesc, - ffActivationDescriptor_t actiDesc, - ffConvolutionDescriptor_t convDesc, - ffConvolutionFwdAlgo_t fwdAlgo, - ffConvolutionBwdFilterAlgo_t bwdFilterAlgo, - ffConvolutionBwdDataAlgo_t bwdDataAlgo, - req> relu, - bool use_bias); + ffTensorDescriptor_t inputTensor, + ffTensorDescriptor_t biasTensor, + ffTensorDescriptor_t outputTensor, + ffFilterDescriptor_t filterDesc, + ffActivationDescriptor_t actiDesc, + ffConvolutionDescriptor_t convDesc, + ffConvolutionFwdAlgo_t fwdAlgo, + ffConvolutionBwdFilterAlgo_t bwdFilterAlgo, + ffConvolutionBwdDataAlgo_t bwdDataAlgo, + req> relu, + bool use_bias); void forward_kernel(ffStream_t stream, Conv2DPerDeviceState const *m, diff --git a/lib/runtime/src/ops/conv_2d.cc b/lib/runtime/src/ops/conv_2d.cc index fc87ed987f..8379da80ff 100644 --- a/lib/runtime/src/ops/conv_2d.cc +++ b/lib/runtime/src/ops/conv_2d.cc @@ -2,8 +2,8 @@ #include "kernels/conv_2d_kernels.h" #include "legion/legion_utilities.h" #include "mpark/variant.hpp" -#include "utils/hash-utils.h" #include "op-attrs/get_output_shapes.h" +#include "utils/hash-utils.h" namespace FlexFlow { @@ -57,10 +57,10 @@ OpTaskInvocation backward(Conv2DAttrs const &attrs) { static DeviceSpecific init_task_impl(TaskArgumentAccessor const &acc) { - + PerDeviceFFHandle handle = acc.get_argument(HANDLE); auto const &attrs = acc.get_argument(ATTRS); - + ffTensorDescriptor_t inputTensor; ffTensorDescriptor_t biasTensor; ffTensorDescriptor_t outputTensor; @@ -71,20 +71,20 @@ static DeviceSpecific ffConvolutionBwdFilterAlgo_t bwdFilterAlgo; ffConvolutionBwdDataAlgo_t bwdDataAlgo; - DeviceSpecific per_device_state = + DeviceSpecific per_device_state = acc.create_device_specific( - init_kernel(handle, - inputTensor, - biasTensor, - outputTensor, - filterDesc, - actiDesc, - convDesc, - fwdAlgo, - bwdFilterAlgo, - bwdDataAlgo, - attrs.activation, - attrs.use_bias)); + init_kernel(handle, + inputTensor, + biasTensor, + outputTensor, + filterDesc, + actiDesc, + convDesc, + fwdAlgo, + bwdFilterAlgo, + bwdDataAlgo, + attrs.activation, + attrs.use_bias)); return per_device_state; } @@ -100,7 +100,8 @@ static DeviceSpecific static optional forward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); - auto per_device_state = acc.get_argument(PER_DEVICE_STATE); + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); auto input = acc.get_tensor(INPUT); auto filter = acc.get_tensor(FILTER); @@ -108,13 +109,13 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { auto output = acc.get_tensor(OUTPUT); return profile(forward_kernel, - profiling, - "[Conv2d] forward_time = %.2lfms\n", - &per_device_state, - input.get_float_ptr(), - output.get_float_ptr(), - filter.get_float_ptr(), - bias.get_float_ptr()); + profiling, + "[Conv2d] forward_time = %.2lfms\n", + &per_device_state, + input.get_float_ptr(), + output.get_float_ptr(), + filter.get_float_ptr(), + bias.get_float_ptr()); } static void forward_task(Task const *task, @@ -127,7 +128,8 @@ static void forward_task(Task const *task, static optional backward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); - auto per_device_state = acc.get_argument(PER_DEVICE_STATE); + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); auto input = acc.get_tensor(INPUT); auto output = acc.get_tensor(OUTPUT); @@ -139,16 +141,16 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { auto bias_grad = acc.get_tensor_grad(BIAS); return profile(backward_kernel, - profiling, - "[Conv2d] backward_time = %.2lfms\n", - &per_device_state, - input.get_float_ptr(), - input_grad.get_float_ptr(), - output.get_float_ptr(), - output_grad.get_float_ptr(), - filter.get_float_ptr(), - filter_grad.get_float_ptr(), - bias_grad.get_float_ptr()); + profiling, + "[Conv2d] backward_time = %.2lfms\n", + &per_device_state, + input.get_float_ptr(), + input_grad.get_float_ptr(), + output.get_float_ptr(), + output_grad.get_float_ptr(), + filter.get_float_ptr(), + filter_grad.get_float_ptr(), + bias_grad.get_float_ptr()); } static void backward_task(Task const *task, @@ -166,7 +168,7 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, InputParallelTensorDesc const &bias_shape, ProfilingSettings const &settings, MachineView const &mv) { - + auto env = sim.new_environment(); ParallelTensorShape output_shape = get_output_shape(attrs, input_shape.shape); @@ -175,15 +177,14 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, init_binding.bind_arg(ATTRS, attrs); init_binding.bind_arg(HANDLE, ff_handle()); - auto init_accessor = - env.get_init_accessor(CONV2D_INIT_TASK_ID, init_binding); + auto init_accessor = env.get_init_accessor(CONV2D_INIT_TASK_ID, init_binding); DeviceSpecific per_device_state = init_task_impl(init_accessor); SimTaskBinding fwd_binding; fwd_binding.bind_arg(PROFILING, settings); fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); - + fwd_binding.bind(INPUT, input_shape); fwd_binding.bind(OUTPUT, output_shape); fwd_binding.bind(FILTER, filter_shape); diff --git a/lib/runtime/src/ops/conv_2d.h b/lib/runtime/src/ops/conv_2d.h index ae76778670..777b491089 100644 --- a/lib/runtime/src/ops/conv_2d.h +++ b/lib/runtime/src/ops/conv_2d.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_CONV_2D_H #include "op-attrs/ops/conv_2d.h" -#include "task_spec/op_task_invocation.h" #include "sim_environment.h" +#include "task_spec/op_task_invocation.h" namespace FlexFlow { @@ -135,8 +135,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, #endif - - // Tensor FFModel::conv2d(Tensor const &input, // int outChannels, // int kernelH, @@ -166,7 +164,8 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, // use_bias}; // TensorShape output_shape = get_output_shape(attrs, input->get_shape()); -// Tensor output = this->tensor_mgr.create(output_shape, CreateGrad::YES, conv); +// Tensor output = this->tensor_mgr.create(output_shape, CreateGrad::YES, +// conv); // std::vector weights; @@ -181,7 +180,8 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, // } // Layer *conv = -// this->layer_mgr.create(attrs, DT_FLOAT, name, {input}, weights, {output}); +// this->layer_mgr.create(attrs, DT_FLOAT, name, {input}, weights, +// {output}); // //{ // // int numdims = 4; @@ -231,7 +231,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, // ); // } - /* void Conv2DParams::mark_replica_dims( */ /* ParallelTensorShape const &input, */ /* ParallelDim output_dims[MAX_TENSOR_DIM], */ @@ -399,8 +398,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, /* } */ /* } */ - - // Conv2D::Conv2D(FFModel &model, // Conv2D const &other, // const ParallelTensor input, @@ -468,8 +465,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, /* return is_valid; */ /* } */ - - // Conv2D::Conv2D(FFModel &model, // LayerID const &_layer_guid, // const ParallelTensor input, @@ -520,7 +515,8 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, // &bias_ndims); // if (allocate_weights) { -// Initializer *kernel_initializer = new GlorotUniform(std::rand() /*seed*/); +// Initializer *kernel_initializer = new GlorotUniform(std::rand() +// /*seed*/); // weights[Conv2DKernel::INDEX] = // model.create_parallel_weight_legion_ordering(kernel_ndims, @@ -551,7 +547,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, // assert(check_output_input_weight_parallel_dims(allocate_weights)); // } - // tl::optional Conv2D::as_dot() const { // RecordFormatter rr; // RecordFormatter r; @@ -571,7 +566,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, // return rr; // } - // bool Conv2D::estimate_sync_cost(Simulator *sim, // MachineView const &view, // CostMetrics &cost_metrics) const { @@ -614,7 +608,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, regions[3](I): bias */ - // void Conv2D::forward(FFModel const &ff) { // this->execute_task(ff, CONV2D_FWD_TASK_ID, get_fwd_task_signature()); // } @@ -623,7 +616,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, // this->execute_task(ff, CONV2D_bWD_TASK_ID, get_bwd_task_signature()); // } - // TaskSpec Conv2D::get_tasks_spec() const { // OpTasksSpec spec { // CONV2D_INIT_TASK_ID, @@ -695,7 +687,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, /* return spec; */ /* } */ - /* regions[0]: input regions[1]: output @@ -705,83 +696,82 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, regions[5](O): input_grad */ - // void Conv2D::init(FFModel const &ff) { // this->execute_task(ff, CONV2D_INIT_TASK_ID, get_init_task_signature()); - // assert(check_output_input_weight_same_parallel_is()); - // parallel_is = outputs[0]->parallel_is; - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_init(ff, argmap); - // IndexLauncher launcher(CONV2D_INIT_TASK_ID, - // parallel_is, - // TaskArgument(this, sizeof(Conv2D)), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(1, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(weights[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // weights[0]->region)); - // launcher.add_field(2, FID_DATA); - // // launcher.add_region_requirement( - // // RegionRequirement(weights[1]->part, 0/*projection id*/, - // // READ_ONLY, EXCLUSIVE, weights[1]->region)); - // // launcher.add_field(3, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(weights[0]->part_grad, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // weights[0]->region_grad)); - // launcher.add_field(3, FID_DATA); - // // launcher.add_region_requirement( - // // RegionRequirement(inputs[0]->part_grad, 0/*projection id*/, - // // WRITE_ONLY, EXCLUSIVE, inputs[0]->region_grad)); - // // launcher.add_field(4, FID_DATA); - // FutureMap fm = runtime->execute_index_space(ctx, launcher); - // fm.wait_all_results(); - // set_opmeta_from_futuremap(ff, fm); +// assert(check_output_input_weight_same_parallel_is()); +// parallel_is = outputs[0]->parallel_is; +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_init(ff, argmap); +// IndexLauncher launcher(CONV2D_INIT_TASK_ID, +// parallel_is, +// TaskArgument(this, sizeof(Conv2D)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(1, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(weights[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// weights[0]->region)); +// launcher.add_field(2, FID_DATA); +// // launcher.add_region_requirement( +// // RegionRequirement(weights[1]->part, 0/*projection id*/, +// // READ_ONLY, EXCLUSIVE, weights[1]->region)); +// // launcher.add_field(3, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(weights[0]->part_grad, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// weights[0]->region_grad)); +// launcher.add_field(3, FID_DATA); +// // launcher.add_region_requirement( +// // RegionRequirement(inputs[0]->part_grad, 0/*projection id*/, +// // WRITE_ONLY, EXCLUSIVE, inputs[0]->region_grad)); +// // launcher.add_field(4, FID_DATA); +// FutureMap fm = runtime->execute_index_space(ctx, launcher); +// fm.wait_all_results(); +// set_opmeta_from_futuremap(ff, fm); // } - // printf("init conv (input): n(%d) c(%d) h(%d) w(%d)\n", - // input_n, - // input_c, - // input_h, - // input_w); - // printf("init conv (output): n(%d) c(%d) h(%d) w(%d)\n", - // output_n, - // output_c, - // output_h, - // output_w); - - // printf("convDim: padding(%d %d) stride(%d %d)\n", conv->padding_h, - // conv->padding_w, conv->stride_h, conv->stride_w); - // int pad_h = - // ((output_h - 1) * attrs.stride_h + attrs.kernel_h - input_h + 1) / 2; - // int pad_w = - // ((output_w - 1) * attrs.stride_w + attrs.kernel_w - input_w + 1) / 2; - // if (pad_h != attrs.padding_h) { - // printf("Warning: changing conv_padding_h to satisfy output_h size\n"); - // } - // if (pad_w != attrs.padding_w) { - // printf("Warning: changing conv_padding_w to satisfy output_w size\n"); - // } +// printf("init conv (input): n(%d) c(%d) h(%d) w(%d)\n", +// input_n, +// input_c, +// input_h, +// input_w); +// printf("init conv (output): n(%d) c(%d) h(%d) w(%d)\n", +// output_n, +// output_c, +// output_h, +// output_w); + +// printf("convDim: padding(%d %d) stride(%d %d)\n", conv->padding_h, +// conv->padding_w, conv->stride_h, conv->stride_w); +// int pad_h = +// ((output_h - 1) * attrs.stride_h + attrs.kernel_h - input_h + 1) / 2; +// int pad_w = +// ((output_w - 1) * attrs.stride_w + attrs.kernel_w - input_w + 1) / 2; +// if (pad_h != attrs.padding_h) { +// printf("Warning: changing conv_padding_h to satisfy output_h size\n"); +// } +// if (pad_w != attrs.padding_w) { +// printf("Warning: changing conv_padding_w to satisfy output_w size\n"); +// } // size_t rid = 0; // TensorAccessorR acc_input( @@ -832,4 +822,4 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, // true /*readOutput*/); // acc_bias_grad_ptr = static_cast(acc_bias_grad.ptr); // rid++; -// } \ No newline at end of file +// }