diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index 0e4437bdb8..ec32648d0f 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -6,38 +6,43 @@ namespace FlexFlow { -class BatchMatmulPerDeviceState : public PerDeviceOpState { -public: - BatchMatmulPerDeviceState(FFHandler handler); - int a_seq_length_dim, b_seq_length_dim; +struct BMMPerDeviceState { + PerDeviceFFHandle handle; + Allocator allocator; + int a_seq_length_dim; + req b_seq_length_dim; }; +FF_VISITABLE_STRUCT_NO_EQ( + BMMPerDeviceState, handle, allocator, a_seq_length_dim, b_seq_length_dim); + namespace Kernels { namespace BatchMatmul { +BMMPerDeviceState init_kernel(PerDeviceFFHandle const &handle, + Allocator const &allocator, + int a_seq_length_dim, + int b_seq_length_dim); + void forward_kernel(ffStream_t stream, - BatchMatmulPerDeviceState const *, - float *o_ptr, - float const *a_ptr, - float const *b_ptr, - float const *c_ptr, + BMMPerDeviceState const &meta, + float *output_ptr, + float const *lhs_input_ptr, + float const *rhs_input_ptr, int m, int n, int k, int batch, - int a_seq_length_dim = -1, - int b_seq_length_dim = -1, int seq_length = -1); void backward_kernel(ffStream_t stream, - BatchMatmulPerDeviceState const *, + BMMPerDeviceState const &meta, float const *o_ptr, float const *o_grad_ptr, float const *a_ptr, float *a_grad_ptr, float const *b_ptr, float *b_grad_ptr, - float *c_grad_ptr, int m, int n, int k, diff --git a/lib/kernels/include/kernels/batch_norm_kernels.h b/lib/kernels/include/kernels/batch_norm_kernels.h index 6ff90299db..74dfc96068 100644 --- a/lib/kernels/include/kernels/batch_norm_kernels.h +++ b/lib/kernels/include/kernels/batch_norm_kernels.h @@ -8,30 +8,66 @@ namespace FlexFlow { -class BatchNormPerDeviceState : public PerDeviceOpState { -public: - BatchNormPerDeviceState(FFHandler handle, - std::unique_ptr allocator, - int output_n, - int output_c, - int output_h, - int output_w, - bool relu, - bool profiling); - ~BatchNormPerDeviceState(void); - - ffTensorDescriptor_t inputTensor, outputTensor, biasTensor; +struct BatchNormPerDeviceState { + PerDeviceFFHandle handle; + Allocator allocator; + ffTensorDescriptor_t inputTensor; + ffTensorDescriptor_t outputTensor; + ffTensorDescriptor_t biasTensor; ffActivationDescriptor_t actiDesc; ffBatchNormMode_t mode; - float *runningMean, *runningVar, *saveMean, *saveVar; - bool relu; - bool profiling; - std::unique_ptr allocator; + float *runningMean; + float *runningVar; + float *saveMean; + float *saveVar; + int output_n; + int output_c; + int output_h; + int output_w; + ProfilingSettings profiling; + req relu; }; +FF_VISITABLE_STRUCT_NO_EQ(BatchNormPerDeviceState, + handle, + allocator, + inputTensor, + outputTensor, + biasTensor, + actiDesc, + mode, + runningMean, + runningVar, + saveMean, + saveVar, + output_n, + output_c, + output_h, + output_w, + profiling, + relu); + namespace Kernels { namespace BatchNorm { +BatchNormPerDeviceState init_kernel(PerDeviceFFHandle handle, + Allocator allocator, + ffTensorDescriptor_t inputTensor, + ffTensorDescriptor_t outputTensor, + ffTensorDescriptor_t biasTensor, + ffActivationDescriptor_t actiDesc, + ffBatchNormMode_t mode, + float *runningMean, + float *runningVar, + float *saveMean, + float *saveVar, + int output_n, + int output_c, + int output_h, + int output_w, + ProfilingSettings profiling, + bool relu); + void forward_kernel(ffStream_t stream, BatchNormPerDeviceState *m, float const *input_ptr, diff --git a/lib/kernels/src/cuda/batch_matmul_kernels.cu b/lib/kernels/src/cuda/batch_matmul_kernels.cu index 3593ac4ab2..cde0df93c0 100644 --- a/lib/kernels/src/cuda/batch_matmul_kernels.cu +++ b/lib/kernels/src/cuda/batch_matmul_kernels.cu @@ -18,9 +18,6 @@ namespace FlexFlow { -BatchMatmulPerDeviceState::BatchMatmulPerDeviceState(FFHandler handler) - : PerDeviceOpState(handler) {} - namespace Kernels { namespace BatchMatmul { @@ -124,7 +121,7 @@ O = A * B */ void forward_kernel(cudaStream_t stream, - BatchMatmulPerDeviceState const *meta, + BatchMatmulPerDeviceState const &meta, float *o_ptr, float const *a_ptr, float const *b_ptr, diff --git a/lib/kernels/src/hip/batch_matmul_kernels.cpp b/lib/kernels/src/hip/batch_matmul_kernels.cpp index d8b6500326..a06442d3d6 100644 --- a/lib/kernels/src/hip/batch_matmul_kernels.cpp +++ b/lib/kernels/src/hip/batch_matmul_kernels.cpp @@ -19,9 +19,6 @@ namespace FlexFlow { -BatchMatmulPerDeviceState::BatchMatmulPerDeviceState(FFHandler handler) - : PerDeviceOpState(handler) {} - namespace Kernels { namespace BatchMatmul { @@ -32,7 +29,7 @@ O: (batch, n, m) O = A * B */ void forward_kernel(hipStream_t stream, - BatchMatmulPerDeviceState const *meta, + BatchMatmulPerDeviceState const &meta, float *o_ptr, float const *a_ptr, float const *b_ptr, @@ -42,9 +39,9 @@ void forward_kernel(hipStream_t stream, int k, int batch, hipStream_t stream, - int a_seq_length_dim, - int b_seq_length_dim, int seq_length) { + int a_seq_length_dim = meta->a_seq_length_dim; + int b_seq_length_dim = meta->b_seq_length_dim; checkCUDA(hipblasSetStream(meta->handle.blas, stream)); checkCUDNN(miopenSetStream(meta->handle.dnn, stream)); diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index c74824570c..00c700ba20 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -12,8 +12,10 @@ struct BatchMatmulAttrs { }; FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim); -CHECK_VALID_OP_ATTR(BatchMatmulAttrs); +int get_aSeqLengthDim(BatchMatmulAttrs const &attrs); +int get_bSeqLengthDim(BatchMatmulAttrs const &attrs); +CHECK_VALID_OP_ATTR(BatchMatmulAttrs); } // namespace FlexFlow #endif diff --git a/lib/op-attrs/src/batch_matmul.cc b/lib/op-attrs/src/batch_matmul.cc index 1cc8c5cfda..bd61c24737 100644 --- a/lib/op-attrs/src/batch_matmul.cc +++ b/lib/op-attrs/src/batch_matmul.cc @@ -2,6 +2,14 @@ namespace FlexFlow { +int get_aSeqLengthDim(BatchMatmulAttrs const &attrs) { + return attrs.a_seq_length_dim; +} + +int get_bSeqLengthDim(BatchMatmulAttrs const &attrs) { + return attrs.b_seq_length_dim; +} + /* bool BatchMatmulAttrs::is_valid( */ /* ParallelTensorShape const &lhs, ParallelTensorShape const &rhs) const { */ diff --git a/lib/runtime/include/runtime/config.h b/lib/runtime/include/runtime/config.h index a7b8d86171..ef7e779469 100644 --- a/lib/runtime/include/runtime/config.h +++ b/lib/runtime/include/runtime/config.h @@ -104,13 +104,14 @@ struct FFConfig : public use_visitable_cmp { int python_data_loader_type = 2; }; -class FFIterationConfig { -public: +struct FFIterationConfig { FFIterationConfig(); void reset(); int seq_length; }; +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(FFIterationConfig, seq_length); + enum FieldIDs { FID_DATA, }; diff --git a/lib/runtime/src/ops/attention.cc b/lib/runtime/src/ops/attention.cc index bca87bdb53..94e2b03731 100644 --- a/lib/runtime/src/ops/attention.cc +++ b/lib/runtime/src/ops/attention.cc @@ -121,18 +121,6 @@ static DeviceSpecific int num_samples = get_piece_shape(query_parallel_tensor_shape)[ff_dim_t(2)]; int num_heads = get_piece_shape(weight_parallel_tensor_shape)[ff_dim_t(1)]; - assert(qoSeqLength == query.shape[legion_dim_t(1)]); - assert(qSize == query.shape[legion_dim_t(0)]); - assert(num_samples == key.shape[legion_dim_t(2)]); - assert(kvSeqLength == key.shape[legion_dim_t(1)]); - assert(kSize == key.shape[legion_dim_t(0)]); - assert(num_samples == value.shape[legion_dim_t(2)]); - assert(kvSeqLength == value.shape[legion_dim_t(1)]); - assert(vSize == value.shape[legion_dim_t(0)]); - assert(num_samples == output.shape[legion_dim_t(2)]); - assert(qoSeqLength == output.shape[legion_dim_t(1)]); - assert(oProjSize == output.shape[legion_dim_t(0)]); - DeviceSpecific per_device_state = acc.create_device_specific( init_kernel(handle, @@ -149,9 +137,6 @@ static DeviceSpecific qoSeqLength, kvSeqLength, attrs.add_bias_kv)); - - assert(weight.shape.get_volume() * sizeof(float) == - acc.unwrap(per_device_state)->weightSize); return per_device_state; } diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 3e860bd413..45c5e11b9c 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -15,752 +15,268 @@ #include "batch_matmul.h" #include "kernels/batch_matmul_kernels.h" -#include "kernels/profiling.h" -#include "legion/legion_utilities.h" -#include "tasks.h" +#include "legion.h" +#include "op-attrs/get_output_shapes.h" +#include "op-attrs/ops/batch_matmul.h" namespace FlexFlow { using namespace FlexFlow::Kernels::BatchMatmul; +using Legion::Context; +using Legion::PhysicalRegion; +using Legion::Runtime; +using Legion::Task; + enum Slots { - A_INPUT, - B_INPUT, - OUTPUT, - A_INPUT_GRAD, - B_INPUT_GRAD, - OUTPUT_GRAD, - ATTRS, - PROFILING + A_INPUT, // tensor + B_INPUT, // tensor + OUTPUT, // tensor + PROFILING, + HANDLE, + A_SEQ_LENGTH_DIM, + B_SEQ_LENGTH_DIM, + PER_DEVICE_STATE, + ITERATION_CONFIG }; OpTaskInvocation init(BatchMatmulAttrs const &attrs) { - OpTaskBinding b; + OpTaskBinding init; - b.bind_arg(ATTRS, attrs); - b.bind_arg(PROFILING, enable_profiling()); + init.bind_arg(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); + init.bind_arg(B_SEQ_LENGTH_DIM, get_bSeqLengthDim(attrs)); + init.bind_arg(HANDLE, ff_handle()); - return {BATCHMATMUL_INIT_TASK_ID, b}; + return {BATCHMATMUL_INIT_TASK_ID, init}; } OpTaskInvocation forward(BatchMatmulAttrs const &attrs) { - OpTaskBinding b; - - b.bind(A_INPUT, input_tensor(0)); - b.bind(B_INPUT, input_tensor(1)); - b.bind(OUTPUT, output_tensor(0)); - - return {BATCHMATMUL_FWD_TASK_ID, b}; -} - -OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { - OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - - return {BATCHMATMUL_BWD_TASK_ID, b}; -} - -BatchMatmulParams BatchMatmul::get_params() const { - BatchMatmulParams params; - params.a_seq_length_dim = inputs[0]->num_dims - 1 - this->a_seq_length_dim; - params.b_seq_length_dim = inputs[1]->num_dims - 1 - this->b_seq_length_dim; - return params; -} - -Tensor FFModel::batch_matmul(const Tensor A, - const Tensor B, - int a_seq_length_dim, - int b_seq_length_dim, - char const *name) { - Layer *bmm = new Layer(this, - OP_BATCHMATMUL, - DT_FLOAT, - name, - 2 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - A, - B); - assert((a_seq_length_dim <= 1) && - "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " - "Fortran ordering)."); - assert((b_seq_length_dim <= 1) && - "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " - "Fortran ordering)."); - assert(A->num_dims == B->num_dims); - for (int i = A->num_dims - 1; i >= 2; i--) { - assert(A->dims[i] == B->dims[i]); - } - assert(A->dims[0] == B->dims[1]); - int dims[MAX_TENSOR_DIM]; - int numdim = A->num_dims; - for (int i = 0; i < numdim; i++) { - dims[i] = A->dims[i]; - } - dims[0] = B->dims[0]; - bmm->outputs[0] = create_tensor_legion_ordering( - numdim, dims, A->data_type, bmm, 0, true /*create_grad*/); - bmm->add_int_property("a_seq_length_dim", a_seq_length_dim); - bmm->add_int_property("b_seq_length_dim", b_seq_length_dim); - layers.push_back(bmm); - return bmm->outputs[0]; -} - -Op *BatchMatmul::create_operator_from_layer( - FFModel &model, - Layer const *layer, - std::vector const &inputs) { - long long value; - layer->get_int_property("a_seq_length_dim", value); - int a_seq_length_dim = value; - layer->get_int_property("b_seq_length_dim", value); - int b_seq_length_dim = value; - return new BatchMatmul(model, - inputs[0], - inputs[1], - a_seq_length_dim, - b_seq_length_dim, - layer->name); -} - -BatchMatmul::BatchMatmul( - FFModel &model, - BatchMatmulParams const ¶ms, - std::pair const &inputs, - char const *name) - : BatchMatmul(model, - inputs.first, - inputs.second, - params.a_seq_length_dim, - params.b_seq_length_dim, - name) {} - -// return A*B -BatchMatmul::BatchMatmul(FFModel &model, - const ParallelTensor A, - const ParallelTensor B, - int _a_seq_length_dim, - int _b_seq_length_dim, - char const *name) - : Op(model, - OP_BATCHMATMUL, - DT_FLOAT, - name, - 2 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - A, - B), - a_seq_length_dim(A->num_dims - 1 - _a_seq_length_dim), - b_seq_length_dim(B->num_dims - 1 - _b_seq_length_dim) { - assert((_a_seq_length_dim <= 1) && - "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " - "Fortran ordering)."); - assert((_b_seq_length_dim <= 1) && - "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " - "Fortran ordering)."); - assert(A->num_dims == B->num_dims); - for (int i = A->num_dims - 1; i >= 2; i--) { - assert(A->dims[i] == B->dims[i]); - } - assert(A->dims[0] == B->dims[1]); - ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < A->num_dims; i++) { - dims[i] = A->dims[i]; - } - dims[0] = B->dims[0]; - numOutputs = 1; - outputs[0] = model.create_parallel_tensor_legion_ordering( - A->num_dims, dims, DT_FLOAT, this); - // C is not none - // if (C != Tensor::NO_TENSOR) { - // numInputs = 3; - // assert(C.num_dims == outputs[0].num_dims); - // for (int i = 0; i < C.num_dims; i++) - // assert(C.adim[i] == outputs[0].adim[i]); - //} -} - -void BatchMatmul::serialize(Legion::Serializer &sez) const { - BatchMatmulParams params = get_params(); - sez.serialize(params.a_seq_length_dim); - sez.serialize(params.b_seq_length_dim); -} - -using PCG::Node; -/*static*/ -Node BatchMatmul::deserialize(FFModel &ff, - Legion::Deserializer &dez, - ParallelTensor inputs[], - int num_inputs) { - assert(num_inputs == 2); - int a_seq_length_dim, b_seq_length_dim; - dez.deserialize(a_seq_length_dim); - dez.deserialize(b_seq_length_dim); - - BatchMatmulParams params; - params.a_seq_length_dim = a_seq_length_dim; - params.b_seq_length_dim = b_seq_length_dim; - return ff.get_or_create_node({inputs[0], inputs[1]}, params); -} + OpTaskBinding fwd; -Op *BatchMatmul::materialize(FFModel &ff, - ParallelTensor inputs[], - int num_inputs) const { - BatchMatmulParams params = get_params(); - return new BatchMatmul(ff, params, {inputs[0], inputs[1]}, this->name); -} + fwd.bind(A_INPUT, input_tensor(0)); + fwd.bind(B_INPUT, input_tensor(1)); + fwd.bind(OUTPUT, output_tensor(0)); -template <> -void register_task() { - OpTaskSignature sig(OpTaskType::INIT); + fwd.bind_arg(PROFILING, profiling_settings()); + fwd.bind_arg(PER_DEVICE_STATE, per_device_op_state()); + fwd.bind_arg(ITERATION_CONFIG, iteration_config()); - sig.add_arg_slot(ATTRS); - sig.add_arg_slot(PROFILING); - - register_task(BATCHMATMUL_INIT_TASK_ID, "BatchMatmul Init", sig, init_task); + return {BATCHMATMUL_FWD_TASK_ID, fwd}; } -static OpTaskSignature get_fwd_task_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_input_slot(A_INPUT, READ_WRITE); - fwd.add_input_slot(B_INPUT, READ_WRITE); - fwd.add_output_slot(OUTPUT); +OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { + OpTaskBinding bwd = infer_bwd_binding(forward(attrs).binding); - return fwd; + return {BATCHMATMUL_BWD_TASK_ID, bwd}; } -static OpTaskSignature get_bwd_task_signature() { - OpTaskSignature bwd(OpTaskType::BWD); +static DeviceSpecific + init_task_impl(TaskArgumentAccessor const &acc) { + int const a_seq_length_dim = acc.get_argument(A_SEQ_LENGTH_DIM); + int const b_seq_length_dim = acc.get_argument(B_SEQ_LENGTH_DIM); + PerDeviceFFHandle handle = acc.get_argument(HANDLE); + Allocator allocator = acc.get_allocator(); - bwd.add_input_slot(A_INPUT); - bwd.add_input_slot(B_INPUT); - bwd.add_input_grad_slot(A_INPUT_GRAD); - bwd.add_input_grad_slot(B_INPUT_GRAD); - bwd.add_output_slot(OUTPUT); - bwd.add_output_grad_slot(OUTPUT_GRAD); + DeviceSpecific per_device_state = + acc.create_device_specific( + init_kernel(handle, allocator, a_seq_length_dim, b_seq_length_dim)); - return bwd; + return per_device_state; } -OpTaskBinding BatchMatmul::get_init_task_binding() const { - OpTaskBinding binding; - - binding.bind_arg(ATTRS, this->attrs); - binding.bind_arg(PROFILING, this->profiling); - - return binding; +static DeviceSpecific + init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + return init_task_impl(acc); } -OpTaskBinding BatchMatmul::get_fwd_task_binding() const { - OpTaskBinding binding; +static optional forward_task_impl(TaskArgumentAccessor const &acc) { + auto a_input = acc.get_tensor(A_INPUT); + auto b_input = acc.get_tensor(B_INPUT); + auto output = acc.get_tensor(OUTPUT); - binding.bind(A_INPUT, input_tensor(0)); - binding.bind(B_INPUT, input_tensor(1)); - binding.bind(OUTPUT, output_tensor(0)); + ProfilingSettings profiling = acc.get_argument(PROFILING); + auto per_device_state = acc.get_argument(PER_DEVICE_STATE); + FFIterationConfig iter_config = + acc.get_argument(ITERATION_CONFIG); - binding.bind_arg(ATTRS, this->attrs); - return binding; -} + int m = b_input.shape[legion_dim_t(0)]; + assert(m == output.shape[legion_dim_t(0)]); + int n = a_input.shape[legion_dim_t(1)]; + assert(n == output.shape[legion_dim_t(1)]); + int k = a_input.shape[legion_dim_t(0)]; + assert(k == b_input.shape[legion_dim_t(1)]); -OpTaskBinding BatchMatmul::get_bwd_task_binding() const { - OpTaskBinding binding; - binding.bind(A_INPUT, input_tensor(0)); - binding.bind(B_INPUT, input_tensor(1)); - binding.bind_grad(A_INPUT_GRAD, input_tensor(0).grad()); - binding.bind_grad(B_INPUT_GRAD, input_tensor(1).grad()); + assert(a_input.shape.get_volume() == b_input.shape.get_volume()); + assert(a_input.shape.get_volume() == output.shape.get_volume()); - binding.bind(OUTPUT, output_tensor(0)); - binding.bind_grad(OUTPUT_GRAD, output_tensor(0).grad()); + int batch = 1; + for (int i = 2; i < a_input.shape.get_dim(); + i++) { // get_dim() or get_volume()? + int dim_size = a_input.shape[legion_dim_t(i)]; + assert(dim_size == b_input.shape[legion_dim_t(i)]); + assert(dim_size == output.shape[legion_dim_t(i)]); + batch *= dim_size; + } - binding.bind_arg(ATTRS, this->attrs); - return binding; + return profile(forward_kernel, + profiling, + "[BatchMatmul] forward_time = %.2lfms\n", + per_device_state, + output.get_float_ptr(), + a_input.get_float_ptr(), + b_input.get_float_ptr(), + m, + n, + k, + batch, + iter_config.seq_length); } -void BatchMatmul::init(FFModel const &ff) { - int dim = outputs[0]->num_dims; - switch (dim) { -#define DIMFUNC(DIM) \ - case DIM: { \ - // init_with_dim(ff); - this->execute_task(ff, BATCHMATMUL_INIT_TASK_ID, get_init_task_signature()); - break; - } - LEGION_FOREACH_N(DIMFUNC) -#undef DIMFUNC - default: - assert(false); -} -} // namespace FlexFlow -// / -// template -// void BatchMatmul::init_with_dim(FFModel const &ff) { -// assert(check_output_input_weight_same_parallel_is()); -// parallel_is = outputs[0]->parallel_is; -// ArgumentMap argmap; -// Context ctx = ff.config.lg_ctx; -// Runtime *runtime = ff.config.lg_hlr; -// set_argumentmap_for_init(ff, argmap); -// IndexLauncher launcher(BATCHMATMUL_INIT_TASK_ID, -// parallel_is, -// TaskArgument(this, sizeof(BatchMatmul)), -// argmap, -// Predicate::TRUE_PRED, -// false /*must*/, -// 0 /*mapper_id*/, -// outputs[0]->machine_view.hash()); -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, -// 0 /*projection id*/, -// WRITE_ONLY, -// EXCLUSIVE, -// outputs[0]->region)); -// launcher.add_field(0, FID_DATA); -// for (int i = 0; i < numInputs; i++) { -// launcher.add_region_requirement(RegionRequirement(inputs[i]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[i]->region)); -// launcher.add_field(i + 1, FID_DATA); -// } -// FutureMap fm = runtime->execute_index_space(ctx, launcher); -// fm.wait_all_results(); -// set_opmeta_from_futuremap(ff, fm); -// } - -PerDeviceOpState * - BatchMatmul::init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { +static void forward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { TaskArgumentAccessor acc(task, regions, ctx, runtime); - auto const &attrs = acc.get_argument(ATTRS); - bool profiling = acc.get_argument(PROFILING); - - FFHandler handle = *((FFHandler const *)task->local_args); - BatchMatmulPerDeviceState *m = new BatchMatmulPerDeviceState(handle); - m->profiling = profiling; - m->a_seq_length_dim = attrs.a_seq_length_dim; - m->b_seq_length_dim = attrs.b_seq_length_dim; - return m; + forward_task_impl(acc); } -void BatchMatmul::forward(FFModel const &ff) { - int dim = outputs[0]->num_dims; - switch (dim) { -#define DIMFUNC(DIM) \ - case DIM: { \ - // forward_with_dim(ff); - this->execute_task(ff, BATCHMATMUL_FWD_TASK_ID, get_fwd_task_signature()); - break; - } - LEGION_FOREACH_N(DIMFUNC) -#undef DIMFUNC - default: - assert(false); -} -} +static optional backward_task_impl(TaskArgumentAccessor const &acc) { + // BatchMatmul* bmm = (BatchMatmul*) task->args; + FFIterationConfig iter_config = + acc.get_argument(ITERATION_CONFIG); + ProfilingSettings profiling = acc.get_argument(PROFILING); + auto per_device_state = acc.get_argument(PER_DEVICE_STATE); -// template -// void BatchMatmul::forward_with_dim(FFModel const &ff) { -// ArgumentMap argmap; -// Context ctx = ff.config.lg_ctx; -// Runtime *runtime = ff.config.lg_hlr; -// set_argumentmap_for_forward(ff, argmap); -// IndexLauncher launcher( -// BATCHMATMUL_FWD_TASK_ID, -// parallel_is, -// TaskArgument(&ff.iter_config, sizeof(FFIterationConfig)), -// argmap, -// Predicate::TRUE_PRED, -// false /*must*/, -// 0 /*mapper_id*/, -// outputs[0]->machine_view.hash()); -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, -// 0 /*projection id*/, -// WRITE_ONLY, -// EXCLUSIVE, -// outputs[0]->region)); -// launcher.add_field(0, FID_DATA); -// for (int i = 0; i < numInputs; i++) { -// launcher.add_region_requirement(RegionRequirement(inputs[i]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[i]->region)); -// launcher.add_field(i + 1, FID_DATA); -// } -// runtime->execute_index_space(ctx, launcher); -// } - -/* - regions[0](O): output - regions[1](I): A - regions[2](I): B - ////////////////////(optional) regions[3](I): C -- TODO: is C deprecated? - output = A * B /////////+ C -*/ -void BatchMatmul::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 3); - assert(task->regions.size() == 3); + auto output = acc.get_tensor(OUTPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); + // is this equivalent to checking `Domain` equality? + assert(output.shape == output_grad.shape); - TaskArgumentAccessor acc(task, regions, ctx, runtime); + auto a_input = acc.get_tensor(A_INPUT); + auto a_input_grad = acc.get_tensor_grad(A_INPUT); + assert(a_input.shape == a_input_grad.shape); - // const BatchMatmul* bmm = (const BatchMatmul*) task->args; - FFIterationConfig const *iter_config = (FFIterationConfig const *)task->args; - // BatchMatmulMeta const *meta = *((BatchMatmulMeta **)task->local_args); - BatchMatmulPerDeviceState const *meta = - *((BatchMatmulPerDeviceState **)task->local_args); - - auto a_input = acc.get_tensor(A_INPUT); - auto b_input = acc.get_tensor(B_INPUT); - auto output = acc.get_tensor(OUTPUT); - - int m = b_input.shape[0]; - assert(m == output.shape[0]); - int n = a_input.shape[1]; - assert(n == output.shape[1]); - int k = a_input.shape[0]; - assert(k == b_input.shape[1]); - - assert(a_input.shape.size() == b_input.shape.size()); - assert(a_input.shape.size() == output.shape.size()); + auto b_input = acc.get_tensor(B_INPUT); + auto b_input_grad = acc.get_tensor_grad(B_INPUT); + assert(b_input.shape == b_input_grad.shape); + + // check dins + int m = b_input.shape[legion_dim_t(0)]; + assert(m == output.shape[legion_dim_t(0)]); + int n = a_input.shape[legion_dim_t(1)]; + assert(n == output.shape[legion_dim_t(1)]); + int k = a_input.shape[legion_dim_t(0)]; + assert(k == b_input.shape[legion_dim_t(1)]); + assert(a_input.shape.get_volume() == b_input.shape.get_volume()); + assert(a_input.shape.get_volume() == output.shape.get_volume()); int batch = 1; - for (int i = 2; i < a_input.shape.size(); i++) { - int dim_size = a_input.shape[i]; - assert(dim_size == b_input.shape[i]); - assert(dim_size == output.shape[i]); + for (int i = 2; i < a_input.shape.dims.num_dims(); i++) { + int dim_size = a_input.shape[legion_dim_t(i)]; + assert(dim_size == b_input.shape[legion_dim_t(i)]); + assert(dim_size == output.shape[legion_dim_t(i)]); batch *= dim_size; } - float *out_ptr = output.get_float_ptr(); - c float const *a_ptr = a_input.get_float_ptr(); - float const *b_ptr = b_input.get_float_ptr(); - float const *c_ptr = NULL; - // if (regions.size() == 4) { - // Domain c_domain = runtime->get_index_space_domain( - // ctx, task->regions[3].region.get_index_space()); - // assert(c_domain == a_domain); - // c_ptr = helperGetTensorPointerRO( - // regions[3], task->regions[3], FID_DATA, ctx, runtime); - // } - - profile(forward_kernel, - meta->profiling, - "[BatchMatmul] forward_time = %.2lfms\n", - out_ptr, - a_ptr, - b_ptr, - c_ptr, - m, - n, - k, - batch, - meta->a_seq_length_dim, - meta->b_seq_length_dim, - iter_config->seq_length); -} -void BatchMatmul::backward(FFModel const &ff) { - int dim = outputs[0]->num_dims; - switch (dim) { -#define DIMFUNC(DIM) \ - case DIM: { \ - backward_with_dim(ff); \ - break; \ - } - LEGION_FOREACH_N(DIMFUNC) -#undef DIMFUNC - default: - assert(false); - } + return profile(backward_kernel, + profiling, + "[BatchMatmul] backward_time = %.2lfms\n", + per_device_state, + output.get_float_ptr(), + output_grad.get_float_ptr(), + a_input.get_float_ptr(), + a_input_grad.get_float_ptr(), + b_input.get_float_ptr(), + b_input_grad.get_float_ptr(), + m, + n, + k, + batch); } -/* - regions[0](I): output - regions[1](I): output_grad - regions[2](I): A - regions[3](I/O): A_grad - regions[4](I): B - regions[5](I/O): B_grad - regions[6](I/O): C_grad -*/ -// template -// void BatchMatmul::backward_with_dim(FFModel const &ff) { -// ArgumentMap argmap; -// Context ctx = ff.config.lg_ctx; -// Runtime *runtime = ff.config.lg_hlr; -// set_argumentmap_for_backward(ff, argmap); -// IndexLauncher launcher( -// BATCHMATMUL_BWD_TASK_ID, -// parallel_is, -// TaskArgument(&ff.iter_config, sizeof(FFIterationConfig)), -// argmap, -// Predicate::TRUE_PRED, -// false /*must*/, -// 0 /*mapper_id*/, -// outputs[0]->machine_view.hash()); -// // regions[0](I): output -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// outputs[0]->region)); -// launcher.add_field(0, FID_DATA); -// // regions[1](I): output_grad -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// outputs[0]->region_grad)); -// launcher.add_field(1, FID_DATA); -// // regions[2](I): A -// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[0]->region)); -// launcher.add_field(2, FID_DATA); -// // regions[3](I/O): A_grad -// launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, -// 0 /*projection id*/, -// READ_WRITE, -// EXCLUSIVE, -// inputs[0]->region_grad)); -// launcher.add_field(3, FID_DATA); -// // regions[4](I): B -// launcher.add_region_requirement(RegionRequirement(inputs[1]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[1]->region)); -// launcher.add_field(4, FID_DATA); -// // regions[5](I/O): B_grad -// launcher.add_region_requirement(RegionRequirement(inputs[1]->part_grad, -// 0 /*projection id*/, -// READ_WRITE, -// EXCLUSIVE, -// inputs[1]->region_grad)); -// launcher.add_field(5, FID_DATA); -// runtime->execute_index_space(ctx, launcher); -// } - -/* - regions[0](I): output - regions[1](I): output_grad - regions[2](I): A - regions[3](I/O): A_grad - regions[4](I): B - regions[5](I/O): B_grad - regions[6](I/O): C_grad -*/ -__host__ void - BatchMatmul::backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - // Currently assume C is NULL - assert(regions.size() == 6); - assert(task->regions.size() == 6); - // BatchMatmul* bmm = (BatchMatmul*) task->args; +static void backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { TaskArgumentAccessor acc(task, regions, ctx, runtime); - FFIterationConfig const *iter_config = (FFIterationConfig const *)task->args; - BatchMatmulPerDeviceState const *meta = - *((BatchMatmulPerDeviceState **)task->local_args); - // output domains - auto output = acc.get_tensor(OUTPUT); - auto output_grad = acc.get_tensor(OUTPUT_GRAD); - assert(output == - output_grad); // is this equivalent to checking `Domain` equality? - // A domains - auto a_input = acc.get_tensor(A_INPUT); - auto a_input_grad = acc.get_tensor(A_INPUT_GRAD); - assert(a_input == a_input_grad); - // B domains - auto b_input = acc.get_tensor(B_INPUT); - auto b_input_grad = acc.get_tensor(B_INPUT_GRAD); - assert(b_input == b_input_grad); - - // check dins - int m = b_input.shape[0]; - assert(m == output.shape[0]); - int n = a_input.shape[1]; - assert(n == output.shape[1]); - int k = a_input.shape[0]; - assert(k == b_input.shape[1]); - assert(a_input.shape.size() == b_input.shape.size()); - assert(a_input.shape.size() == output.shape.size()); - int batch = 1; - for (int i = 2; i < a_input.shape.size(); i++) { - int dim_size = a_input.shape[i]; - assert(dim_size == b_input.shape[i]); - assert(dim_size == output.shape[i]); - batch *= dim_size; - } - // get pointers - float const *out_ptr = output.get_float_ptr(); - float const *out_grad_ptr = output_grad.get_float_ptr(); - float const *a_ptr = a_input.get_float_ptr(); - float *a_grad_ptr = a_input_grad.get_float_ptr(); - float const *b_ptr = b_input.get_float_ptr(); - float *b_grad_ptr = b_input_grad.get_float_ptr(); - - float *c_grad_ptr = NULL; - - // TODO: add support for meta->a_seq_length_dim >= 0 - // or meta->b_seq_length_dim >= 0 - assert((meta->a_seq_length_dim >= a_len) || (iter_config->seq_length == 0)); - assert((meta->b_seq_length_dim >= b_len) || (iter_config->seq_length == 0)); - - profile(backward_kernel, - meta->profiling, - "[BatchMatmul] backward_time = %.2lfms\n", - meta, - out_ptr, - out_grad_ptr, - a_ptr, - a_grad_ptr, - b_ptr, - b_grad_ptr, - c_grad_ptr, - m, - n, - k, - batch); + backward_task_impl(acc); } -void BatchMatmul::print_layer(FFModel const &ff) { - return; +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + BatchMatmulAttrs const &attrs, + InputParallelTensorDesc const &a_input, + InputParallelTensorDesc const &b_input, + ProfilingSettings const &settings, + MachineView const &pc) { + auto env = sim.new_environment(); + + ParallelTensorShape output_shape = + get_output_shape(attrs, a_input.shape, b_input.shape); + + SimTaskBinding init_binding; + init_binding.bind_arg(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); + init_binding.bind_arg(B_SEQ_LENGTH_DIM, get_bSeqLengthDim(attrs)); + init_binding.bind_arg(HANDLE, ff_handle()); + + auto init_accessor = + env.get_init_accessor(BATCHMATMUL_INIT_TASK_ID, init_binding); + DeviceSpecific per_device_state = + init_task_impl(init_accessor); + + SimTaskBinding fwd_binding; + fwd_binding.bind(A_INPUT, a_input); + fwd_binding.bind(B_INPUT, b_input); + fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind_arg(PROFILING, settings); + fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); + + SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); + + auto fwd_accessor = + env.get_fwd_accessor(BATCHMATMUL_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = + env.get_bwd_accessor(BATCHMATMUL_BWD_TASK_ID, bwd_binding); + + float forward_time = forward_task_impl(fwd_accessor).value(); + float backward_time = backward_task_impl(bwd_accessor).value(); + + float sync_time = default_estimate_sync_time(env); + return make_metrics(forward_time, backward_time, sync_time, env); } -bool BatchMatmul::measure_operator_cost(Simulator *sim, - MachineView const &pc, - CostMetrics &cost_metrics) const { - ParallelTensorBase sub_output, sub_input0, sub_input1; - if (!outputs[0]->get_sub_tensor(pc, sub_output)) { - return false; - } - if (!inputs[0]->get_sub_tensor(pc, sub_input0)) { - return false; - } - if (!inputs[1]->get_sub_tensor(pc, sub_input1)) { - return false; - } - - int input0_c = sub_input0.dims[0].size; - int input0_r = sub_input0.dims[1].size; - int input1_c = sub_input1.dims[0].size; - int input1_r = sub_input1.dims[1].size; - int output_c = sub_output.dims[0].size; - int output_r = sub_output.dims[1].size; +template <> +void register_task() { + OpTaskSignature init(OpTaskType::INIT); - assert(input0_c == input1_r); - assert(input0_r == output_r); - assert(input1_c == output_c); + init.add_arg_slot(A_SEQ_LENGTH_DIM); + init.add_arg_slot(B_SEQ_LENGTH_DIM); + init.add_unchecked_arg_slot(HANDLE); - assert(sub_input0.dims[2] == sub_input1.dims[2]); - assert(sub_input1.dims[2] == sub_output.dims[2]); - int batch = 1; - assert(sub_input0.num_dims == sub_input1.num_dims); - for (int i = 2; i < sub_input0.num_dims; i++) { - assert(sub_input0.dims[i] == sub_input1.dims[i]); - assert(sub_input0.dims[i] == sub_output.dims[i]); - batch *= sub_input0.dims[i].size; - } + register_task(BATCHMATMUL_INIT_TASK_ID, "BatchMatmul Init", init, init_task); +} - BatchMatmulPerDeviceState *meta = sim->batch_matmul_meta; - - // allocate tensors in simulator - sim->free_all(); - float *a_ptr = (float *)sim->allocate(sub_input0.get_volume(), DT_FLOAT); - assert(a_ptr != NULL); - float *b_ptr = (float *)sim->allocate(sub_input1.get_volume(), DT_FLOAT); - assert(b_ptr != NULL); - float *c_ptr = NULL; - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *out_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - assert(out_ptr != NULL); - cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - int m = input1_c; - int n = input0_r; - int k = input0_c; - - assert(meta->profiling == false); - - std::function forward, backward; - forward = [&](ffStream_t stream) { - forward_kernel(stream, meta, out_ptr, a_ptr, b_ptr, c_ptr, m, n, k, batch); - }; - - if (sim->computationMode == COMP_MODE_TRAINING) { - float *a_grad_ptr = - (float *)sim->allocate(sub_input0.get_volume(), DT_FLOAT); - float *b_grad_ptr = - (float *)sim->allocate(sub_input1.get_volume(), DT_FLOAT); - float *c_grad_ptr = NULL; - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *out_grad_ptr = - (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - assert(out_grad_ptr != NULL); - cost_metrics.outputs_memory += - cost_metrics.total_mem_diff_from(sim->offset); - - backward = [&](ffStream_t stream) { - backward_kernel(stream, - meta, - out_ptr, - out_grad_ptr, - a_ptr, - a_grad_ptr, - b_ptr, - b_grad_ptr, - c_grad_ptr, - m, - n, - k, - batch); - }; - } +template <> +void register_task() { + OpTaskSignature fwd(OpTaskType::FWD); - inner_measure_operator_cost(sim, forward, backward, cost_metrics); - - if (sim->computationMode == COMP_MODE_TRAINING) { - printf("[Measure BatchMatmul] name(%s) adim(%d %d %d) bdim(%d %d %d) " - "odim(%d %d %d) forward_time(%.4lf) backward_time(%.4lf)\n", - name, - batch, - input0_r, - input0_c, - batch, - input1_r, - input1_c, - batch, - output_r, - output_c, - cost_metrics.forward_time, - cost_metrics.backward_time); - } else { - printf("[Measure BatchMatmul] name(%s) adim(%d %d %d) bdim(%d %d %d) " - "odim(%d %d %d) forward_time(%.4lf)\n", - name, - batch, - input0_r, - input0_c, - batch, - input1_r, - input1_c, - batch, - output_r, - output_c, - cost_metrics.forward_time); - } + fwd.add_input_slot(A_INPUT); + fwd.add_input_slot(B_INPUT); + fwd.add_output_slot(OUTPUT); + fwd.add_arg_slot(PROFILING); + fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); - return true; + register_task(BATCHMATMUL_FWD_TASK_ID, "BatchMatmul Fwd", fwd, forward_task); } + +template <> +void register_task() { + OpTaskSignature bwd = + infer_bwd_signature(get_op_signature(ATTENTION_FWD_TASK_ID)); + + register_task(BATCHMATMUL_BWD_TASK_ID, "BatchMatmul Bwd", bwd, backward_task); } -; // namespace FlexFlow + +}; // namespace FlexFlow diff --git a/lib/runtime/src/ops/batch_matmul.h b/lib/runtime/src/ops/batch_matmul.h index c133c2a875..018fe1d582 100644 --- a/lib/runtime/src/ops/batch_matmul.h +++ b/lib/runtime/src/ops/batch_matmul.h @@ -1,10 +1,14 @@ #ifndef _FLEXFLOW_BATCH_MATMUL_H #define _FLEXFLOW_BATCH_MATMUL_H +// #include "op-attrs/ops/batch_matmul.h" +// #include "task_spec/op_task_invocation.h" +// #include "task_spec/op_task_signature.h" +// #include "sim_environment.h" + #include "op-attrs/ops/batch_matmul.h" -#include "op_task_invocation.h" -#include "op_task_signature.h" #include "sim_environment.h" +#include "task_spec/op_task_invocation.h" namespace FlexFlow { @@ -19,68 +23,426 @@ OpTaskInvocation init(BatchMatmulAttrs const &); OpTaskInvocation forward(BatchMatmulAttrs const &); OpTaskInvocation backward(BatchMatmulAttrs const &); -CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, +CostMetrics measure_operator_cost(SimEnvFactory const &sim, BatchMatmulAttrs const &attrs, - ParallelTensorShape const &lhs_input_shape, - ParallelTensorShape const &rhs_input_shape, + InputParallelTensorDesc const &a_input, + InputParallelTensorDesc const &b_input, ProfilingSettings const &settings, - MachineView const &); - -/* class BatchMatmul : public Op { */ -/* public: */ -/* BatchMatmul(FFModel &model, */ -/* const ParallelTensor A, */ -/* const ParallelTensor B, */ -/* int a_seq_length_dim, */ -/* int b_seq_length_dim, */ -/* char const *name = nullptr); */ -/* static Op * */ -/* create_operator_from_layer(FFModel &model, */ -/* Layer const *layer, */ -/* std::vector const &inputs); - */ - -/* void init(FFModel const &) override; */ -/* void forward(FFModel const &) override; */ -/* void backward(FFModel const &) override; */ -/* /1* static PCG::Node deserialize(FFModel &ff, *1/ */ -/* /1* Legion::Deserializer &d, *1/ */ -/* /1* ParallelTensor inputs[], *1/ */ -/* /1* int num_inputs); *1/ */ -/* static PerDeviceOpState *init_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void forward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void backward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* bool measure_operator_cost(Simulator *sim, */ -/* MachineView const &pc, */ -/* CostMetrics &cost_metrics) const override; */ - -/* OpTaskBinding get_init_task_binding() const override; */ -/* OpTaskBinding get_fwd_task_binding() const override; */ -/* OpTaskBinding get_bwd_task_binding() const override; */ -/* private: */ -/* template */ -/* void init_with_dim(FFModel const &ff); */ -/* template */ -/* void forward_with_dim(FFModel const &ff); */ -/* template */ -/* void backward_with_dim(FFModel const &ff); */ - -/* public: */ -/* int a_seq_length_dim, b_seq_length_dim; */ -/* }; */ + MachineView const &pc); } // namespace FlexFlow #endif + +// BatchMatmulParams BatchMatmul::get_params() const { +// BatchMatmulParams params; +// params.a_seq_length_dim = inputs[0]->num_dims - 1 - this->a_seq_length_dim; +// params.b_seq_length_dim = inputs[1]->num_dims - 1 - this->b_seq_length_dim; +// return params; +// } + +// Tensor FFModel::batch_matmul(const Tensor A, +// const Tensor B, +// int a_seq_length_dim, +// int b_seq_length_dim, +// char const *name) { +// Layer *bmm = new Layer(this, +// OP_BATCHMATMUL, +// DT_FLOAT, +// name, +// 2 /*inputs*/, +// 0 /*weights*/, +// 1 /*outputs*/, +// A, +// B); +// assert((a_seq_length_dim <= 1) && +// "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " +// "Fortran ordering)."); +// assert((b_seq_length_dim <= 1) && +// "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " +// "Fortran ordering)."); +// assert(A->num_dims == B->num_dims); +// for (int i = A->num_dims - 1; i >= 2; i--) { +// assert(A->dims[i] == B->dims[i]); +// } +// assert(A->dims[0] == B->dims[1]); +// int dims[MAX_TENSOR_DIM]; +// int numdim = A->num_dims; +// for (int i = 0; i < numdim; i++) { +// dims[i] = A->dims[i]; +// } +// dims[0] = B->dims[0]; +// bmm->outputs[0] = create_tensor_legion_ordering( +// numdim, dims, A->data_type, bmm, 0, true /*create_grad*/); +// bmm->add_int_property("a_seq_length_dim", a_seq_length_dim); +// bmm->add_int_property("b_seq_length_dim", b_seq_length_dim); +// layers.push_back(bmm); +// return bmm->outputs[0]; +// } + +// Op *BatchMatmul::create_operator_from_layer( +// FFModel &model, +// Layer const *layer, +// std::vector const &inputs) { +// long long value; +// layer->get_int_property("a_seq_length_dim", value); +// int a_seq_length_dim = value; +// layer->get_int_property("b_seq_length_dim", value); +// int b_seq_length_dim = value; +// return new BatchMatmul(model, +// inputs[0], +// inputs[1], +// a_seq_length_dim, +// b_seq_length_dim, +// layer->name); +// } + +// BatchMatmul::BatchMatmul( +// FFModel &model, +// BatchMatmulParams const ¶ms, +// std::pair const &inputs, +// char const *name) +// : BatchMatmul(model, +// inputs.first, +// inputs.second, +// params.a_seq_length_dim, +// params.b_seq_length_dim, +// name) {} + +// // return A*B +// BatchMatmul::BatchMatmul(FFModel &model, +// const ParallelTensor A, +// const ParallelTensor B, +// int _a_seq_length_dim, +// int _b_seq_length_dim, +// char const *name) +// : Op(model, +// OP_BATCHMATMUL, +// DT_FLOAT, +// name, +// 2 /*inputs*/, +// 0 /*weights*/, +// 1 /*outputs*/, +// A, +// B), +// a_seq_length_dim(A->num_dims - 1 - _a_seq_length_dim), +// b_seq_length_dim(B->num_dims - 1 - _b_seq_length_dim) { +// assert((_a_seq_length_dim <= 1) && +// "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " +// "Fortran ordering)."); +// assert((_b_seq_length_dim <= 1) && +// "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " +// "Fortran ordering)."); +// assert(A->num_dims == B->num_dims); +// for (int i = A->num_dims - 1; i >= 2; i--) { +// assert(A->dims[i] == B->dims[i]); +// } +// assert(A->dims[0] == B->dims[1]); +// ParallelDim dims[MAX_TENSOR_DIM]; +// for (int i = 0; i < A->num_dims; i++) { +// dims[i] = A->dims[i]; +// } +// dims[0] = B->dims[0]; +// numOutputs = 1; +// outputs[0] = model.create_parallel_tensor_legion_ordering( +// A->num_dims, dims, DT_FLOAT, this); +// // C is not none +// // if (C != Tensor::NO_TENSOR) { +// // numInputs = 3; +// // assert(C.num_dims == outputs[0].num_dims); +// // for (int i = 0; i < C.num_dims; i++) +// // assert(C.adim[i] == outputs[0].adim[i]); +// //} +// } + +// void BatchMatmul::serialize(Legion::Serializer &sez) const { +// BatchMatmulParams params = get_params(); +// sez.serialize(params.a_seq_length_dim); +// sez.serialize(params.b_seq_length_dim); +// } + +// using PCG::Node; +// /*static*/ +// Node BatchMatmul::deserialize(FFModel &ff, +// Legion::Deserializer &dez, +// ParallelTensor inputs[], +// int num_inputs) { +// assert(num_inputs == 2); +// int a_seq_length_dim, b_seq_length_dim; +// dez.deserialize(a_seq_length_dim); +// dez.deserialize(b_seq_length_dim); + +// BatchMatmulParams params; +// params.a_seq_length_dim = a_seq_length_dim; +// params.b_seq_length_dim = b_seq_length_dim; +// return ff.get_or_create_node({inputs[0], inputs[1]}, params); +// } + +// Op *BatchMatmul::materialize(FFModel &ff, +// ParallelTensor inputs[], +// int num_inputs) const { +// BatchMatmulParams params = get_params(); +// return new BatchMatmul(ff, params, {inputs[0], inputs[1]}, this->name); +// } + +// void BatchMatmul::forward(FFModel const &ff) { +// int dim = outputs[0]->num_dims; +// switch (dim) { +// #define DIMFUNC(DIM) \ +// case DIM: { \ +// // forward_with_dim(ff); +// this->execute_task(ff, BATCHMATMUL_FWD_TASK_ID, +// get_fwd_task_signature()); break; +// } +// LEGION_FOREACH_N(DIMFUNC) +// #undef DIMFUNC +// default: +// assert(false); +// } +// } + +// template +// void BatchMatmul::forward_with_dim(FFModel const &ff) { +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_forward(ff, argmap); +// IndexLauncher launcher( +// BATCHMATMUL_FWD_TASK_ID, +// parallel_is, +// TaskArgument(&ff.iter_config, sizeof(FFIterationConfig)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// for (int i = 0; i < numInputs; i++) { +// launcher.add_region_requirement(RegionRequirement(inputs[i]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[i]->region)); +// launcher.add_field(i + 1, FID_DATA); +// } +// runtime->execute_index_space(ctx, launcher); +// } + +/* + regions[0](O): output + regions[1](I): A + regions[2](I): B + ////////////////////(optional) regions[3](I): C -- TODO: is C deprecated? + output = A * B /////////+ C +*/ + +// void BatchMatmul::init(FFModel const &ff) { +// int dim = outputs[0]->num_dims; +// switch (dim) { +// #define DIMFUNC(DIM) \ +// case DIM: { \ +// // init_with_dim(ff); +// this->execute_task(ff, BATCHMATMUL_INIT_TASK_ID, +// get_init_task_signature()); break; +// } +// LEGION_FOREACH_N(DIMFUNC) +// #undef DIMFUNC +// default: +// assert(false); +// } +// } // namespace FlexFlow +// // / +// // template +// // void BatchMatmul::init_with_dim(FFModel const &ff) { +// // assert(check_output_input_weight_same_parallel_is()); +// // parallel_is = outputs[0]->parallel_is; +// // ArgumentMap argmap; +// // Context ctx = ff.config.lg_ctx; +// // Runtime *runtime = ff.config.lg_hlr; +// // set_argumentmap_for_init(ff, argmap); +// // IndexLauncher launcher(BATCHMATMUL_INIT_TASK_ID, +// // parallel_is, +// // TaskArgument(this, sizeof(BatchMatmul)), +// // argmap, +// // Predicate::TRUE_PRED, +// // false /*must*/, +// // 0 /*mapper_id*/, +// // outputs[0]->machine_view.hash()); +// // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// // 0 /*projection id*/, +// // WRITE_ONLY, +// // EXCLUSIVE, +// // outputs[0]->region)); +// // launcher.add_field(0, FID_DATA); +// // for (int i = 0; i < numInputs; i++) { +// // launcher.add_region_requirement(RegionRequirement(inputs[i]->part, +// // 0 /*projection id*/, +// // READ_ONLY, +// // EXCLUSIVE, +// // inputs[i]->region)); +// // launcher.add_field(i + 1, FID_DATA); +// // } +// // FutureMap fm = runtime->execute_index_space(ctx, launcher); +// // fm.wait_all_results(); +// // set_opmeta_from_futuremap(ff, fm); +// // } + +// OpTaskBinding BatchMatmul::get_bwd_task_binding() const { +// OpTaskBinding binding; +// binding.bind(A_INPUT, input_tensor(0)); +// binding.bind(B_INPUT, input_tensor(1)); +// binding.bind_grad(A_INPUT_GRAD, input_tensor(0).grad()); +// binding.bind_grad(B_INPUT_GRAD, input_tensor(1).grad()); + +// binding.bind(OUTPUT, output_tensor(0)); +// binding.bind_grad(OUTPUT_GRAD, output_tensor(0).grad()); + +// binding.bind_arg(ATTRS, this->attrs); +// return binding; +// } + +// static OpTaskSignature get_fwd_task_signature() { +// OpTaskSignature fwd(OpTaskType::FWD); + +// fwd.add_input_slot(A_INPUT, READ_WRITE); +// fwd.add_input_slot(B_INPUT, READ_WRITE); +// fwd.add_output_slot(OUTPUT); + +// return fwd; +// } + +// static OpTaskSignature get_bwd_task_signature() { +// OpTaskSignature bwd(OpTaskType::BWD); + +// bwd.add_input_slot(A_INPUT); +// bwd.add_input_slot(B_INPUT); +// bwd.add_input_grad_slot(A_INPUT_GRAD); +// bwd.add_input_grad_slot(B_INPUT_GRAD); +// bwd.add_output_slot(OUTPUT); +// bwd.add_output_grad_slot(OUTPUT_GRAD); + +// return bwd; +// } + +// OpTaskBinding BatchMatmul::get_init_task_binding() const { +// OpTaskBinding binding; + +// binding.bind_arg(ATTRS, this->attrs); +// binding.bind_arg(PROFILING, this->profiling); + +// return binding; +// } + +// OpTaskBinding BatchMatmul::get_fwd_task_binding() const { +// OpTaskBinding binding; + +// binding.bind(A_INPUT, input_tensor(0)); +// binding.bind(B_INPUT, input_tensor(1)); +// binding.bind(OUTPUT, output_tensor(0)); + +// binding.bind_arg(ATTRS, this->attrs); +// return binding; +// } + +// void BatchMatmul::backward(FFModel const &ff) { +// int dim = outputs[0]->num_dims; +// switch (dim) { +// #d ef ine DIMFUNC(DIM) \ +// case DIM: { \ +// backward_with_dim(ff); \ +// break; \ +// } +// LEGION_FOREACH_N(DIMFUNC) +// #undef DIMFUNC +// default: +// assert(false); +// } +// } + +// void BatchMatmul::print_layer(FFModel const &ff) { +// return; +// } + +/* + regions[0](I): output + regions[1](I): output_grad + regions[2](I): A + regions[3](I/O): A_grad + regions[4](I): B + regions[5](I/O): B_grad + regions[6](I/O): C_grad +*/ +// template +// void BatchMatmul::backward_with_dim(FFModel const &ff) { +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_backward(ff, argmap); +// IndexLauncher launcher( +// BATCHMATMUL_BWD_TASK_ID, +// parallel_is, +// TaskArgument(&ff.iter_config, sizeof(FFIterationConfig)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// // regions[0](I): output +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// // regions[1](I): output_grad +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// outputs[0]->region_grad)); +// launcher.add_field(1, FID_DATA); +// // regions[2](I): A +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(2, FID_DATA); +// // regions[3](I/O): A_grad +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// inputs[0]->region_grad)); +// launcher.add_field(3, FID_DATA); +// // regions[4](I): B +// launcher.add_region_requirement(RegionRequirement(inputs[1]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[1]->region)); +// launcher.add_field(4, FID_DATA); +// // regions[5](I/O): B_grad +// launcher.add_region_requirement(RegionRequirement(inputs[1]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// inputs[1]->region_grad)); +// launcher.add_field(5, FID_DATA); +// runtime->execute_index_space(ctx, launcher); +// } + +/* + regions[0](I): output + regions[1](I): output_grad + regions[2](I): A + regions[3](I/O): A_grad + regions[4](I): B + regions[5](I/O): B_grad + regions[6](I/O): C_grad +*/ diff --git a/lib/runtime/src/ops/batch_norm.cc b/lib/runtime/src/ops/batch_norm.cc index 98cc4576a1..6ebf359051 100644 --- a/lib/runtime/src/ops/batch_norm.cc +++ b/lib/runtime/src/ops/batch_norm.cc @@ -16,505 +16,285 @@ #include "batch_norm.h" #include "kernels/batch_norm_kernels.h" #include "legion/legion_utilities.h" -#include "task_spec.h" + +namespace FlexFlow { using namespace FlexFlow::Kernels::BatchNorm; -namespace FlexFlow { +using Legion::Context; +using Legion::PhysicalRegion; +using Legion::Runtime; +using Legion::Task; enum Slots { - INPUT, - SCALE, - BIAS, - OUTPUT, - INPUT_GRAD, - SCALE_GRAD, - BIAS_GRAD, - OUTPUT_GRAD, + INPUT, // tensor + SCALE, // tensor + BIAS, // tensor + OUTPUT, // tensor ATTRS, - PROFILING -} - -Tensor - FFModel::batch_norm(const Tensor input, bool relu, char const *name) { - assert(input->num_dims == 4); /*NCHW*/ - Layer *bm = new Layer(this, - OP_BATCHNORM, - DT_FLOAT, - name, - 1 /*inputs*/, - 2 /*weights*/, - 1 /*outputs*/, - input); - int numdims = 4; - bm->outputs[0] = create_tensor_legion_ordering( - numdims, input->dims, DT_FLOAT, bm, 0, true /*create_grad*/); - bm->add_int_property("relu", relu); - layers.push_back(bm); - return bm->outputs[0]; -} - -/* - locals[0] = scale - locals[1] = bias -*/ -BatchNorm::BatchNorm(FFModel &model, - const ParallelTensor _input, - const ParallelTensor _scale, - const ParallelTensor _bias, - bool _relu, - char const *name) - : Op(model, - OP_BATCHNORM, - DT_FLOAT, - name, - 1 /*inputs*/, - 2 /*weights*/, - 1 /*outputs*/, - _input, - _scale, - _bias), - relu(_relu) { - assert(_input->num_dims == 4); - numOutputs = 1; - ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < _input->num_dims; i++) { - dims[i] = _input->dims[_input->num_dims - 1 - i]; - } - outputs[0] = - model.create_parallel_tensor(_input->num_dims, dims, DT_FLOAT, this); - return; -} - -static OpTaskSignature get_init_task_signature() { - OpTaskSignature init(OpTaskType::INIT); - - init.add_arg_slot(ATTRS); - init.add_arg_slot(PROFILING); - - // init.add_input_slot(INPUT); - // init.add_param_slot(SCALE); - // init.add_param_slot(BIAS); - init.add_output_slot(OUTPUT); -} - -static OpTaskSignature get_fwd_task_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_arg_slot(ATTRS); - - fwd.add_input_slot(INPUT); - fwd.add_param_slot(SCALE); - fwd.add_param_slot(BIAS); - fwd.add_output_slot(OUTPUT, WRITE_DISCARD); - - return fwd; -} - -static OpTaskSignature get_bwd_task_signature() { - OpTaskSignature bwd(OpTaskType::BWD); - - bwd.add_arg_slot(ATTRS); + PROFILING, + PER_DEVICE_STATE, + RELU, + HANDLE +}; - bwd.add_input_slot(INPUT); - bwd.add_input_grad_slot(INPUT_GRAD, READ_WRITE); - bwd.add_param_slot(SCALE); - bwd.add_param_grad_slot(SCALE_GRAD, READ_WRITE); - bwd.add_param_grad_slot(BIAS_GRAD, READ_WRITE); - bwd.add_output_grad_slot(OUTPUT_GRAD); - - return bwd; -} - -OpTaskBinding BatchNorm::get_init_task_binding() const { +OpTaskInvocation init(BatchNormAttrs const &attrs) { OpTaskBinding binding; - binding.bind_arg(ATTRS, this->attrs); - binding.bind_arg(PROFILING, this->profiling); - - // binding.bind(INPUT, input_tensor(0)); - // binding.bind(SCALE, param_tensor(0)); - // binding.bind(BIAS, param_tensor(1)); + binding.bind(INPUT, input_tensor(0)); + binding.bind(BIAS, input_tensor(2)); binding.bind(OUTPUT, output_tensor(0)); - return binding; + binding.bind_arg(ATTRS, attrs); + binding.bind_arg(PROFILING, profiling_settings()); + binding.bind_arg(HANDLE, ff_handle()); + + return {BATCHNORM_INIT_TASK_ID, binding}; } -OpTaskBinding BatchNorm::get_fwd_task_binding() const { +OpTaskInvocation forward(BatchNormAttrs const &attrs) { OpTaskBinding binding; - - binding.bind_arg(ATTRS, this->attrs); + binding.bind_arg(PROFILING, profiling_settings()); + binding.bind_arg(PER_DEVICE_STATE, + per_device_op_state()); binding.bind(INPUT, input_tensor(0)); - binding.bind(SCALE, param_tensor(0)); - binding.bind(BIAS, param_tensor(1)); + binding.bind(SCALE, input_tensor(1)); + binding.bind(BIAS, input_tensor(2)); binding.bind(OUTPUT, output_tensor(0)); - return binding; + return {BATCHNORM_FWD_TASK_ID, binding}; } -OpTaskBinding BatchNorm::get_bwd_task_binding() const { - OpTaskBinding binding; +OpTaskInvocation backward(BatchNormAttrs const &attrs) { + OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - binding.bind_arg(ATTRS, this->attrs); + return {BATCHNORM_BWD_TASK_ID, binding}; +} - binding.bind(INPUT, input_tensor(0)); - binding.bind(INPUT_GRAD, input_tensor(0).grad()); - binding.bind(SCALE, param_tensor(0)); - binding.bind(SCALE_GRAD, param_tensor(0).grad()); - binding.bind(BIAS_GRAD, param_tensor(1).grad()); - binding.bind(OUTPUT_GRAD, output_tensor(0).grad()); +static DeviceSpecific + init_task_impl(TaskArgumentAccessor const &acc) { + Allocator allocator = acc.get_allocator(); + PerDeviceFFHandle handle = acc.get_argument(HANDLE); + ProfilingSettings profiling = acc.get_argument(PROFILING); + auto output = acc.get_tensor(OUTPUT); + auto const &attrs = acc.get_argument(ATTRS); + + int output_w = output.shape[legion_dim_t(0)]; + int output_h = output.shape[legion_dim_t(1)]; + int output_c = output.shape[legion_dim_t(2)]; + int output_n = output.shape[legion_dim_t(3)]; + + ffTensorDescriptor_t inputTensor; + ffTensorDescriptor_t outputTensor; + ffTensorDescriptor_t biasTensor; + ffActivationDescriptor_t actiDesc; + ffBatchNormMode_t mode; + + size_t totalSize = sizeof(float) * output_c * 4; + float *runningMean = (float *)allocator.allocate(totalSize); + float *runningVar = (float *)runningMean + output_c; + float *saveMean = (float *)runningVar + output_c; + float *saveVar = (float *)saveMean + output_c; + + DeviceSpecific per_device_state = + acc.create_device_specific( + init_kernel(handle, + allocator, + inputTensor, + outputTensor, + biasTensor, + actiDesc, + mode, + runningMean, + runningVar, + saveMean, + saveVar, + output_n, + output_c, + output_h, + output_w, + profiling, + attrs.relu)); + + return per_device_state; +} - return binding; +static DeviceSpecific + init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + return init_task_impl(acc); } -void BatchNorm::init(FFModel const &ff) { - this->execute_task(ff, BATCHNORM_INIT_TASK_ID, get_init_task_signature()); - // assert(check_output_input_weight_same_parallel_is()); - // parallel_is = outputs[0]->parallel_is; - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_init(ff, argmap); - // IndexLauncher launcher(BATCHNORM_INIT_TASK_ID, - // parallel_is, - // TaskArgument(this, sizeof(BatchNorm)), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(1, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(weights[0]->region, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // weights[0]->region)); - // launcher.add_field(2, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(weights[1]->region, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // weights[1]->region)); - // launcher.add_field(3, FID_DATA); - // FutureMap fm = runtime->execute_index_space(ctx, launcher); - // fm.wait_all_results(); - // set_opmeta_from_futuremap(ff, fm); +static optional forward_task_impl(TaskArgumentAccessor const &acc) { + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + ProfilingSettings profiling = acc.get_argument(PROFILING); + + auto input = acc.get_tensor(INPUT); + auto output = acc.get_tensor(OUTPUT); + auto scale = acc.get_tensor(SCALE); + auto bias = acc.get_tensor(SCALE); + + return profile(forward_kernel, + profiling, + "[BatchNorm] forward_time = %.2lfms\n", + &per_device_state, + input.get_float_ptr(), + output.get_float_ptr(), + scale.get_float_ptr(), + bias.get_float_ptr()); } -/* - regions[0]: input - regions[1]: output - regions[2](I): scale - regions[3](I): bias -*/ -PerDeviceOpState * - BatchNorm::init_task(Task const *task, +static void forward_task(Task const *task, std::vector const ®ions, Context ctx, Runtime *runtime) { - assert(regions.size() == 4); - assert(task->regions.size() == 4); TaskArgumentAccessor acc(task, regions, ctx, runtime); - FFHandler handle = *((FFHandler const *)task->local_args); - - auto output = acc.get_tensor(OUTPUT); - - int output_w = output.shape[0]; - int output_h = output.shape[1]; - int output_c = output.shape[2]; - int output_n = output.shape[3]; - - Memory gpu_mem = Machine::MemoryQuery(Machine::get_machine()) - .only_kind(Memory::GPU_FB_MEM) - .best_affinity_to(task->target_proc) - .first(); - BatchNormPerDeviceState *m = new BatchNormPerDeviceState( - handle, bm, gpu_mem, output_n, output_c, output_h, output_w); - return m; + forward_task_impl(acc); } -void BatchNorm::forward(FFModel const &ff) { - this->execute_task(ff, BATCHNORM_FWD_TASK_ID, get_fwd_task_signature()); - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_forward(ff, argmap); - // IndexLauncher launcher(BATCHNORM_FWD_TASK_ID, - // parallel_is, - // TaskArgument(NULL, 0), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // WRITE_DISCARD, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(1, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(weights[0]->region, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // weights[0]->region)); - // launcher.add_field(2, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(weights[1]->region, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // weights[1]->region)); - // launcher.add_field(3, FID_DATA); - - // runtime->execute_index_space(ctx, launcher); +static optional backward_task_impl(TaskArgumentAccessor const &acc) { + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + ProfilingSettings profiling = acc.get_argument(PROFILING); + + auto input = acc.get_tensor(INPUT); + auto input_grad = acc.get_tensor_grad(INPUT); + auto output = acc.get_tensor(OUTPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); + auto scale = acc.get_tensor(SCALE); + auto scale_grad = acc.get_tensor_grad(SCALE); + auto bias_grad = acc.get_tensor_grad(BIAS); + + return profile(backward_kernel, + profiling, + "[BatchNorm] backward_time = %.2lfms\n", + &per_device_state, + input.get_float_ptr(), + output_grad.get_float_ptr(), + output.get_float_ptr(), + input_grad.get_float_ptr(), + scale.get_float_ptr(), + scale_grad.get_float_ptr(), + bias_grad.get_float_ptr(), + output.shape.get_volume()); } -/* - regions[0](I): input - regions[1](O): ouptut - regions[2](I): scale - regions[3](I): bias -*/ -void BatchNorm::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 4); - assert(task->regions.size() == 4); - // const BatchNorm* bm = (BatchNorm*) task->args; +static void backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { TaskArgumentAccessor acc(task, regions, ctx, runtime); - BatchNormPerDeviceState *m = *((BatchNormPerDeviceState **)task->local_args); - - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - auto scale = acc.get_tensor(SCALE); - auto bias = acc.get_tensor(SCALE); - - profile(forward_kernel, - m->profiling, - "[BatchNorm] forward_time = %.2lfms\n", - m, - input.get_float_ptr(), - output.get_float_ptr(), - scale.get_float_ptr(), - bias.get_float_ptr()); + backward_task_impl(acc); } -void BatchNorm::backward(FFModel const &ff) { - this->execute_task(ff, BATCHNORM_BWD_TASK_ID, get_bwd_task_signature()); - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_backward(ff, argmap); - // IndexLauncher launcher(BATCHNORM_BWD_TASK_ID, - // parallel_is, - // TaskArgument(NULL, 0), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // // regions[0](I): input - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // // regions[1](I/O): input_grad (we only need grad tensors) - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, - // 0 /*projection id*/, - // READ_WRITE, - // EXCLUSIVE, - // inputs[0]->region_grad)); - // launcher.add_field(1, FID_DATA); - // // regions[2](I): output - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(2, FID_DATA); - // // regions[3](I/O): output_grad - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, - // 0 /*projection id*/, - // READ_WRITE, - // EXCLUSIVE, - // outputs[0]->region_grad)); - // launcher.add_field(3, FID_DATA); - // // regions[4](I): filter - // launcher.add_region_requirement(RegionRequirement(weights[0]->region, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // weights[0]->region)); - // launcher.add_field(4, FID_DATA); - // // regions[5](I/O): filter_grad - // launcher.add_region_requirement(RegionRequirement(weights[0]->part_grad, - // 0 /*projection id*/, - // READ_WRITE, - // EXCLUSIVE, - // weights[0]->region_grad)); - // launcher.add_field(5, FID_DATA); - // // regions[6](I/O): bias_grad - // launcher.add_region_requirement(RegionRequirement(weights[1]->part_grad, - // 0 /*projection id*/, - // READ_WRITE, - // EXCLUSIVE, - // weights[1]->region_grad)); - // launcher.add_field(6, FID_DATA); - // FutureMap fm = runtime->execute_index_space(ctx, launcher); +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + BatchNormAttrs const &attrs, + InputParallelTensorDesc const &input_shape, + InputParallelTensorDesc const &scale_shape, + InputParallelTensorDesc const &bias_shape, + ProfilingSettings const &settings, + MachineView const &mv) { + + // int output_w = sub_output.dims[0].size; + // int output_h = sub_output.dims[1].size; + // int output_c = sub_output.dims[2].size; + // int output_n = sub_output.dims[3].size; + // BatchNormPerDeviceState *m = new BatchNormPerDeviceState( + // sim->handler, this, sim->memory, output_n, output_c, output_h, + // output_w); + + // sim->free_all(); + // float *input_ptr = (float *)sim->allocate(sub_input.get_volume(), + // DT_FLOAT); assert(input_ptr != NULL); cost_metrics.inputs_memory += + // cost_metrics.total_mem_diff_from(sim->offset); + + // float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), + // DT_FLOAT); assert(output_ptr != NULL); cost_metrics.outputs_memory += + // cost_metrics.total_mem_diff_from(sim->offset); + + // float *bias_ptr = (float *)sim->allocate(output_c, DT_FLOAT); + // assert(bias_ptr != NULL); + // float *scale_ptr = (float *)sim->allocate(output_c, DT_FLOAT); + // assert(scale_ptr != NULL); + // cost_metrics.weights_memory += + // cost_metrics.total_mem_diff_from(sim->offset); + + auto env = sim.new_environment(); + + ParallelTensorShape output_shape = get_output_shape(attrs); + + SimTaskBinding init_binding; + init_binding.bind(INPUT, input_shape); + init_binding.bind(BIAS, bias_shape); + init_binding.bind(OUTPUT, output_shape); + + init_binding.bind_arg(ATTRS, attrs); + init_binding.bind_arg(PROFILING, settings); + init_binding.bind_arg(HANDLE, ff_handle()); + + auto init_accessor = + env.get_init_accessor(ATTENTION_INIT_TASK_ID, init_binding); + DeviceSpecific per_device_state = + init_task_impl(init_accessor); + + SimTaskBinding fwd_binding; + fwd_binding.bind(INPUT, input_shape); + fwd_binding.bind(SCALE, scale_shape); + fwd_binding.bind(BIAS, bias_shape); + fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind_arg(PROFILING, settings); + fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); + + SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); + + auto fwd_accessor = env.get_fwd_accessor(ATTENTION_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = env.get_bwd_accessor(ATTENTION_BWD_TASK_ID, bwd_binding); + + float forward_time = forward_task_impl(fwd_accessor).value(); + float backward_time = backward_task_impl(bwd_accessor).value(); + + float sync_time = default_estimate_sync_time(env); + return make_metrics(forward_time, backward_time, sync_time, env); } -/* - regions[0](I): input - regions[1](I/O): input_grad - regions[2](I): output - regions[3](I/O): output_grad - regions[4](I): scale - regions[5](I/O): scale_grad - regions[6](I/O): bias_grad -*/ -__host__ void - BatchNorm::backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 7); - assert(task->regions.size() == 7); - // float beta = 0.0f; - // const BatchNorm* bm = (BatchNorm*) task->args; - TaskArgumentAccessor acc(task, regions, ctx, runtime); - BatchNormPerDeviceState *m = *((BatchNormPerDeviceState **)task->local_args); - - auto input = acc.get_tensor(INPUT); - auto input_grad = acc.get_tensor_grad(INPUT_GRAD); - auto output = acc.get_tensor(OUTPUT); - auto output_grad = acc.get_tensor_grad(OUTPUT_GRAD); - auto scale = acc.get_tensor(SCALE); - auto scale_grad = acc.get_tensor_grad(SCALE_GRAD); - auto bias_grad = acc.get_tensor_grad(BIAS_GRAD); - - profile(backward_kernel, - m->profiling, - "[BatchNorm] backward_time = %.2lfms\n", - m, - input.get_float_ptr(), - output_grad.get_float_ptr(), - output.get_float_ptr(), - input_grad.get_float_ptr(), - scale.get_float_ptr(), - scale_grad.get_float_ptr(), - bias_grad.get_float_ptr(), - output.get_volume()); +template <> +void register_task() { + OpTaskSignature init(OpTaskType::INIT); + init.add_input_slot(INPUT); + init.add_input_slot(BIAS); + init.add_output_slot(OUTPUT); + init.add_arg_slot(ATTRS); + init.add_arg_slot(PROFILING); + init.add_unchecked_arg_slot(HANDLE); + + register_task(BATCHNORM_INIT_TASK_ID, "BatchNorm Init", init, init_task); +} + +template <> +void register_task() { + OpTaskSignature fwd(OpTaskType::FWD); + + fwd.add_input_slot(INPUT); + fwd.add_input_slot(SCALE); + fwd.add_input_slot(BIAS); + fwd.add_output_slot(OUTPUT); + fwd.add_arg_slot(PROFILING); + fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); + + register_task(BATCHNORM_FWD_TASK_ID, "BatchNorm Fwd", fwd, forward_task); } -bool BatchNorm::measure_operator_cost(Simulator *sim, - MachineView const &mv, - CostMetrics &cost_metrics) const { - ParallelTensorBase sub_input, sub_output; - if (!outputs[0]->get_sub_tensor(mv, sub_output)) { - return false; - } - if (!inputs[0]->get_sub_tensor(mv, sub_input)) { - return false; - } - - int output_w = sub_output.dims[0].size; - int output_h = sub_output.dims[1].size; - int output_c = sub_output.dims[2].size; - int output_n = sub_output.dims[3].size; - BatchNormPerDeviceState *m = new BatchNormPerDeviceState( - sim->handler, this, sim->memory, output_n, output_c, output_h, output_w); - - sim->free_all(); - float *input_ptr = (float *)sim->allocate(sub_input.get_volume(), DT_FLOAT); - assert(input_ptr != NULL); - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - assert(output_ptr != NULL); - cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *bias_ptr = (float *)sim->allocate(output_c, DT_FLOAT); - assert(bias_ptr != NULL); - float *scale_ptr = (float *)sim->allocate(output_c, DT_FLOAT); - assert(scale_ptr != NULL); - cost_metrics.weights_memory += cost_metrics.total_mem_diff_from(sim->offset); - - std::function forward, backward; - forward = [&](ffStream_t stream) { - forward_kernel(stream, m, input_ptr, output_ptr, scale_ptr, bias_ptr); - }; - if (sim->computationMode == COMP_MODE_TRAINING) { - float *input_grad_ptr = - (float *)sim->allocate(sub_input.get_volume(), DT_FLOAT); - assert(input_grad_ptr != NULL); - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *output_grad_ptr = - (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - assert(output_grad_ptr != NULL); - cost_metrics.outputs_memory += - cost_metrics.total_mem_diff_from(sim->offset); - - float *scale_grad_ptr = (float *)sim->allocate(output_c, DT_FLOAT); - assert(scale_grad_ptr != NULL); - float *bias_grad_ptr = (float *)sim->allocate(output_c, DT_FLOAT); - assert(bias_grad_ptr != NULL); - cost_metrics.weights_memory += - cost_metrics.total_mem_diff_from(sim->offset); - - backward = [&](ffStream_t stream) { - backward_kernel(stream, - m, - input_ptr, - output_grad_ptr, - output_ptr, - input_grad_ptr, - scale_ptr, - scale_grad_ptr, - bias_grad_ptr, - sub_output.get_volume()); - }; - } - - inner_measure_operator_cost(sim, forward, backward, cost_metrics); - - if (sim->computationMode == COMP_MODE_TRAINING) { - printf("[Measure BatchNorm] name(%s) size(%zu) forward_time(%.4lf) " - "backward_time(%.4lf)\n", - name, - sub_input.get_volume(), - cost_metrics.forward_time, - cost_metrics.backward_time); - } else { - printf("[Measure BatchNorm] name(%s) size(%zu) forward_time(%.4lf)\n", - name, - sub_input.get_volume(), - cost_metrics.forward_time); - } - // Free batchnormmeta - delete m; - return true; +template <> +void register_task() { + OpTaskSignature bwd = + infer_bwd_signature(get_op_signature(BATCHNORM_FWD_TASK_ID)); + + register_task(BATCHNORM_BWD_TASK_ID, "BatchNorm Bwd", bwd, backward_task); } }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/batch_norm.h b/lib/runtime/src/ops/batch_norm.h index e54331665e..94bda5122b 100644 --- a/lib/runtime/src/ops/batch_norm.h +++ b/lib/runtime/src/ops/batch_norm.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_BATCH_NORM_H #include "op-attrs/ops/batch_norm.h" -#include "op_task_invocation.h" #include "sim_environment.h" +#include "task_spec/op_task_invocation.h" namespace FlexFlow { @@ -66,3 +66,243 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, } // namespace FlexFlow #endif + +// void BatchNorm::init(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_INIT_TASK_ID, get_init_task_signature()); +// } + +// void BatchNorm::forward(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_FWD_TASK_ID, get_fwd_task_signature()); +// } + +// void BatchNorm::backward(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_BWD_TASK_ID, get_bwd_task_signature()); +// } + +// Tensor batch_norm(const Tensor input, bool relu, char const *name) { +// assert(input->num_dims == 4); /*NCHW*/ +// Layer *bm = new Layer(this, +// OP_BATCHNORM, +// DT_FLOAT, +// name, +// 1 /*inputs*/, +// 2 /*weights*/, +// 1 /*outputs*/, +// input); +// int numdims = 4; +// bm->outputs[0] = create_tensor_legion_ordering( +// numdims, input->dims, DT_FLOAT, bm, 0, true /*create_grad*/); +// bm->add_int_property("relu", relu); +// layers.push_back(bm); +// return bm->outputs[0]; +// } + +// BatchNorm::BatchNorm(FFModel &model, +// const ParallelTensor _input, +// const ParallelTensor _scale, +// const ParallelTensor _bias, +// bool _relu, +// char const *name) +// : Op(model, +// OP_BATCHNORM, +// DT_FLOAT, +// name, +// 1 /*inputs*/, +// 2 /*weights*/, +// 1 /*outputs*/, +// _input, +// _scale, +// _bias), +// relu(_relu) { +// assert(_input->num_dims == 4); +// numOutputs = 1; +// ParallelDim dims[MAX_TENSOR_DIM]; +// for (int i = 0; i < _input->num_dims; i++) { +// dims[i] = _input->dims[_input->num_dims - 1 - i]; +// } +// outputs[0] = +// model.create_parallel_tensor(_input->num_dims, dims, DT_FLOAT, this); +// return; +// } + +/* + locals[0] = scale + locals[1] = bias +*/ + +// void BatchNorm::init(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_INIT_TASK_ID, get_init_task_signature()); +// assert(check_output_input_weight_same_parallel_is()); +// parallel_is = outputs[0]->parallel_is; +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_init(ff, argmap); +// IndexLauncher launcher(BATCHNORM_INIT_TASK_ID, +// parallel_is, +// TaskArgument(this, sizeof(BatchNorm)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(1, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(weights[0]->region, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// weights[0]->region)); +// launcher.add_field(2, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(weights[1]->region, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// weights[1]->region)); +// launcher.add_field(3, FID_DATA); +// FutureMap fm = runtime->execute_index_space(ctx, launcher); +// fm.wait_all_results(); +// set_opmeta_from_futuremap(ff, fm); +// } + +/* + regions[0]: input + regions[1]: output + regions[2](I): scale + regions[3](I): bias +*/ + +// void BatchNorm::forward(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_FWD_TASK_ID, get_fwd_task_signature()); +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_forward(ff, argmap); +// IndexLauncher launcher(BATCHNORM_FWD_TASK_ID, +// parallel_is, +// TaskArgument(NULL, 0), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_DISCARD, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(1, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(weights[0]->region, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// weights[0]->region)); +// launcher.add_field(2, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(weights[1]->region, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// weights[1]->region)); +// launcher.add_field(3, FID_DATA); + +// runtime->execute_index_space(ctx, launcher); +// } + +/* + regions[0](I): input + regions[1](O): ouptut + regions[2](I): scale + regions[3](I): bias +*/ + +// void BatchNorm::backward(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_BWD_TASK_ID, get_bwd_task_signature()); +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_backward(ff, argmap); +// IndexLauncher launcher(BATCHNORM_BWD_TASK_ID, +// parallel_is, +// TaskArgument(NULL, 0), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// // regions[0](I): input +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// // regions[1](I/O): input_grad (we only need grad tensors) +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// inputs[0]->region_grad)); +// launcher.add_field(1, FID_DATA); +// // regions[2](I): output +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(2, FID_DATA); +// // regions[3](I/O): output_grad +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// outputs[0]->region_grad)); +// launcher.add_field(3, FID_DATA); +// // regions[4](I): filter +// launcher.add_region_requirement(RegionRequirement(weights[0]->region, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// weights[0]->region)); +// launcher.add_field(4, FID_DATA); +// // regions[5](I/O): filter_grad +// launcher.add_region_requirement(RegionRequirement(weights[0]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// weights[0]->region_grad)); +// launcher.add_field(5, FID_DATA); +// // regions[6](I/O): bias_grad +// launcher.add_region_requirement(RegionRequirement(weights[1]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// weights[1]->region_grad)); +// launcher.add_field(6, FID_DATA); +// FutureMap fm = runtime->execute_index_space(ctx, launcher); +// } + +/* + regions[0](I): input + regions[1](I/O): input_grad + regions[2](I): output + regions[3](I/O): output_grad + regions[4](I): scale + regions[5](I/O): scale_grad + regions[6](I/O): bias_grad +*/ diff --git a/lib/runtime/src/task_spec/runtime_arg_ref.h b/lib/runtime/src/task_spec/runtime_arg_ref.h index 6b4345091a..033c2bcfbc 100644 --- a/lib/runtime/src/task_spec/runtime_arg_ref.h +++ b/lib/runtime/src/task_spec/runtime_arg_ref.h @@ -3,6 +3,7 @@ #include "arg_ref.h" #include "device_specific.h" +#include "runtime/config.h" namespace FlexFlow { @@ -15,6 +16,7 @@ using RuntimeArgRefSpec = ArgRefSpec; RuntimeArgRef profiling_settings(); RuntimeArgRef> ff_handle(); +RuntimeArgRef iteration_config(); } // namespace FlexFlow