diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index 0e4437bdb8..ec32648d0f 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -6,38 +6,43 @@ namespace FlexFlow { -class BatchMatmulPerDeviceState : public PerDeviceOpState { -public: - BatchMatmulPerDeviceState(FFHandler handler); - int a_seq_length_dim, b_seq_length_dim; +struct BMMPerDeviceState { + PerDeviceFFHandle handle; + Allocator allocator; + int a_seq_length_dim; + req b_seq_length_dim; }; +FF_VISITABLE_STRUCT_NO_EQ( + BMMPerDeviceState, handle, allocator, a_seq_length_dim, b_seq_length_dim); + namespace Kernels { namespace BatchMatmul { +BMMPerDeviceState init_kernel(PerDeviceFFHandle const &handle, + Allocator const &allocator, + int a_seq_length_dim, + int b_seq_length_dim); + void forward_kernel(ffStream_t stream, - BatchMatmulPerDeviceState const *, - float *o_ptr, - float const *a_ptr, - float const *b_ptr, - float const *c_ptr, + BMMPerDeviceState const &meta, + float *output_ptr, + float const *lhs_input_ptr, + float const *rhs_input_ptr, int m, int n, int k, int batch, - int a_seq_length_dim = -1, - int b_seq_length_dim = -1, int seq_length = -1); void backward_kernel(ffStream_t stream, - BatchMatmulPerDeviceState const *, + BMMPerDeviceState const &meta, float const *o_ptr, float const *o_grad_ptr, float const *a_ptr, float *a_grad_ptr, float const *b_ptr, float *b_grad_ptr, - float *c_grad_ptr, int m, int n, int k, diff --git a/lib/kernels/include/kernels/batch_norm_kernels.h b/lib/kernels/include/kernels/batch_norm_kernels.h index 6ff90299db..74dfc96068 100644 --- a/lib/kernels/include/kernels/batch_norm_kernels.h +++ b/lib/kernels/include/kernels/batch_norm_kernels.h @@ -8,30 +8,66 @@ namespace FlexFlow { -class BatchNormPerDeviceState : public PerDeviceOpState { -public: - BatchNormPerDeviceState(FFHandler handle, - std::unique_ptr allocator, - int output_n, - int output_c, - int output_h, - int output_w, - bool relu, - bool profiling); - ~BatchNormPerDeviceState(void); - - ffTensorDescriptor_t inputTensor, outputTensor, biasTensor; +struct BatchNormPerDeviceState { + PerDeviceFFHandle handle; + Allocator allocator; + ffTensorDescriptor_t inputTensor; + ffTensorDescriptor_t outputTensor; + ffTensorDescriptor_t biasTensor; ffActivationDescriptor_t actiDesc; ffBatchNormMode_t mode; - float *runningMean, *runningVar, *saveMean, *saveVar; - bool relu; - bool profiling; - std::unique_ptr allocator; + float *runningMean; + float *runningVar; + float *saveMean; + float *saveVar; + int output_n; + int output_c; + int output_h; + int output_w; + ProfilingSettings profiling; + req relu; }; +FF_VISITABLE_STRUCT_NO_EQ(BatchNormPerDeviceState, + handle, + allocator, + inputTensor, + outputTensor, + biasTensor, + actiDesc, + mode, + runningMean, + runningVar, + saveMean, + saveVar, + output_n, + output_c, + output_h, + output_w, + profiling, + relu); + namespace Kernels { namespace BatchNorm { +BatchNormPerDeviceState init_kernel(PerDeviceFFHandle handle, + Allocator allocator, + ffTensorDescriptor_t inputTensor, + ffTensorDescriptor_t outputTensor, + ffTensorDescriptor_t biasTensor, + ffActivationDescriptor_t actiDesc, + ffBatchNormMode_t mode, + float *runningMean, + float *runningVar, + float *saveMean, + float *saveVar, + int output_n, + int output_c, + int output_h, + int output_w, + ProfilingSettings profiling, + bool relu); + void forward_kernel(ffStream_t stream, BatchNormPerDeviceState *m, float const *input_ptr, diff --git a/lib/kernels/include/kernels/cast_kernels.h b/lib/kernels/include/kernels/cast_kernels.h index d43446883c..28985f5501 100644 --- a/lib/kernels/include/kernels/cast_kernels.h +++ b/lib/kernels/include/kernels/cast_kernels.h @@ -3,19 +3,26 @@ #include "kernels/accessor.h" #include "kernels/device.h" -#include "op-attrs/ffconst.h" namespace FlexFlow { -class CastPerDeviceState : public PerDeviceOpState { -public: - CastPerDeviceState(FFHandler handle); - DataType input_data_type, output_data_type; +struct CastPerDeviceState { + PerDeviceFFHandle handle; + DataType input_data_type; + req output_data_type; }; +FF_VISITABLE_STRUCT_NO_EQ(CastPerDeviceState, + handle, + input_data_type, + output_data_type); + namespace Kernels { namespace Cast { +CastPerDeviceState + init_kernel(PerDeviceFFHandle const &, DataType input, DataType output); + void forward_kernel(ffStream_t stream, CastPerDeviceState const *, GenericTensorAccessorR const &input, diff --git a/lib/kernels/include/kernels/combine_kernels.h b/lib/kernels/include/kernels/combine_kernels.h index 174d1eb925..24b9adb803 100644 --- a/lib/kernels/include/kernels/combine_kernels.h +++ b/lib/kernels/include/kernels/combine_kernels.h @@ -6,15 +6,17 @@ namespace FlexFlow { -class CombinePerDeviceState : public PerDeviceOpState { -public: - CombinePerDeviceState(FFHandler handle); - DataType data_type; +struct CombinePerDeviceState { + req data_type; }; +FF_VISITABLE_STRUCT_NO_EQ(CombinePerDeviceState, data_type); + namespace Kernels { namespace Combine { +CombinePerDeviceState init_kernel(DataType data_type); + void forward_kernel(ffStream_t stream, CombinePerDeviceState const *m, GenericTensorAccessorR const &input, diff --git a/lib/kernels/include/kernels/concat_kernels.h b/lib/kernels/include/kernels/concat_kernels.h index 741bbbe9f0..165f63f332 100644 --- a/lib/kernels/include/kernels/concat_kernels.h +++ b/lib/kernels/include/kernels/concat_kernels.h @@ -6,29 +6,29 @@ namespace FlexFlow { -class ConcatPerDeviceState : public PerDeviceOpState { -public: - ConcatPerDeviceState(FFHandler handle) : PerDeviceOpState(handle){}; - int legion_axis; - char op_name[MAX_OPNAME]; +struct ConcatPerDeviceState { + req legion_axis; }; +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(ConcatPerDeviceState, legion_axis); + namespace Kernels { namespace Concat { -void init_meta(ConcatPerDeviceState *meta, int legion_axis); +ConcatPerDeviceState init_kernel(ff_dim_t legion_axis); void forward_kernel(ffStream_t stream, ConcatPerDeviceState const *m, GenericTensorAccessorW const &output, - GenericTensorAccessorR const *inputs, + std::vector const &inputs, int num_inputs); -void backward_kernel(ffStream_t stream, - ConcatPerDeviceState const *m, - GenericTensorAccessorR const &output_grad, - GenericTensorAccessorW const *input_grads, - int num_inputs); +void backward_kernel( + ffStream_t stream, + ConcatPerDeviceState const *m, + GenericTensorAccessorR const &output_grad, + std::vector const &input_grads, + int num_inputs); } // namespace Concat } // namespace Kernels diff --git a/lib/kernels/include/kernels/conv_2d_kernels.h b/lib/kernels/include/kernels/conv_2d_kernels.h index 50b3c0601f..75eefbe1c2 100644 --- a/lib/kernels/include/kernels/conv_2d_kernels.h +++ b/lib/kernels/include/kernels/conv_2d_kernels.h @@ -5,45 +5,50 @@ namespace FlexFlow { -class Conv2DPerDeviceState : public PerDeviceOpState { -public: - Conv2DPerDeviceState(FFHandler handler); - ffTensorDescriptor_t inputTensor, biasTensor, outputTensor; +struct Conv2DPerDeviceState { + PerDeviceFFHandle handle; + ffTensorDescriptor_t inputTensor; + ffTensorDescriptor_t biasTensor; + ffTensorDescriptor_t outputTensor; ffFilterDescriptor_t filterDesc; ffActivationDescriptor_t actiDesc; ffConvolutionDescriptor_t convDesc; ffConvolutionFwdAlgo_t fwdAlgo; ffConvolutionBwdFilterAlgo_t bwdFilterAlgo; ffConvolutionBwdDataAlgo_t bwdDataAlgo; - bool relu, use_bias; - char op_name[MAX_OPNAME]; + req> activation; + req use_bias; }; +FF_VISITABLE_STRUCT_NO_EQ(Conv2DPerDeviceState, + handle, + inputTensor, + biasTensor, + outputTensor, + filterDesc, + actiDesc, + convDesc, + fwdAlgo, + bwdFilterAlgo, + bwdDataAlgo, + activation, + use_bias); + namespace Kernels { namespace Conv2D { -void init_kernel(Conv2DPerDeviceState *m, - int input_w, - int input_h, - int input_c, - int input_n, - int output_w, - int output_h, - int output_c, - int output_n, - int kernel_h, - int kernel_w, - int groups, - int stride_h, - int stride_w, - int pad_h, - int pad_w, - float const *input_ptr, - float *output_ptr, - float const *kernel_ptr, - float *kernel_grad_ptr, - float *forward_time = nullptr, - float *backward_time = nullptr); +Conv2DPerDeviceState init_kernel(PerDeviceFFHandle handle, + ffTensorDescriptor_t inputTensor, + ffTensorDescriptor_t biasTensor, + ffTensorDescriptor_t outputTensor, + ffFilterDescriptor_t filterDesc, + ffActivationDescriptor_t actiDesc, + ffConvolutionDescriptor_t convDesc, + ffConvolutionFwdAlgo_t fwdAlgo, + ffConvolutionBwdFilterAlgo_t bwdFilterAlgo, + ffConvolutionBwdDataAlgo_t bwdDataAlgo, + req> relu, + bool use_bias); void forward_kernel(ffStream_t stream, Conv2DPerDeviceState const *m, @@ -58,8 +63,8 @@ void backward_kernel(ffStream_t stream, float *input_grad_ptr, float const *output_ptr, float *output_grad_ptr, - float const *kernel_ptr, - float *kernel_grad_ptr, + float const *filter_ptr, + float *filter_grad_ptr, float *bias_grad_ptr); } // namespace Conv2D diff --git a/lib/kernels/src/cuda/batch_matmul_kernels.cu b/lib/kernels/src/cuda/batch_matmul_kernels.cu index 3593ac4ab2..cde0df93c0 100644 --- a/lib/kernels/src/cuda/batch_matmul_kernels.cu +++ b/lib/kernels/src/cuda/batch_matmul_kernels.cu @@ -18,9 +18,6 @@ namespace FlexFlow { -BatchMatmulPerDeviceState::BatchMatmulPerDeviceState(FFHandler handler) - : PerDeviceOpState(handler) {} - namespace Kernels { namespace BatchMatmul { @@ -124,7 +121,7 @@ O = A * B */ void forward_kernel(cudaStream_t stream, - BatchMatmulPerDeviceState const *meta, + BatchMatmulPerDeviceState const &meta, float *o_ptr, float const *a_ptr, float const *b_ptr, diff --git a/lib/kernels/src/hip/batch_matmul_kernels.cpp b/lib/kernels/src/hip/batch_matmul_kernels.cpp index d8b6500326..a06442d3d6 100644 --- a/lib/kernels/src/hip/batch_matmul_kernels.cpp +++ b/lib/kernels/src/hip/batch_matmul_kernels.cpp @@ -19,9 +19,6 @@ namespace FlexFlow { -BatchMatmulPerDeviceState::BatchMatmulPerDeviceState(FFHandler handler) - : PerDeviceOpState(handler) {} - namespace Kernels { namespace BatchMatmul { @@ -32,7 +29,7 @@ O: (batch, n, m) O = A * B */ void forward_kernel(hipStream_t stream, - BatchMatmulPerDeviceState const *meta, + BatchMatmulPerDeviceState const &meta, float *o_ptr, float const *a_ptr, float const *b_ptr, @@ -42,9 +39,9 @@ void forward_kernel(hipStream_t stream, int k, int batch, hipStream_t stream, - int a_seq_length_dim, - int b_seq_length_dim, int seq_length) { + int a_seq_length_dim = meta->a_seq_length_dim; + int b_seq_length_dim = meta->b_seq_length_dim; checkCUDA(hipblasSetStream(meta->handle.blas, stream)); checkCUDNN(miopenSetStream(meta->handle.dnn, stream)); diff --git a/lib/kernels/src/hip/concat_kernels.cpp b/lib/kernels/src/hip/concat_kernels.cpp index e818f8b568..f943bc9156 100644 --- a/lib/kernels/src/hip/concat_kernels.cpp +++ b/lib/kernels/src/hip/concat_kernels.cpp @@ -26,10 +26,6 @@ using Legion::Rect; namespace Kernels { namespace Concat { -void init_meta(ConcatPerDeviceState *m, int legion_axis) { - m->legion_axis = legion_axis; -} - template void calc_blk_size(coord_t &num_blocks, coord_t &blk_size, diff --git a/lib/op-attrs/include/op-attrs/get_op_type.h b/lib/op-attrs/include/op-attrs/get_op_type.h index 8b451b2705..910d5dc925 100644 --- a/lib/op-attrs/include/op-attrs/get_op_type.h +++ b/lib/op-attrs/include/op-attrs/get_op_type.h @@ -12,6 +12,7 @@ OperatorType get_op_type(BatchMatmulAttrs const &); OperatorType get_op_type(BatchNormAttrs const &); OperatorType get_op_type(BroadcastAttrs const &); OperatorType get_op_type(CastAttrs const &); +OperatorType get_op_type(CombineAttrs const &); OperatorType get_op_type(ConcatAttrs const &); OperatorType get_op_type(Conv2DAttrs const &); OperatorType get_op_type(DropoutAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/operator_attrs.h b/lib/op-attrs/include/op-attrs/operator_attrs.h index 5fd067313e..b64fe73497 100644 --- a/lib/op-attrs/include/op-attrs/operator_attrs.h +++ b/lib/op-attrs/include/op-attrs/operator_attrs.h @@ -43,6 +43,7 @@ using SharedOperatorAttrs = variant::value, ""); static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); +static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index c74824570c..00c700ba20 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -12,8 +12,10 @@ struct BatchMatmulAttrs { }; FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim); -CHECK_VALID_OP_ATTR(BatchMatmulAttrs); +int get_aSeqLengthDim(BatchMatmulAttrs const &attrs); +int get_bSeqLengthDim(BatchMatmulAttrs const &attrs); +CHECK_VALID_OP_ATTR(BatchMatmulAttrs); } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/concat.h b/lib/op-attrs/include/op-attrs/ops/concat.h index b9bd14a231..cbc864be44 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat.h +++ b/lib/op-attrs/include/op-attrs/ops/concat.h @@ -9,9 +9,9 @@ namespace FlexFlow { struct ConcatAttrs { - ff_dim_t axis; + req axis; }; -FF_VISITABLE_STRUCT(ConcatAttrs, axis); +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(ConcatAttrs, axis); CHECK_VALID_OP_ATTR(ConcatAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/src/batch_matmul.cc b/lib/op-attrs/src/batch_matmul.cc index 1cc8c5cfda..bd61c24737 100644 --- a/lib/op-attrs/src/batch_matmul.cc +++ b/lib/op-attrs/src/batch_matmul.cc @@ -2,6 +2,14 @@ namespace FlexFlow { +int get_aSeqLengthDim(BatchMatmulAttrs const &attrs) { + return attrs.a_seq_length_dim; +} + +int get_bSeqLengthDim(BatchMatmulAttrs const &attrs) { + return attrs.b_seq_length_dim; +} + /* bool BatchMatmulAttrs::is_valid( */ /* ParallelTensorShape const &lhs, ParallelTensorShape const &rhs) const { */ diff --git a/lib/runtime/include/runtime/config.h b/lib/runtime/include/runtime/config.h index a7b8d86171..ef7e779469 100644 --- a/lib/runtime/include/runtime/config.h +++ b/lib/runtime/include/runtime/config.h @@ -104,13 +104,14 @@ struct FFConfig : public use_visitable_cmp { int python_data_loader_type = 2; }; -class FFIterationConfig { -public: +struct FFIterationConfig { FFIterationConfig(); void reset(); int seq_length; }; +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(FFIterationConfig, seq_length); + enum FieldIDs { FID_DATA, }; diff --git a/lib/runtime/src/ops/attention.cc b/lib/runtime/src/ops/attention.cc index bca87bdb53..94e2b03731 100644 --- a/lib/runtime/src/ops/attention.cc +++ b/lib/runtime/src/ops/attention.cc @@ -121,18 +121,6 @@ static DeviceSpecific int num_samples = get_piece_shape(query_parallel_tensor_shape)[ff_dim_t(2)]; int num_heads = get_piece_shape(weight_parallel_tensor_shape)[ff_dim_t(1)]; - assert(qoSeqLength == query.shape[legion_dim_t(1)]); - assert(qSize == query.shape[legion_dim_t(0)]); - assert(num_samples == key.shape[legion_dim_t(2)]); - assert(kvSeqLength == key.shape[legion_dim_t(1)]); - assert(kSize == key.shape[legion_dim_t(0)]); - assert(num_samples == value.shape[legion_dim_t(2)]); - assert(kvSeqLength == value.shape[legion_dim_t(1)]); - assert(vSize == value.shape[legion_dim_t(0)]); - assert(num_samples == output.shape[legion_dim_t(2)]); - assert(qoSeqLength == output.shape[legion_dim_t(1)]); - assert(oProjSize == output.shape[legion_dim_t(0)]); - DeviceSpecific per_device_state = acc.create_device_specific( init_kernel(handle, @@ -149,9 +137,6 @@ static DeviceSpecific qoSeqLength, kvSeqLength, attrs.add_bias_kv)); - - assert(weight.shape.get_volume() * sizeof(float) == - acc.unwrap(per_device_state)->weightSize); return per_device_state; } diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 3e860bd413..45c5e11b9c 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -15,752 +15,268 @@ #include "batch_matmul.h" #include "kernels/batch_matmul_kernels.h" -#include "kernels/profiling.h" -#include "legion/legion_utilities.h" -#include "tasks.h" +#include "legion.h" +#include "op-attrs/get_output_shapes.h" +#include "op-attrs/ops/batch_matmul.h" namespace FlexFlow { using namespace FlexFlow::Kernels::BatchMatmul; +using Legion::Context; +using Legion::PhysicalRegion; +using Legion::Runtime; +using Legion::Task; + enum Slots { - A_INPUT, - B_INPUT, - OUTPUT, - A_INPUT_GRAD, - B_INPUT_GRAD, - OUTPUT_GRAD, - ATTRS, - PROFILING + A_INPUT, // tensor + B_INPUT, // tensor + OUTPUT, // tensor + PROFILING, + HANDLE, + A_SEQ_LENGTH_DIM, + B_SEQ_LENGTH_DIM, + PER_DEVICE_STATE, + ITERATION_CONFIG }; OpTaskInvocation init(BatchMatmulAttrs const &attrs) { - OpTaskBinding b; + OpTaskBinding init; - b.bind_arg(ATTRS, attrs); - b.bind_arg(PROFILING, enable_profiling()); + init.bind_arg(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); + init.bind_arg(B_SEQ_LENGTH_DIM, get_bSeqLengthDim(attrs)); + init.bind_arg(HANDLE, ff_handle()); - return {BATCHMATMUL_INIT_TASK_ID, b}; + return {BATCHMATMUL_INIT_TASK_ID, init}; } OpTaskInvocation forward(BatchMatmulAttrs const &attrs) { - OpTaskBinding b; - - b.bind(A_INPUT, input_tensor(0)); - b.bind(B_INPUT, input_tensor(1)); - b.bind(OUTPUT, output_tensor(0)); - - return {BATCHMATMUL_FWD_TASK_ID, b}; -} - -OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { - OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - - return {BATCHMATMUL_BWD_TASK_ID, b}; -} - -BatchMatmulParams BatchMatmul::get_params() const { - BatchMatmulParams params; - params.a_seq_length_dim = inputs[0]->num_dims - 1 - this->a_seq_length_dim; - params.b_seq_length_dim = inputs[1]->num_dims - 1 - this->b_seq_length_dim; - return params; -} - -Tensor FFModel::batch_matmul(const Tensor A, - const Tensor B, - int a_seq_length_dim, - int b_seq_length_dim, - char const *name) { - Layer *bmm = new Layer(this, - OP_BATCHMATMUL, - DT_FLOAT, - name, - 2 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - A, - B); - assert((a_seq_length_dim <= 1) && - "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " - "Fortran ordering)."); - assert((b_seq_length_dim <= 1) && - "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " - "Fortran ordering)."); - assert(A->num_dims == B->num_dims); - for (int i = A->num_dims - 1; i >= 2; i--) { - assert(A->dims[i] == B->dims[i]); - } - assert(A->dims[0] == B->dims[1]); - int dims[MAX_TENSOR_DIM]; - int numdim = A->num_dims; - for (int i = 0; i < numdim; i++) { - dims[i] = A->dims[i]; - } - dims[0] = B->dims[0]; - bmm->outputs[0] = create_tensor_legion_ordering( - numdim, dims, A->data_type, bmm, 0, true /*create_grad*/); - bmm->add_int_property("a_seq_length_dim", a_seq_length_dim); - bmm->add_int_property("b_seq_length_dim", b_seq_length_dim); - layers.push_back(bmm); - return bmm->outputs[0]; -} - -Op *BatchMatmul::create_operator_from_layer( - FFModel &model, - Layer const *layer, - std::vector const &inputs) { - long long value; - layer->get_int_property("a_seq_length_dim", value); - int a_seq_length_dim = value; - layer->get_int_property("b_seq_length_dim", value); - int b_seq_length_dim = value; - return new BatchMatmul(model, - inputs[0], - inputs[1], - a_seq_length_dim, - b_seq_length_dim, - layer->name); -} - -BatchMatmul::BatchMatmul( - FFModel &model, - BatchMatmulParams const ¶ms, - std::pair const &inputs, - char const *name) - : BatchMatmul(model, - inputs.first, - inputs.second, - params.a_seq_length_dim, - params.b_seq_length_dim, - name) {} - -// return A*B -BatchMatmul::BatchMatmul(FFModel &model, - const ParallelTensor A, - const ParallelTensor B, - int _a_seq_length_dim, - int _b_seq_length_dim, - char const *name) - : Op(model, - OP_BATCHMATMUL, - DT_FLOAT, - name, - 2 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - A, - B), - a_seq_length_dim(A->num_dims - 1 - _a_seq_length_dim), - b_seq_length_dim(B->num_dims - 1 - _b_seq_length_dim) { - assert((_a_seq_length_dim <= 1) && - "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " - "Fortran ordering)."); - assert((_b_seq_length_dim <= 1) && - "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " - "Fortran ordering)."); - assert(A->num_dims == B->num_dims); - for (int i = A->num_dims - 1; i >= 2; i--) { - assert(A->dims[i] == B->dims[i]); - } - assert(A->dims[0] == B->dims[1]); - ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < A->num_dims; i++) { - dims[i] = A->dims[i]; - } - dims[0] = B->dims[0]; - numOutputs = 1; - outputs[0] = model.create_parallel_tensor_legion_ordering( - A->num_dims, dims, DT_FLOAT, this); - // C is not none - // if (C != Tensor::NO_TENSOR) { - // numInputs = 3; - // assert(C.num_dims == outputs[0].num_dims); - // for (int i = 0; i < C.num_dims; i++) - // assert(C.adim[i] == outputs[0].adim[i]); - //} -} - -void BatchMatmul::serialize(Legion::Serializer &sez) const { - BatchMatmulParams params = get_params(); - sez.serialize(params.a_seq_length_dim); - sez.serialize(params.b_seq_length_dim); -} - -using PCG::Node; -/*static*/ -Node BatchMatmul::deserialize(FFModel &ff, - Legion::Deserializer &dez, - ParallelTensor inputs[], - int num_inputs) { - assert(num_inputs == 2); - int a_seq_length_dim, b_seq_length_dim; - dez.deserialize(a_seq_length_dim); - dez.deserialize(b_seq_length_dim); - - BatchMatmulParams params; - params.a_seq_length_dim = a_seq_length_dim; - params.b_seq_length_dim = b_seq_length_dim; - return ff.get_or_create_node({inputs[0], inputs[1]}, params); -} + OpTaskBinding fwd; -Op *BatchMatmul::materialize(FFModel &ff, - ParallelTensor inputs[], - int num_inputs) const { - BatchMatmulParams params = get_params(); - return new BatchMatmul(ff, params, {inputs[0], inputs[1]}, this->name); -} + fwd.bind(A_INPUT, input_tensor(0)); + fwd.bind(B_INPUT, input_tensor(1)); + fwd.bind(OUTPUT, output_tensor(0)); -template <> -void register_task() { - OpTaskSignature sig(OpTaskType::INIT); + fwd.bind_arg(PROFILING, profiling_settings()); + fwd.bind_arg(PER_DEVICE_STATE, per_device_op_state()); + fwd.bind_arg(ITERATION_CONFIG, iteration_config()); - sig.add_arg_slot(ATTRS); - sig.add_arg_slot(PROFILING); - - register_task(BATCHMATMUL_INIT_TASK_ID, "BatchMatmul Init", sig, init_task); + return {BATCHMATMUL_FWD_TASK_ID, fwd}; } -static OpTaskSignature get_fwd_task_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_input_slot(A_INPUT, READ_WRITE); - fwd.add_input_slot(B_INPUT, READ_WRITE); - fwd.add_output_slot(OUTPUT); +OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { + OpTaskBinding bwd = infer_bwd_binding(forward(attrs).binding); - return fwd; + return {BATCHMATMUL_BWD_TASK_ID, bwd}; } -static OpTaskSignature get_bwd_task_signature() { - OpTaskSignature bwd(OpTaskType::BWD); +static DeviceSpecific + init_task_impl(TaskArgumentAccessor const &acc) { + int const a_seq_length_dim = acc.get_argument(A_SEQ_LENGTH_DIM); + int const b_seq_length_dim = acc.get_argument(B_SEQ_LENGTH_DIM); + PerDeviceFFHandle handle = acc.get_argument(HANDLE); + Allocator allocator = acc.get_allocator(); - bwd.add_input_slot(A_INPUT); - bwd.add_input_slot(B_INPUT); - bwd.add_input_grad_slot(A_INPUT_GRAD); - bwd.add_input_grad_slot(B_INPUT_GRAD); - bwd.add_output_slot(OUTPUT); - bwd.add_output_grad_slot(OUTPUT_GRAD); + DeviceSpecific per_device_state = + acc.create_device_specific( + init_kernel(handle, allocator, a_seq_length_dim, b_seq_length_dim)); - return bwd; + return per_device_state; } -OpTaskBinding BatchMatmul::get_init_task_binding() const { - OpTaskBinding binding; - - binding.bind_arg(ATTRS, this->attrs); - binding.bind_arg(PROFILING, this->profiling); - - return binding; +static DeviceSpecific + init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + return init_task_impl(acc); } -OpTaskBinding BatchMatmul::get_fwd_task_binding() const { - OpTaskBinding binding; +static optional forward_task_impl(TaskArgumentAccessor const &acc) { + auto a_input = acc.get_tensor(A_INPUT); + auto b_input = acc.get_tensor(B_INPUT); + auto output = acc.get_tensor(OUTPUT); - binding.bind(A_INPUT, input_tensor(0)); - binding.bind(B_INPUT, input_tensor(1)); - binding.bind(OUTPUT, output_tensor(0)); + ProfilingSettings profiling = acc.get_argument(PROFILING); + auto per_device_state = acc.get_argument(PER_DEVICE_STATE); + FFIterationConfig iter_config = + acc.get_argument(ITERATION_CONFIG); - binding.bind_arg(ATTRS, this->attrs); - return binding; -} + int m = b_input.shape[legion_dim_t(0)]; + assert(m == output.shape[legion_dim_t(0)]); + int n = a_input.shape[legion_dim_t(1)]; + assert(n == output.shape[legion_dim_t(1)]); + int k = a_input.shape[legion_dim_t(0)]; + assert(k == b_input.shape[legion_dim_t(1)]); -OpTaskBinding BatchMatmul::get_bwd_task_binding() const { - OpTaskBinding binding; - binding.bind(A_INPUT, input_tensor(0)); - binding.bind(B_INPUT, input_tensor(1)); - binding.bind_grad(A_INPUT_GRAD, input_tensor(0).grad()); - binding.bind_grad(B_INPUT_GRAD, input_tensor(1).grad()); + assert(a_input.shape.get_volume() == b_input.shape.get_volume()); + assert(a_input.shape.get_volume() == output.shape.get_volume()); - binding.bind(OUTPUT, output_tensor(0)); - binding.bind_grad(OUTPUT_GRAD, output_tensor(0).grad()); + int batch = 1; + for (int i = 2; i < a_input.shape.get_dim(); + i++) { // get_dim() or get_volume()? + int dim_size = a_input.shape[legion_dim_t(i)]; + assert(dim_size == b_input.shape[legion_dim_t(i)]); + assert(dim_size == output.shape[legion_dim_t(i)]); + batch *= dim_size; + } - binding.bind_arg(ATTRS, this->attrs); - return binding; + return profile(forward_kernel, + profiling, + "[BatchMatmul] forward_time = %.2lfms\n", + per_device_state, + output.get_float_ptr(), + a_input.get_float_ptr(), + b_input.get_float_ptr(), + m, + n, + k, + batch, + iter_config.seq_length); } -void BatchMatmul::init(FFModel const &ff) { - int dim = outputs[0]->num_dims; - switch (dim) { -#define DIMFUNC(DIM) \ - case DIM: { \ - // init_with_dim(ff); - this->execute_task(ff, BATCHMATMUL_INIT_TASK_ID, get_init_task_signature()); - break; - } - LEGION_FOREACH_N(DIMFUNC) -#undef DIMFUNC - default: - assert(false); -} -} // namespace FlexFlow -// / -// template -// void BatchMatmul::init_with_dim(FFModel const &ff) { -// assert(check_output_input_weight_same_parallel_is()); -// parallel_is = outputs[0]->parallel_is; -// ArgumentMap argmap; -// Context ctx = ff.config.lg_ctx; -// Runtime *runtime = ff.config.lg_hlr; -// set_argumentmap_for_init(ff, argmap); -// IndexLauncher launcher(BATCHMATMUL_INIT_TASK_ID, -// parallel_is, -// TaskArgument(this, sizeof(BatchMatmul)), -// argmap, -// Predicate::TRUE_PRED, -// false /*must*/, -// 0 /*mapper_id*/, -// outputs[0]->machine_view.hash()); -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, -// 0 /*projection id*/, -// WRITE_ONLY, -// EXCLUSIVE, -// outputs[0]->region)); -// launcher.add_field(0, FID_DATA); -// for (int i = 0; i < numInputs; i++) { -// launcher.add_region_requirement(RegionRequirement(inputs[i]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[i]->region)); -// launcher.add_field(i + 1, FID_DATA); -// } -// FutureMap fm = runtime->execute_index_space(ctx, launcher); -// fm.wait_all_results(); -// set_opmeta_from_futuremap(ff, fm); -// } - -PerDeviceOpState * - BatchMatmul::init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { +static void forward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { TaskArgumentAccessor acc(task, regions, ctx, runtime); - auto const &attrs = acc.get_argument(ATTRS); - bool profiling = acc.get_argument(PROFILING); - - FFHandler handle = *((FFHandler const *)task->local_args); - BatchMatmulPerDeviceState *m = new BatchMatmulPerDeviceState(handle); - m->profiling = profiling; - m->a_seq_length_dim = attrs.a_seq_length_dim; - m->b_seq_length_dim = attrs.b_seq_length_dim; - return m; + forward_task_impl(acc); } -void BatchMatmul::forward(FFModel const &ff) { - int dim = outputs[0]->num_dims; - switch (dim) { -#define DIMFUNC(DIM) \ - case DIM: { \ - // forward_with_dim(ff); - this->execute_task(ff, BATCHMATMUL_FWD_TASK_ID, get_fwd_task_signature()); - break; - } - LEGION_FOREACH_N(DIMFUNC) -#undef DIMFUNC - default: - assert(false); -} -} +static optional backward_task_impl(TaskArgumentAccessor const &acc) { + // BatchMatmul* bmm = (BatchMatmul*) task->args; + FFIterationConfig iter_config = + acc.get_argument(ITERATION_CONFIG); + ProfilingSettings profiling = acc.get_argument(PROFILING); + auto per_device_state = acc.get_argument(PER_DEVICE_STATE); -// template -// void BatchMatmul::forward_with_dim(FFModel const &ff) { -// ArgumentMap argmap; -// Context ctx = ff.config.lg_ctx; -// Runtime *runtime = ff.config.lg_hlr; -// set_argumentmap_for_forward(ff, argmap); -// IndexLauncher launcher( -// BATCHMATMUL_FWD_TASK_ID, -// parallel_is, -// TaskArgument(&ff.iter_config, sizeof(FFIterationConfig)), -// argmap, -// Predicate::TRUE_PRED, -// false /*must*/, -// 0 /*mapper_id*/, -// outputs[0]->machine_view.hash()); -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, -// 0 /*projection id*/, -// WRITE_ONLY, -// EXCLUSIVE, -// outputs[0]->region)); -// launcher.add_field(0, FID_DATA); -// for (int i = 0; i < numInputs; i++) { -// launcher.add_region_requirement(RegionRequirement(inputs[i]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[i]->region)); -// launcher.add_field(i + 1, FID_DATA); -// } -// runtime->execute_index_space(ctx, launcher); -// } - -/* - regions[0](O): output - regions[1](I): A - regions[2](I): B - ////////////////////(optional) regions[3](I): C -- TODO: is C deprecated? - output = A * B /////////+ C -*/ -void BatchMatmul::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 3); - assert(task->regions.size() == 3); + auto output = acc.get_tensor(OUTPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); + // is this equivalent to checking `Domain` equality? + assert(output.shape == output_grad.shape); - TaskArgumentAccessor acc(task, regions, ctx, runtime); + auto a_input = acc.get_tensor(A_INPUT); + auto a_input_grad = acc.get_tensor_grad(A_INPUT); + assert(a_input.shape == a_input_grad.shape); - // const BatchMatmul* bmm = (const BatchMatmul*) task->args; - FFIterationConfig const *iter_config = (FFIterationConfig const *)task->args; - // BatchMatmulMeta const *meta = *((BatchMatmulMeta **)task->local_args); - BatchMatmulPerDeviceState const *meta = - *((BatchMatmulPerDeviceState **)task->local_args); - - auto a_input = acc.get_tensor(A_INPUT); - auto b_input = acc.get_tensor(B_INPUT); - auto output = acc.get_tensor(OUTPUT); - - int m = b_input.shape[0]; - assert(m == output.shape[0]); - int n = a_input.shape[1]; - assert(n == output.shape[1]); - int k = a_input.shape[0]; - assert(k == b_input.shape[1]); - - assert(a_input.shape.size() == b_input.shape.size()); - assert(a_input.shape.size() == output.shape.size()); + auto b_input = acc.get_tensor(B_INPUT); + auto b_input_grad = acc.get_tensor_grad(B_INPUT); + assert(b_input.shape == b_input_grad.shape); + + // check dins + int m = b_input.shape[legion_dim_t(0)]; + assert(m == output.shape[legion_dim_t(0)]); + int n = a_input.shape[legion_dim_t(1)]; + assert(n == output.shape[legion_dim_t(1)]); + int k = a_input.shape[legion_dim_t(0)]; + assert(k == b_input.shape[legion_dim_t(1)]); + assert(a_input.shape.get_volume() == b_input.shape.get_volume()); + assert(a_input.shape.get_volume() == output.shape.get_volume()); int batch = 1; - for (int i = 2; i < a_input.shape.size(); i++) { - int dim_size = a_input.shape[i]; - assert(dim_size == b_input.shape[i]); - assert(dim_size == output.shape[i]); + for (int i = 2; i < a_input.shape.dims.num_dims(); i++) { + int dim_size = a_input.shape[legion_dim_t(i)]; + assert(dim_size == b_input.shape[legion_dim_t(i)]); + assert(dim_size == output.shape[legion_dim_t(i)]); batch *= dim_size; } - float *out_ptr = output.get_float_ptr(); - c float const *a_ptr = a_input.get_float_ptr(); - float const *b_ptr = b_input.get_float_ptr(); - float const *c_ptr = NULL; - // if (regions.size() == 4) { - // Domain c_domain = runtime->get_index_space_domain( - // ctx, task->regions[3].region.get_index_space()); - // assert(c_domain == a_domain); - // c_ptr = helperGetTensorPointerRO( - // regions[3], task->regions[3], FID_DATA, ctx, runtime); - // } - - profile(forward_kernel, - meta->profiling, - "[BatchMatmul] forward_time = %.2lfms\n", - out_ptr, - a_ptr, - b_ptr, - c_ptr, - m, - n, - k, - batch, - meta->a_seq_length_dim, - meta->b_seq_length_dim, - iter_config->seq_length); -} -void BatchMatmul::backward(FFModel const &ff) { - int dim = outputs[0]->num_dims; - switch (dim) { -#define DIMFUNC(DIM) \ - case DIM: { \ - backward_with_dim(ff); \ - break; \ - } - LEGION_FOREACH_N(DIMFUNC) -#undef DIMFUNC - default: - assert(false); - } + return profile(backward_kernel, + profiling, + "[BatchMatmul] backward_time = %.2lfms\n", + per_device_state, + output.get_float_ptr(), + output_grad.get_float_ptr(), + a_input.get_float_ptr(), + a_input_grad.get_float_ptr(), + b_input.get_float_ptr(), + b_input_grad.get_float_ptr(), + m, + n, + k, + batch); } -/* - regions[0](I): output - regions[1](I): output_grad - regions[2](I): A - regions[3](I/O): A_grad - regions[4](I): B - regions[5](I/O): B_grad - regions[6](I/O): C_grad -*/ -// template -// void BatchMatmul::backward_with_dim(FFModel const &ff) { -// ArgumentMap argmap; -// Context ctx = ff.config.lg_ctx; -// Runtime *runtime = ff.config.lg_hlr; -// set_argumentmap_for_backward(ff, argmap); -// IndexLauncher launcher( -// BATCHMATMUL_BWD_TASK_ID, -// parallel_is, -// TaskArgument(&ff.iter_config, sizeof(FFIterationConfig)), -// argmap, -// Predicate::TRUE_PRED, -// false /*must*/, -// 0 /*mapper_id*/, -// outputs[0]->machine_view.hash()); -// // regions[0](I): output -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// outputs[0]->region)); -// launcher.add_field(0, FID_DATA); -// // regions[1](I): output_grad -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// outputs[0]->region_grad)); -// launcher.add_field(1, FID_DATA); -// // regions[2](I): A -// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[0]->region)); -// launcher.add_field(2, FID_DATA); -// // regions[3](I/O): A_grad -// launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, -// 0 /*projection id*/, -// READ_WRITE, -// EXCLUSIVE, -// inputs[0]->region_grad)); -// launcher.add_field(3, FID_DATA); -// // regions[4](I): B -// launcher.add_region_requirement(RegionRequirement(inputs[1]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[1]->region)); -// launcher.add_field(4, FID_DATA); -// // regions[5](I/O): B_grad -// launcher.add_region_requirement(RegionRequirement(inputs[1]->part_grad, -// 0 /*projection id*/, -// READ_WRITE, -// EXCLUSIVE, -// inputs[1]->region_grad)); -// launcher.add_field(5, FID_DATA); -// runtime->execute_index_space(ctx, launcher); -// } - -/* - regions[0](I): output - regions[1](I): output_grad - regions[2](I): A - regions[3](I/O): A_grad - regions[4](I): B - regions[5](I/O): B_grad - regions[6](I/O): C_grad -*/ -__host__ void - BatchMatmul::backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - // Currently assume C is NULL - assert(regions.size() == 6); - assert(task->regions.size() == 6); - // BatchMatmul* bmm = (BatchMatmul*) task->args; +static void backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { TaskArgumentAccessor acc(task, regions, ctx, runtime); - FFIterationConfig const *iter_config = (FFIterationConfig const *)task->args; - BatchMatmulPerDeviceState const *meta = - *((BatchMatmulPerDeviceState **)task->local_args); - // output domains - auto output = acc.get_tensor(OUTPUT); - auto output_grad = acc.get_tensor(OUTPUT_GRAD); - assert(output == - output_grad); // is this equivalent to checking `Domain` equality? - // A domains - auto a_input = acc.get_tensor(A_INPUT); - auto a_input_grad = acc.get_tensor(A_INPUT_GRAD); - assert(a_input == a_input_grad); - // B domains - auto b_input = acc.get_tensor(B_INPUT); - auto b_input_grad = acc.get_tensor(B_INPUT_GRAD); - assert(b_input == b_input_grad); - - // check dins - int m = b_input.shape[0]; - assert(m == output.shape[0]); - int n = a_input.shape[1]; - assert(n == output.shape[1]); - int k = a_input.shape[0]; - assert(k == b_input.shape[1]); - assert(a_input.shape.size() == b_input.shape.size()); - assert(a_input.shape.size() == output.shape.size()); - int batch = 1; - for (int i = 2; i < a_input.shape.size(); i++) { - int dim_size = a_input.shape[i]; - assert(dim_size == b_input.shape[i]); - assert(dim_size == output.shape[i]); - batch *= dim_size; - } - // get pointers - float const *out_ptr = output.get_float_ptr(); - float const *out_grad_ptr = output_grad.get_float_ptr(); - float const *a_ptr = a_input.get_float_ptr(); - float *a_grad_ptr = a_input_grad.get_float_ptr(); - float const *b_ptr = b_input.get_float_ptr(); - float *b_grad_ptr = b_input_grad.get_float_ptr(); - - float *c_grad_ptr = NULL; - - // TODO: add support for meta->a_seq_length_dim >= 0 - // or meta->b_seq_length_dim >= 0 - assert((meta->a_seq_length_dim >= a_len) || (iter_config->seq_length == 0)); - assert((meta->b_seq_length_dim >= b_len) || (iter_config->seq_length == 0)); - - profile(backward_kernel, - meta->profiling, - "[BatchMatmul] backward_time = %.2lfms\n", - meta, - out_ptr, - out_grad_ptr, - a_ptr, - a_grad_ptr, - b_ptr, - b_grad_ptr, - c_grad_ptr, - m, - n, - k, - batch); + backward_task_impl(acc); } -void BatchMatmul::print_layer(FFModel const &ff) { - return; +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + BatchMatmulAttrs const &attrs, + InputParallelTensorDesc const &a_input, + InputParallelTensorDesc const &b_input, + ProfilingSettings const &settings, + MachineView const &pc) { + auto env = sim.new_environment(); + + ParallelTensorShape output_shape = + get_output_shape(attrs, a_input.shape, b_input.shape); + + SimTaskBinding init_binding; + init_binding.bind_arg(A_SEQ_LENGTH_DIM, get_aSeqLengthDim(attrs)); + init_binding.bind_arg(B_SEQ_LENGTH_DIM, get_bSeqLengthDim(attrs)); + init_binding.bind_arg(HANDLE, ff_handle()); + + auto init_accessor = + env.get_init_accessor(BATCHMATMUL_INIT_TASK_ID, init_binding); + DeviceSpecific per_device_state = + init_task_impl(init_accessor); + + SimTaskBinding fwd_binding; + fwd_binding.bind(A_INPUT, a_input); + fwd_binding.bind(B_INPUT, b_input); + fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind_arg(PROFILING, settings); + fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); + + SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); + + auto fwd_accessor = + env.get_fwd_accessor(BATCHMATMUL_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = + env.get_bwd_accessor(BATCHMATMUL_BWD_TASK_ID, bwd_binding); + + float forward_time = forward_task_impl(fwd_accessor).value(); + float backward_time = backward_task_impl(bwd_accessor).value(); + + float sync_time = default_estimate_sync_time(env); + return make_metrics(forward_time, backward_time, sync_time, env); } -bool BatchMatmul::measure_operator_cost(Simulator *sim, - MachineView const &pc, - CostMetrics &cost_metrics) const { - ParallelTensorBase sub_output, sub_input0, sub_input1; - if (!outputs[0]->get_sub_tensor(pc, sub_output)) { - return false; - } - if (!inputs[0]->get_sub_tensor(pc, sub_input0)) { - return false; - } - if (!inputs[1]->get_sub_tensor(pc, sub_input1)) { - return false; - } - - int input0_c = sub_input0.dims[0].size; - int input0_r = sub_input0.dims[1].size; - int input1_c = sub_input1.dims[0].size; - int input1_r = sub_input1.dims[1].size; - int output_c = sub_output.dims[0].size; - int output_r = sub_output.dims[1].size; +template <> +void register_task() { + OpTaskSignature init(OpTaskType::INIT); - assert(input0_c == input1_r); - assert(input0_r == output_r); - assert(input1_c == output_c); + init.add_arg_slot(A_SEQ_LENGTH_DIM); + init.add_arg_slot(B_SEQ_LENGTH_DIM); + init.add_unchecked_arg_slot(HANDLE); - assert(sub_input0.dims[2] == sub_input1.dims[2]); - assert(sub_input1.dims[2] == sub_output.dims[2]); - int batch = 1; - assert(sub_input0.num_dims == sub_input1.num_dims); - for (int i = 2; i < sub_input0.num_dims; i++) { - assert(sub_input0.dims[i] == sub_input1.dims[i]); - assert(sub_input0.dims[i] == sub_output.dims[i]); - batch *= sub_input0.dims[i].size; - } + register_task(BATCHMATMUL_INIT_TASK_ID, "BatchMatmul Init", init, init_task); +} - BatchMatmulPerDeviceState *meta = sim->batch_matmul_meta; - - // allocate tensors in simulator - sim->free_all(); - float *a_ptr = (float *)sim->allocate(sub_input0.get_volume(), DT_FLOAT); - assert(a_ptr != NULL); - float *b_ptr = (float *)sim->allocate(sub_input1.get_volume(), DT_FLOAT); - assert(b_ptr != NULL); - float *c_ptr = NULL; - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *out_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - assert(out_ptr != NULL); - cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - int m = input1_c; - int n = input0_r; - int k = input0_c; - - assert(meta->profiling == false); - - std::function forward, backward; - forward = [&](ffStream_t stream) { - forward_kernel(stream, meta, out_ptr, a_ptr, b_ptr, c_ptr, m, n, k, batch); - }; - - if (sim->computationMode == COMP_MODE_TRAINING) { - float *a_grad_ptr = - (float *)sim->allocate(sub_input0.get_volume(), DT_FLOAT); - float *b_grad_ptr = - (float *)sim->allocate(sub_input1.get_volume(), DT_FLOAT); - float *c_grad_ptr = NULL; - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *out_grad_ptr = - (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - assert(out_grad_ptr != NULL); - cost_metrics.outputs_memory += - cost_metrics.total_mem_diff_from(sim->offset); - - backward = [&](ffStream_t stream) { - backward_kernel(stream, - meta, - out_ptr, - out_grad_ptr, - a_ptr, - a_grad_ptr, - b_ptr, - b_grad_ptr, - c_grad_ptr, - m, - n, - k, - batch); - }; - } +template <> +void register_task() { + OpTaskSignature fwd(OpTaskType::FWD); - inner_measure_operator_cost(sim, forward, backward, cost_metrics); - - if (sim->computationMode == COMP_MODE_TRAINING) { - printf("[Measure BatchMatmul] name(%s) adim(%d %d %d) bdim(%d %d %d) " - "odim(%d %d %d) forward_time(%.4lf) backward_time(%.4lf)\n", - name, - batch, - input0_r, - input0_c, - batch, - input1_r, - input1_c, - batch, - output_r, - output_c, - cost_metrics.forward_time, - cost_metrics.backward_time); - } else { - printf("[Measure BatchMatmul] name(%s) adim(%d %d %d) bdim(%d %d %d) " - "odim(%d %d %d) forward_time(%.4lf)\n", - name, - batch, - input0_r, - input0_c, - batch, - input1_r, - input1_c, - batch, - output_r, - output_c, - cost_metrics.forward_time); - } + fwd.add_input_slot(A_INPUT); + fwd.add_input_slot(B_INPUT); + fwd.add_output_slot(OUTPUT); + fwd.add_arg_slot(PROFILING); + fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); - return true; + register_task(BATCHMATMUL_FWD_TASK_ID, "BatchMatmul Fwd", fwd, forward_task); } + +template <> +void register_task() { + OpTaskSignature bwd = + infer_bwd_signature(get_op_signature(ATTENTION_FWD_TASK_ID)); + + register_task(BATCHMATMUL_BWD_TASK_ID, "BatchMatmul Bwd", bwd, backward_task); } -; // namespace FlexFlow + +}; // namespace FlexFlow diff --git a/lib/runtime/src/ops/batch_matmul.h b/lib/runtime/src/ops/batch_matmul.h index c133c2a875..018fe1d582 100644 --- a/lib/runtime/src/ops/batch_matmul.h +++ b/lib/runtime/src/ops/batch_matmul.h @@ -1,10 +1,14 @@ #ifndef _FLEXFLOW_BATCH_MATMUL_H #define _FLEXFLOW_BATCH_MATMUL_H +// #include "op-attrs/ops/batch_matmul.h" +// #include "task_spec/op_task_invocation.h" +// #include "task_spec/op_task_signature.h" +// #include "sim_environment.h" + #include "op-attrs/ops/batch_matmul.h" -#include "op_task_invocation.h" -#include "op_task_signature.h" #include "sim_environment.h" +#include "task_spec/op_task_invocation.h" namespace FlexFlow { @@ -19,68 +23,426 @@ OpTaskInvocation init(BatchMatmulAttrs const &); OpTaskInvocation forward(BatchMatmulAttrs const &); OpTaskInvocation backward(BatchMatmulAttrs const &); -CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, +CostMetrics measure_operator_cost(SimEnvFactory const &sim, BatchMatmulAttrs const &attrs, - ParallelTensorShape const &lhs_input_shape, - ParallelTensorShape const &rhs_input_shape, + InputParallelTensorDesc const &a_input, + InputParallelTensorDesc const &b_input, ProfilingSettings const &settings, - MachineView const &); - -/* class BatchMatmul : public Op { */ -/* public: */ -/* BatchMatmul(FFModel &model, */ -/* const ParallelTensor A, */ -/* const ParallelTensor B, */ -/* int a_seq_length_dim, */ -/* int b_seq_length_dim, */ -/* char const *name = nullptr); */ -/* static Op * */ -/* create_operator_from_layer(FFModel &model, */ -/* Layer const *layer, */ -/* std::vector const &inputs); - */ - -/* void init(FFModel const &) override; */ -/* void forward(FFModel const &) override; */ -/* void backward(FFModel const &) override; */ -/* /1* static PCG::Node deserialize(FFModel &ff, *1/ */ -/* /1* Legion::Deserializer &d, *1/ */ -/* /1* ParallelTensor inputs[], *1/ */ -/* /1* int num_inputs); *1/ */ -/* static PerDeviceOpState *init_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void forward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void backward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* bool measure_operator_cost(Simulator *sim, */ -/* MachineView const &pc, */ -/* CostMetrics &cost_metrics) const override; */ - -/* OpTaskBinding get_init_task_binding() const override; */ -/* OpTaskBinding get_fwd_task_binding() const override; */ -/* OpTaskBinding get_bwd_task_binding() const override; */ -/* private: */ -/* template */ -/* void init_with_dim(FFModel const &ff); */ -/* template */ -/* void forward_with_dim(FFModel const &ff); */ -/* template */ -/* void backward_with_dim(FFModel const &ff); */ - -/* public: */ -/* int a_seq_length_dim, b_seq_length_dim; */ -/* }; */ + MachineView const &pc); } // namespace FlexFlow #endif + +// BatchMatmulParams BatchMatmul::get_params() const { +// BatchMatmulParams params; +// params.a_seq_length_dim = inputs[0]->num_dims - 1 - this->a_seq_length_dim; +// params.b_seq_length_dim = inputs[1]->num_dims - 1 - this->b_seq_length_dim; +// return params; +// } + +// Tensor FFModel::batch_matmul(const Tensor A, +// const Tensor B, +// int a_seq_length_dim, +// int b_seq_length_dim, +// char const *name) { +// Layer *bmm = new Layer(this, +// OP_BATCHMATMUL, +// DT_FLOAT, +// name, +// 2 /*inputs*/, +// 0 /*weights*/, +// 1 /*outputs*/, +// A, +// B); +// assert((a_seq_length_dim <= 1) && +// "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " +// "Fortran ordering)."); +// assert((b_seq_length_dim <= 1) && +// "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " +// "Fortran ordering)."); +// assert(A->num_dims == B->num_dims); +// for (int i = A->num_dims - 1; i >= 2; i--) { +// assert(A->dims[i] == B->dims[i]); +// } +// assert(A->dims[0] == B->dims[1]); +// int dims[MAX_TENSOR_DIM]; +// int numdim = A->num_dims; +// for (int i = 0; i < numdim; i++) { +// dims[i] = A->dims[i]; +// } +// dims[0] = B->dims[0]; +// bmm->outputs[0] = create_tensor_legion_ordering( +// numdim, dims, A->data_type, bmm, 0, true /*create_grad*/); +// bmm->add_int_property("a_seq_length_dim", a_seq_length_dim); +// bmm->add_int_property("b_seq_length_dim", b_seq_length_dim); +// layers.push_back(bmm); +// return bmm->outputs[0]; +// } + +// Op *BatchMatmul::create_operator_from_layer( +// FFModel &model, +// Layer const *layer, +// std::vector const &inputs) { +// long long value; +// layer->get_int_property("a_seq_length_dim", value); +// int a_seq_length_dim = value; +// layer->get_int_property("b_seq_length_dim", value); +// int b_seq_length_dim = value; +// return new BatchMatmul(model, +// inputs[0], +// inputs[1], +// a_seq_length_dim, +// b_seq_length_dim, +// layer->name); +// } + +// BatchMatmul::BatchMatmul( +// FFModel &model, +// BatchMatmulParams const ¶ms, +// std::pair const &inputs, +// char const *name) +// : BatchMatmul(model, +// inputs.first, +// inputs.second, +// params.a_seq_length_dim, +// params.b_seq_length_dim, +// name) {} + +// // return A*B +// BatchMatmul::BatchMatmul(FFModel &model, +// const ParallelTensor A, +// const ParallelTensor B, +// int _a_seq_length_dim, +// int _b_seq_length_dim, +// char const *name) +// : Op(model, +// OP_BATCHMATMUL, +// DT_FLOAT, +// name, +// 2 /*inputs*/, +// 0 /*weights*/, +// 1 /*outputs*/, +// A, +// B), +// a_seq_length_dim(A->num_dims - 1 - _a_seq_length_dim), +// b_seq_length_dim(B->num_dims - 1 - _b_seq_length_dim) { +// assert((_a_seq_length_dim <= 1) && +// "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " +// "Fortran ordering)."); +// assert((_b_seq_length_dim <= 1) && +// "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " +// "Fortran ordering)."); +// assert(A->num_dims == B->num_dims); +// for (int i = A->num_dims - 1; i >= 2; i--) { +// assert(A->dims[i] == B->dims[i]); +// } +// assert(A->dims[0] == B->dims[1]); +// ParallelDim dims[MAX_TENSOR_DIM]; +// for (int i = 0; i < A->num_dims; i++) { +// dims[i] = A->dims[i]; +// } +// dims[0] = B->dims[0]; +// numOutputs = 1; +// outputs[0] = model.create_parallel_tensor_legion_ordering( +// A->num_dims, dims, DT_FLOAT, this); +// // C is not none +// // if (C != Tensor::NO_TENSOR) { +// // numInputs = 3; +// // assert(C.num_dims == outputs[0].num_dims); +// // for (int i = 0; i < C.num_dims; i++) +// // assert(C.adim[i] == outputs[0].adim[i]); +// //} +// } + +// void BatchMatmul::serialize(Legion::Serializer &sez) const { +// BatchMatmulParams params = get_params(); +// sez.serialize(params.a_seq_length_dim); +// sez.serialize(params.b_seq_length_dim); +// } + +// using PCG::Node; +// /*static*/ +// Node BatchMatmul::deserialize(FFModel &ff, +// Legion::Deserializer &dez, +// ParallelTensor inputs[], +// int num_inputs) { +// assert(num_inputs == 2); +// int a_seq_length_dim, b_seq_length_dim; +// dez.deserialize(a_seq_length_dim); +// dez.deserialize(b_seq_length_dim); + +// BatchMatmulParams params; +// params.a_seq_length_dim = a_seq_length_dim; +// params.b_seq_length_dim = b_seq_length_dim; +// return ff.get_or_create_node({inputs[0], inputs[1]}, params); +// } + +// Op *BatchMatmul::materialize(FFModel &ff, +// ParallelTensor inputs[], +// int num_inputs) const { +// BatchMatmulParams params = get_params(); +// return new BatchMatmul(ff, params, {inputs[0], inputs[1]}, this->name); +// } + +// void BatchMatmul::forward(FFModel const &ff) { +// int dim = outputs[0]->num_dims; +// switch (dim) { +// #define DIMFUNC(DIM) \ +// case DIM: { \ +// // forward_with_dim(ff); +// this->execute_task(ff, BATCHMATMUL_FWD_TASK_ID, +// get_fwd_task_signature()); break; +// } +// LEGION_FOREACH_N(DIMFUNC) +// #undef DIMFUNC +// default: +// assert(false); +// } +// } + +// template +// void BatchMatmul::forward_with_dim(FFModel const &ff) { +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_forward(ff, argmap); +// IndexLauncher launcher( +// BATCHMATMUL_FWD_TASK_ID, +// parallel_is, +// TaskArgument(&ff.iter_config, sizeof(FFIterationConfig)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// for (int i = 0; i < numInputs; i++) { +// launcher.add_region_requirement(RegionRequirement(inputs[i]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[i]->region)); +// launcher.add_field(i + 1, FID_DATA); +// } +// runtime->execute_index_space(ctx, launcher); +// } + +/* + regions[0](O): output + regions[1](I): A + regions[2](I): B + ////////////////////(optional) regions[3](I): C -- TODO: is C deprecated? + output = A * B /////////+ C +*/ + +// void BatchMatmul::init(FFModel const &ff) { +// int dim = outputs[0]->num_dims; +// switch (dim) { +// #define DIMFUNC(DIM) \ +// case DIM: { \ +// // init_with_dim(ff); +// this->execute_task(ff, BATCHMATMUL_INIT_TASK_ID, +// get_init_task_signature()); break; +// } +// LEGION_FOREACH_N(DIMFUNC) +// #undef DIMFUNC +// default: +// assert(false); +// } +// } // namespace FlexFlow +// // / +// // template +// // void BatchMatmul::init_with_dim(FFModel const &ff) { +// // assert(check_output_input_weight_same_parallel_is()); +// // parallel_is = outputs[0]->parallel_is; +// // ArgumentMap argmap; +// // Context ctx = ff.config.lg_ctx; +// // Runtime *runtime = ff.config.lg_hlr; +// // set_argumentmap_for_init(ff, argmap); +// // IndexLauncher launcher(BATCHMATMUL_INIT_TASK_ID, +// // parallel_is, +// // TaskArgument(this, sizeof(BatchMatmul)), +// // argmap, +// // Predicate::TRUE_PRED, +// // false /*must*/, +// // 0 /*mapper_id*/, +// // outputs[0]->machine_view.hash()); +// // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// // 0 /*projection id*/, +// // WRITE_ONLY, +// // EXCLUSIVE, +// // outputs[0]->region)); +// // launcher.add_field(0, FID_DATA); +// // for (int i = 0; i < numInputs; i++) { +// // launcher.add_region_requirement(RegionRequirement(inputs[i]->part, +// // 0 /*projection id*/, +// // READ_ONLY, +// // EXCLUSIVE, +// // inputs[i]->region)); +// // launcher.add_field(i + 1, FID_DATA); +// // } +// // FutureMap fm = runtime->execute_index_space(ctx, launcher); +// // fm.wait_all_results(); +// // set_opmeta_from_futuremap(ff, fm); +// // } + +// OpTaskBinding BatchMatmul::get_bwd_task_binding() const { +// OpTaskBinding binding; +// binding.bind(A_INPUT, input_tensor(0)); +// binding.bind(B_INPUT, input_tensor(1)); +// binding.bind_grad(A_INPUT_GRAD, input_tensor(0).grad()); +// binding.bind_grad(B_INPUT_GRAD, input_tensor(1).grad()); + +// binding.bind(OUTPUT, output_tensor(0)); +// binding.bind_grad(OUTPUT_GRAD, output_tensor(0).grad()); + +// binding.bind_arg(ATTRS, this->attrs); +// return binding; +// } + +// static OpTaskSignature get_fwd_task_signature() { +// OpTaskSignature fwd(OpTaskType::FWD); + +// fwd.add_input_slot(A_INPUT, READ_WRITE); +// fwd.add_input_slot(B_INPUT, READ_WRITE); +// fwd.add_output_slot(OUTPUT); + +// return fwd; +// } + +// static OpTaskSignature get_bwd_task_signature() { +// OpTaskSignature bwd(OpTaskType::BWD); + +// bwd.add_input_slot(A_INPUT); +// bwd.add_input_slot(B_INPUT); +// bwd.add_input_grad_slot(A_INPUT_GRAD); +// bwd.add_input_grad_slot(B_INPUT_GRAD); +// bwd.add_output_slot(OUTPUT); +// bwd.add_output_grad_slot(OUTPUT_GRAD); + +// return bwd; +// } + +// OpTaskBinding BatchMatmul::get_init_task_binding() const { +// OpTaskBinding binding; + +// binding.bind_arg(ATTRS, this->attrs); +// binding.bind_arg(PROFILING, this->profiling); + +// return binding; +// } + +// OpTaskBinding BatchMatmul::get_fwd_task_binding() const { +// OpTaskBinding binding; + +// binding.bind(A_INPUT, input_tensor(0)); +// binding.bind(B_INPUT, input_tensor(1)); +// binding.bind(OUTPUT, output_tensor(0)); + +// binding.bind_arg(ATTRS, this->attrs); +// return binding; +// } + +// void BatchMatmul::backward(FFModel const &ff) { +// int dim = outputs[0]->num_dims; +// switch (dim) { +// #d ef ine DIMFUNC(DIM) \ +// case DIM: { \ +// backward_with_dim(ff); \ +// break; \ +// } +// LEGION_FOREACH_N(DIMFUNC) +// #undef DIMFUNC +// default: +// assert(false); +// } +// } + +// void BatchMatmul::print_layer(FFModel const &ff) { +// return; +// } + +/* + regions[0](I): output + regions[1](I): output_grad + regions[2](I): A + regions[3](I/O): A_grad + regions[4](I): B + regions[5](I/O): B_grad + regions[6](I/O): C_grad +*/ +// template +// void BatchMatmul::backward_with_dim(FFModel const &ff) { +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_backward(ff, argmap); +// IndexLauncher launcher( +// BATCHMATMUL_BWD_TASK_ID, +// parallel_is, +// TaskArgument(&ff.iter_config, sizeof(FFIterationConfig)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// // regions[0](I): output +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// // regions[1](I): output_grad +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// outputs[0]->region_grad)); +// launcher.add_field(1, FID_DATA); +// // regions[2](I): A +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(2, FID_DATA); +// // regions[3](I/O): A_grad +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// inputs[0]->region_grad)); +// launcher.add_field(3, FID_DATA); +// // regions[4](I): B +// launcher.add_region_requirement(RegionRequirement(inputs[1]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[1]->region)); +// launcher.add_field(4, FID_DATA); +// // regions[5](I/O): B_grad +// launcher.add_region_requirement(RegionRequirement(inputs[1]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// inputs[1]->region_grad)); +// launcher.add_field(5, FID_DATA); +// runtime->execute_index_space(ctx, launcher); +// } + +/* + regions[0](I): output + regions[1](I): output_grad + regions[2](I): A + regions[3](I/O): A_grad + regions[4](I): B + regions[5](I/O): B_grad + regions[6](I/O): C_grad +*/ diff --git a/lib/runtime/src/ops/batch_norm.cc b/lib/runtime/src/ops/batch_norm.cc index 98cc4576a1..6ebf359051 100644 --- a/lib/runtime/src/ops/batch_norm.cc +++ b/lib/runtime/src/ops/batch_norm.cc @@ -16,505 +16,285 @@ #include "batch_norm.h" #include "kernels/batch_norm_kernels.h" #include "legion/legion_utilities.h" -#include "task_spec.h" + +namespace FlexFlow { using namespace FlexFlow::Kernels::BatchNorm; -namespace FlexFlow { +using Legion::Context; +using Legion::PhysicalRegion; +using Legion::Runtime; +using Legion::Task; enum Slots { - INPUT, - SCALE, - BIAS, - OUTPUT, - INPUT_GRAD, - SCALE_GRAD, - BIAS_GRAD, - OUTPUT_GRAD, + INPUT, // tensor + SCALE, // tensor + BIAS, // tensor + OUTPUT, // tensor ATTRS, - PROFILING -} - -Tensor - FFModel::batch_norm(const Tensor input, bool relu, char const *name) { - assert(input->num_dims == 4); /*NCHW*/ - Layer *bm = new Layer(this, - OP_BATCHNORM, - DT_FLOAT, - name, - 1 /*inputs*/, - 2 /*weights*/, - 1 /*outputs*/, - input); - int numdims = 4; - bm->outputs[0] = create_tensor_legion_ordering( - numdims, input->dims, DT_FLOAT, bm, 0, true /*create_grad*/); - bm->add_int_property("relu", relu); - layers.push_back(bm); - return bm->outputs[0]; -} - -/* - locals[0] = scale - locals[1] = bias -*/ -BatchNorm::BatchNorm(FFModel &model, - const ParallelTensor _input, - const ParallelTensor _scale, - const ParallelTensor _bias, - bool _relu, - char const *name) - : Op(model, - OP_BATCHNORM, - DT_FLOAT, - name, - 1 /*inputs*/, - 2 /*weights*/, - 1 /*outputs*/, - _input, - _scale, - _bias), - relu(_relu) { - assert(_input->num_dims == 4); - numOutputs = 1; - ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < _input->num_dims; i++) { - dims[i] = _input->dims[_input->num_dims - 1 - i]; - } - outputs[0] = - model.create_parallel_tensor(_input->num_dims, dims, DT_FLOAT, this); - return; -} - -static OpTaskSignature get_init_task_signature() { - OpTaskSignature init(OpTaskType::INIT); - - init.add_arg_slot(ATTRS); - init.add_arg_slot(PROFILING); - - // init.add_input_slot(INPUT); - // init.add_param_slot(SCALE); - // init.add_param_slot(BIAS); - init.add_output_slot(OUTPUT); -} - -static OpTaskSignature get_fwd_task_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_arg_slot(ATTRS); - - fwd.add_input_slot(INPUT); - fwd.add_param_slot(SCALE); - fwd.add_param_slot(BIAS); - fwd.add_output_slot(OUTPUT, WRITE_DISCARD); - - return fwd; -} - -static OpTaskSignature get_bwd_task_signature() { - OpTaskSignature bwd(OpTaskType::BWD); - - bwd.add_arg_slot(ATTRS); + PROFILING, + PER_DEVICE_STATE, + RELU, + HANDLE +}; - bwd.add_input_slot(INPUT); - bwd.add_input_grad_slot(INPUT_GRAD, READ_WRITE); - bwd.add_param_slot(SCALE); - bwd.add_param_grad_slot(SCALE_GRAD, READ_WRITE); - bwd.add_param_grad_slot(BIAS_GRAD, READ_WRITE); - bwd.add_output_grad_slot(OUTPUT_GRAD); - - return bwd; -} - -OpTaskBinding BatchNorm::get_init_task_binding() const { +OpTaskInvocation init(BatchNormAttrs const &attrs) { OpTaskBinding binding; - binding.bind_arg(ATTRS, this->attrs); - binding.bind_arg(PROFILING, this->profiling); - - // binding.bind(INPUT, input_tensor(0)); - // binding.bind(SCALE, param_tensor(0)); - // binding.bind(BIAS, param_tensor(1)); + binding.bind(INPUT, input_tensor(0)); + binding.bind(BIAS, input_tensor(2)); binding.bind(OUTPUT, output_tensor(0)); - return binding; + binding.bind_arg(ATTRS, attrs); + binding.bind_arg(PROFILING, profiling_settings()); + binding.bind_arg(HANDLE, ff_handle()); + + return {BATCHNORM_INIT_TASK_ID, binding}; } -OpTaskBinding BatchNorm::get_fwd_task_binding() const { +OpTaskInvocation forward(BatchNormAttrs const &attrs) { OpTaskBinding binding; - - binding.bind_arg(ATTRS, this->attrs); + binding.bind_arg(PROFILING, profiling_settings()); + binding.bind_arg(PER_DEVICE_STATE, + per_device_op_state()); binding.bind(INPUT, input_tensor(0)); - binding.bind(SCALE, param_tensor(0)); - binding.bind(BIAS, param_tensor(1)); + binding.bind(SCALE, input_tensor(1)); + binding.bind(BIAS, input_tensor(2)); binding.bind(OUTPUT, output_tensor(0)); - return binding; + return {BATCHNORM_FWD_TASK_ID, binding}; } -OpTaskBinding BatchNorm::get_bwd_task_binding() const { - OpTaskBinding binding; +OpTaskInvocation backward(BatchNormAttrs const &attrs) { + OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - binding.bind_arg(ATTRS, this->attrs); + return {BATCHNORM_BWD_TASK_ID, binding}; +} - binding.bind(INPUT, input_tensor(0)); - binding.bind(INPUT_GRAD, input_tensor(0).grad()); - binding.bind(SCALE, param_tensor(0)); - binding.bind(SCALE_GRAD, param_tensor(0).grad()); - binding.bind(BIAS_GRAD, param_tensor(1).grad()); - binding.bind(OUTPUT_GRAD, output_tensor(0).grad()); +static DeviceSpecific + init_task_impl(TaskArgumentAccessor const &acc) { + Allocator allocator = acc.get_allocator(); + PerDeviceFFHandle handle = acc.get_argument(HANDLE); + ProfilingSettings profiling = acc.get_argument(PROFILING); + auto output = acc.get_tensor(OUTPUT); + auto const &attrs = acc.get_argument(ATTRS); + + int output_w = output.shape[legion_dim_t(0)]; + int output_h = output.shape[legion_dim_t(1)]; + int output_c = output.shape[legion_dim_t(2)]; + int output_n = output.shape[legion_dim_t(3)]; + + ffTensorDescriptor_t inputTensor; + ffTensorDescriptor_t outputTensor; + ffTensorDescriptor_t biasTensor; + ffActivationDescriptor_t actiDesc; + ffBatchNormMode_t mode; + + size_t totalSize = sizeof(float) * output_c * 4; + float *runningMean = (float *)allocator.allocate(totalSize); + float *runningVar = (float *)runningMean + output_c; + float *saveMean = (float *)runningVar + output_c; + float *saveVar = (float *)saveMean + output_c; + + DeviceSpecific per_device_state = + acc.create_device_specific( + init_kernel(handle, + allocator, + inputTensor, + outputTensor, + biasTensor, + actiDesc, + mode, + runningMean, + runningVar, + saveMean, + saveVar, + output_n, + output_c, + output_h, + output_w, + profiling, + attrs.relu)); + + return per_device_state; +} - return binding; +static DeviceSpecific + init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + return init_task_impl(acc); } -void BatchNorm::init(FFModel const &ff) { - this->execute_task(ff, BATCHNORM_INIT_TASK_ID, get_init_task_signature()); - // assert(check_output_input_weight_same_parallel_is()); - // parallel_is = outputs[0]->parallel_is; - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_init(ff, argmap); - // IndexLauncher launcher(BATCHNORM_INIT_TASK_ID, - // parallel_is, - // TaskArgument(this, sizeof(BatchNorm)), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(1, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(weights[0]->region, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // weights[0]->region)); - // launcher.add_field(2, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(weights[1]->region, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // weights[1]->region)); - // launcher.add_field(3, FID_DATA); - // FutureMap fm = runtime->execute_index_space(ctx, launcher); - // fm.wait_all_results(); - // set_opmeta_from_futuremap(ff, fm); +static optional forward_task_impl(TaskArgumentAccessor const &acc) { + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + ProfilingSettings profiling = acc.get_argument(PROFILING); + + auto input = acc.get_tensor(INPUT); + auto output = acc.get_tensor(OUTPUT); + auto scale = acc.get_tensor(SCALE); + auto bias = acc.get_tensor(SCALE); + + return profile(forward_kernel, + profiling, + "[BatchNorm] forward_time = %.2lfms\n", + &per_device_state, + input.get_float_ptr(), + output.get_float_ptr(), + scale.get_float_ptr(), + bias.get_float_ptr()); } -/* - regions[0]: input - regions[1]: output - regions[2](I): scale - regions[3](I): bias -*/ -PerDeviceOpState * - BatchNorm::init_task(Task const *task, +static void forward_task(Task const *task, std::vector const ®ions, Context ctx, Runtime *runtime) { - assert(regions.size() == 4); - assert(task->regions.size() == 4); TaskArgumentAccessor acc(task, regions, ctx, runtime); - FFHandler handle = *((FFHandler const *)task->local_args); - - auto output = acc.get_tensor(OUTPUT); - - int output_w = output.shape[0]; - int output_h = output.shape[1]; - int output_c = output.shape[2]; - int output_n = output.shape[3]; - - Memory gpu_mem = Machine::MemoryQuery(Machine::get_machine()) - .only_kind(Memory::GPU_FB_MEM) - .best_affinity_to(task->target_proc) - .first(); - BatchNormPerDeviceState *m = new BatchNormPerDeviceState( - handle, bm, gpu_mem, output_n, output_c, output_h, output_w); - return m; + forward_task_impl(acc); } -void BatchNorm::forward(FFModel const &ff) { - this->execute_task(ff, BATCHNORM_FWD_TASK_ID, get_fwd_task_signature()); - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_forward(ff, argmap); - // IndexLauncher launcher(BATCHNORM_FWD_TASK_ID, - // parallel_is, - // TaskArgument(NULL, 0), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // WRITE_DISCARD, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(1, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(weights[0]->region, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // weights[0]->region)); - // launcher.add_field(2, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(weights[1]->region, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // weights[1]->region)); - // launcher.add_field(3, FID_DATA); - - // runtime->execute_index_space(ctx, launcher); +static optional backward_task_impl(TaskArgumentAccessor const &acc) { + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + ProfilingSettings profiling = acc.get_argument(PROFILING); + + auto input = acc.get_tensor(INPUT); + auto input_grad = acc.get_tensor_grad(INPUT); + auto output = acc.get_tensor(OUTPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); + auto scale = acc.get_tensor(SCALE); + auto scale_grad = acc.get_tensor_grad(SCALE); + auto bias_grad = acc.get_tensor_grad(BIAS); + + return profile(backward_kernel, + profiling, + "[BatchNorm] backward_time = %.2lfms\n", + &per_device_state, + input.get_float_ptr(), + output_grad.get_float_ptr(), + output.get_float_ptr(), + input_grad.get_float_ptr(), + scale.get_float_ptr(), + scale_grad.get_float_ptr(), + bias_grad.get_float_ptr(), + output.shape.get_volume()); } -/* - regions[0](I): input - regions[1](O): ouptut - regions[2](I): scale - regions[3](I): bias -*/ -void BatchNorm::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 4); - assert(task->regions.size() == 4); - // const BatchNorm* bm = (BatchNorm*) task->args; +static void backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { TaskArgumentAccessor acc(task, regions, ctx, runtime); - BatchNormPerDeviceState *m = *((BatchNormPerDeviceState **)task->local_args); - - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - auto scale = acc.get_tensor(SCALE); - auto bias = acc.get_tensor(SCALE); - - profile(forward_kernel, - m->profiling, - "[BatchNorm] forward_time = %.2lfms\n", - m, - input.get_float_ptr(), - output.get_float_ptr(), - scale.get_float_ptr(), - bias.get_float_ptr()); + backward_task_impl(acc); } -void BatchNorm::backward(FFModel const &ff) { - this->execute_task(ff, BATCHNORM_BWD_TASK_ID, get_bwd_task_signature()); - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_backward(ff, argmap); - // IndexLauncher launcher(BATCHNORM_BWD_TASK_ID, - // parallel_is, - // TaskArgument(NULL, 0), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // // regions[0](I): input - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // // regions[1](I/O): input_grad (we only need grad tensors) - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, - // 0 /*projection id*/, - // READ_WRITE, - // EXCLUSIVE, - // inputs[0]->region_grad)); - // launcher.add_field(1, FID_DATA); - // // regions[2](I): output - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(2, FID_DATA); - // // regions[3](I/O): output_grad - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, - // 0 /*projection id*/, - // READ_WRITE, - // EXCLUSIVE, - // outputs[0]->region_grad)); - // launcher.add_field(3, FID_DATA); - // // regions[4](I): filter - // launcher.add_region_requirement(RegionRequirement(weights[0]->region, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // weights[0]->region)); - // launcher.add_field(4, FID_DATA); - // // regions[5](I/O): filter_grad - // launcher.add_region_requirement(RegionRequirement(weights[0]->part_grad, - // 0 /*projection id*/, - // READ_WRITE, - // EXCLUSIVE, - // weights[0]->region_grad)); - // launcher.add_field(5, FID_DATA); - // // regions[6](I/O): bias_grad - // launcher.add_region_requirement(RegionRequirement(weights[1]->part_grad, - // 0 /*projection id*/, - // READ_WRITE, - // EXCLUSIVE, - // weights[1]->region_grad)); - // launcher.add_field(6, FID_DATA); - // FutureMap fm = runtime->execute_index_space(ctx, launcher); +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + BatchNormAttrs const &attrs, + InputParallelTensorDesc const &input_shape, + InputParallelTensorDesc const &scale_shape, + InputParallelTensorDesc const &bias_shape, + ProfilingSettings const &settings, + MachineView const &mv) { + + // int output_w = sub_output.dims[0].size; + // int output_h = sub_output.dims[1].size; + // int output_c = sub_output.dims[2].size; + // int output_n = sub_output.dims[3].size; + // BatchNormPerDeviceState *m = new BatchNormPerDeviceState( + // sim->handler, this, sim->memory, output_n, output_c, output_h, + // output_w); + + // sim->free_all(); + // float *input_ptr = (float *)sim->allocate(sub_input.get_volume(), + // DT_FLOAT); assert(input_ptr != NULL); cost_metrics.inputs_memory += + // cost_metrics.total_mem_diff_from(sim->offset); + + // float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), + // DT_FLOAT); assert(output_ptr != NULL); cost_metrics.outputs_memory += + // cost_metrics.total_mem_diff_from(sim->offset); + + // float *bias_ptr = (float *)sim->allocate(output_c, DT_FLOAT); + // assert(bias_ptr != NULL); + // float *scale_ptr = (float *)sim->allocate(output_c, DT_FLOAT); + // assert(scale_ptr != NULL); + // cost_metrics.weights_memory += + // cost_metrics.total_mem_diff_from(sim->offset); + + auto env = sim.new_environment(); + + ParallelTensorShape output_shape = get_output_shape(attrs); + + SimTaskBinding init_binding; + init_binding.bind(INPUT, input_shape); + init_binding.bind(BIAS, bias_shape); + init_binding.bind(OUTPUT, output_shape); + + init_binding.bind_arg(ATTRS, attrs); + init_binding.bind_arg(PROFILING, settings); + init_binding.bind_arg(HANDLE, ff_handle()); + + auto init_accessor = + env.get_init_accessor(ATTENTION_INIT_TASK_ID, init_binding); + DeviceSpecific per_device_state = + init_task_impl(init_accessor); + + SimTaskBinding fwd_binding; + fwd_binding.bind(INPUT, input_shape); + fwd_binding.bind(SCALE, scale_shape); + fwd_binding.bind(BIAS, bias_shape); + fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind_arg(PROFILING, settings); + fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); + + SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); + + auto fwd_accessor = env.get_fwd_accessor(ATTENTION_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = env.get_bwd_accessor(ATTENTION_BWD_TASK_ID, bwd_binding); + + float forward_time = forward_task_impl(fwd_accessor).value(); + float backward_time = backward_task_impl(bwd_accessor).value(); + + float sync_time = default_estimate_sync_time(env); + return make_metrics(forward_time, backward_time, sync_time, env); } -/* - regions[0](I): input - regions[1](I/O): input_grad - regions[2](I): output - regions[3](I/O): output_grad - regions[4](I): scale - regions[5](I/O): scale_grad - regions[6](I/O): bias_grad -*/ -__host__ void - BatchNorm::backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 7); - assert(task->regions.size() == 7); - // float beta = 0.0f; - // const BatchNorm* bm = (BatchNorm*) task->args; - TaskArgumentAccessor acc(task, regions, ctx, runtime); - BatchNormPerDeviceState *m = *((BatchNormPerDeviceState **)task->local_args); - - auto input = acc.get_tensor(INPUT); - auto input_grad = acc.get_tensor_grad(INPUT_GRAD); - auto output = acc.get_tensor(OUTPUT); - auto output_grad = acc.get_tensor_grad(OUTPUT_GRAD); - auto scale = acc.get_tensor(SCALE); - auto scale_grad = acc.get_tensor_grad(SCALE_GRAD); - auto bias_grad = acc.get_tensor_grad(BIAS_GRAD); - - profile(backward_kernel, - m->profiling, - "[BatchNorm] backward_time = %.2lfms\n", - m, - input.get_float_ptr(), - output_grad.get_float_ptr(), - output.get_float_ptr(), - input_grad.get_float_ptr(), - scale.get_float_ptr(), - scale_grad.get_float_ptr(), - bias_grad.get_float_ptr(), - output.get_volume()); +template <> +void register_task() { + OpTaskSignature init(OpTaskType::INIT); + init.add_input_slot(INPUT); + init.add_input_slot(BIAS); + init.add_output_slot(OUTPUT); + init.add_arg_slot(ATTRS); + init.add_arg_slot(PROFILING); + init.add_unchecked_arg_slot(HANDLE); + + register_task(BATCHNORM_INIT_TASK_ID, "BatchNorm Init", init, init_task); +} + +template <> +void register_task() { + OpTaskSignature fwd(OpTaskType::FWD); + + fwd.add_input_slot(INPUT); + fwd.add_input_slot(SCALE); + fwd.add_input_slot(BIAS); + fwd.add_output_slot(OUTPUT); + fwd.add_arg_slot(PROFILING); + fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); + + register_task(BATCHNORM_FWD_TASK_ID, "BatchNorm Fwd", fwd, forward_task); } -bool BatchNorm::measure_operator_cost(Simulator *sim, - MachineView const &mv, - CostMetrics &cost_metrics) const { - ParallelTensorBase sub_input, sub_output; - if (!outputs[0]->get_sub_tensor(mv, sub_output)) { - return false; - } - if (!inputs[0]->get_sub_tensor(mv, sub_input)) { - return false; - } - - int output_w = sub_output.dims[0].size; - int output_h = sub_output.dims[1].size; - int output_c = sub_output.dims[2].size; - int output_n = sub_output.dims[3].size; - BatchNormPerDeviceState *m = new BatchNormPerDeviceState( - sim->handler, this, sim->memory, output_n, output_c, output_h, output_w); - - sim->free_all(); - float *input_ptr = (float *)sim->allocate(sub_input.get_volume(), DT_FLOAT); - assert(input_ptr != NULL); - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - assert(output_ptr != NULL); - cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *bias_ptr = (float *)sim->allocate(output_c, DT_FLOAT); - assert(bias_ptr != NULL); - float *scale_ptr = (float *)sim->allocate(output_c, DT_FLOAT); - assert(scale_ptr != NULL); - cost_metrics.weights_memory += cost_metrics.total_mem_diff_from(sim->offset); - - std::function forward, backward; - forward = [&](ffStream_t stream) { - forward_kernel(stream, m, input_ptr, output_ptr, scale_ptr, bias_ptr); - }; - if (sim->computationMode == COMP_MODE_TRAINING) { - float *input_grad_ptr = - (float *)sim->allocate(sub_input.get_volume(), DT_FLOAT); - assert(input_grad_ptr != NULL); - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *output_grad_ptr = - (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - assert(output_grad_ptr != NULL); - cost_metrics.outputs_memory += - cost_metrics.total_mem_diff_from(sim->offset); - - float *scale_grad_ptr = (float *)sim->allocate(output_c, DT_FLOAT); - assert(scale_grad_ptr != NULL); - float *bias_grad_ptr = (float *)sim->allocate(output_c, DT_FLOAT); - assert(bias_grad_ptr != NULL); - cost_metrics.weights_memory += - cost_metrics.total_mem_diff_from(sim->offset); - - backward = [&](ffStream_t stream) { - backward_kernel(stream, - m, - input_ptr, - output_grad_ptr, - output_ptr, - input_grad_ptr, - scale_ptr, - scale_grad_ptr, - bias_grad_ptr, - sub_output.get_volume()); - }; - } - - inner_measure_operator_cost(sim, forward, backward, cost_metrics); - - if (sim->computationMode == COMP_MODE_TRAINING) { - printf("[Measure BatchNorm] name(%s) size(%zu) forward_time(%.4lf) " - "backward_time(%.4lf)\n", - name, - sub_input.get_volume(), - cost_metrics.forward_time, - cost_metrics.backward_time); - } else { - printf("[Measure BatchNorm] name(%s) size(%zu) forward_time(%.4lf)\n", - name, - sub_input.get_volume(), - cost_metrics.forward_time); - } - // Free batchnormmeta - delete m; - return true; +template <> +void register_task() { + OpTaskSignature bwd = + infer_bwd_signature(get_op_signature(BATCHNORM_FWD_TASK_ID)); + + register_task(BATCHNORM_BWD_TASK_ID, "BatchNorm Bwd", bwd, backward_task); } }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/batch_norm.h b/lib/runtime/src/ops/batch_norm.h index e54331665e..94bda5122b 100644 --- a/lib/runtime/src/ops/batch_norm.h +++ b/lib/runtime/src/ops/batch_norm.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_BATCH_NORM_H #include "op-attrs/ops/batch_norm.h" -#include "op_task_invocation.h" #include "sim_environment.h" +#include "task_spec/op_task_invocation.h" namespace FlexFlow { @@ -66,3 +66,243 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, } // namespace FlexFlow #endif + +// void BatchNorm::init(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_INIT_TASK_ID, get_init_task_signature()); +// } + +// void BatchNorm::forward(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_FWD_TASK_ID, get_fwd_task_signature()); +// } + +// void BatchNorm::backward(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_BWD_TASK_ID, get_bwd_task_signature()); +// } + +// Tensor batch_norm(const Tensor input, bool relu, char const *name) { +// assert(input->num_dims == 4); /*NCHW*/ +// Layer *bm = new Layer(this, +// OP_BATCHNORM, +// DT_FLOAT, +// name, +// 1 /*inputs*/, +// 2 /*weights*/, +// 1 /*outputs*/, +// input); +// int numdims = 4; +// bm->outputs[0] = create_tensor_legion_ordering( +// numdims, input->dims, DT_FLOAT, bm, 0, true /*create_grad*/); +// bm->add_int_property("relu", relu); +// layers.push_back(bm); +// return bm->outputs[0]; +// } + +// BatchNorm::BatchNorm(FFModel &model, +// const ParallelTensor _input, +// const ParallelTensor _scale, +// const ParallelTensor _bias, +// bool _relu, +// char const *name) +// : Op(model, +// OP_BATCHNORM, +// DT_FLOAT, +// name, +// 1 /*inputs*/, +// 2 /*weights*/, +// 1 /*outputs*/, +// _input, +// _scale, +// _bias), +// relu(_relu) { +// assert(_input->num_dims == 4); +// numOutputs = 1; +// ParallelDim dims[MAX_TENSOR_DIM]; +// for (int i = 0; i < _input->num_dims; i++) { +// dims[i] = _input->dims[_input->num_dims - 1 - i]; +// } +// outputs[0] = +// model.create_parallel_tensor(_input->num_dims, dims, DT_FLOAT, this); +// return; +// } + +/* + locals[0] = scale + locals[1] = bias +*/ + +// void BatchNorm::init(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_INIT_TASK_ID, get_init_task_signature()); +// assert(check_output_input_weight_same_parallel_is()); +// parallel_is = outputs[0]->parallel_is; +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_init(ff, argmap); +// IndexLauncher launcher(BATCHNORM_INIT_TASK_ID, +// parallel_is, +// TaskArgument(this, sizeof(BatchNorm)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(1, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(weights[0]->region, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// weights[0]->region)); +// launcher.add_field(2, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(weights[1]->region, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// weights[1]->region)); +// launcher.add_field(3, FID_DATA); +// FutureMap fm = runtime->execute_index_space(ctx, launcher); +// fm.wait_all_results(); +// set_opmeta_from_futuremap(ff, fm); +// } + +/* + regions[0]: input + regions[1]: output + regions[2](I): scale + regions[3](I): bias +*/ + +// void BatchNorm::forward(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_FWD_TASK_ID, get_fwd_task_signature()); +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_forward(ff, argmap); +// IndexLauncher launcher(BATCHNORM_FWD_TASK_ID, +// parallel_is, +// TaskArgument(NULL, 0), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_DISCARD, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(1, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(weights[0]->region, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// weights[0]->region)); +// launcher.add_field(2, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(weights[1]->region, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// weights[1]->region)); +// launcher.add_field(3, FID_DATA); + +// runtime->execute_index_space(ctx, launcher); +// } + +/* + regions[0](I): input + regions[1](O): ouptut + regions[2](I): scale + regions[3](I): bias +*/ + +// void BatchNorm::backward(FFModel const &ff) { +// this->execute_task(ff, BATCHNORM_BWD_TASK_ID, get_bwd_task_signature()); +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_backward(ff, argmap); +// IndexLauncher launcher(BATCHNORM_BWD_TASK_ID, +// parallel_is, +// TaskArgument(NULL, 0), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// // regions[0](I): input +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// // regions[1](I/O): input_grad (we only need grad tensors) +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// inputs[0]->region_grad)); +// launcher.add_field(1, FID_DATA); +// // regions[2](I): output +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(2, FID_DATA); +// // regions[3](I/O): output_grad +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// outputs[0]->region_grad)); +// launcher.add_field(3, FID_DATA); +// // regions[4](I): filter +// launcher.add_region_requirement(RegionRequirement(weights[0]->region, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// weights[0]->region)); +// launcher.add_field(4, FID_DATA); +// // regions[5](I/O): filter_grad +// launcher.add_region_requirement(RegionRequirement(weights[0]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// weights[0]->region_grad)); +// launcher.add_field(5, FID_DATA); +// // regions[6](I/O): bias_grad +// launcher.add_region_requirement(RegionRequirement(weights[1]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// weights[1]->region_grad)); +// launcher.add_field(6, FID_DATA); +// FutureMap fm = runtime->execute_index_space(ctx, launcher); +// } + +/* + regions[0](I): input + regions[1](I/O): input_grad + regions[2](I): output + regions[3](I/O): output_grad + regions[4](I): scale + regions[5](I/O): scale_grad + regions[6](I/O): bias_grad +*/ diff --git a/lib/runtime/src/ops/cast.cc b/lib/runtime/src/ops/cast.cc index 23c1bc9940..36afdefcef 100644 --- a/lib/runtime/src/ops/cast.cc +++ b/lib/runtime/src/ops/cast.cc @@ -16,441 +16,178 @@ #include "cast.h" #include "kernels/cast_kernels.h" #include "legion/legion_utilities.h" -#include "task_spec.h" #include "utils/hash-utils.h" using namespace FlexFlow::Kernels::Cast; -namespace FlexFlow { - -enum Slots { - INPUT, - OUTPUT, - INPUT_GRAD, - OUTPUT_GRAD, - ATTRS, - PROFILING -} - -// declare Legion names -using Legion::ArgumentMap; using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; using Legion::Runtime; using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; - -Tensor FFModel::cast(const Tensor input, DataType dtype, char const *name) { - Layer *cast = new Layer(this, - OP_CAST, - dtype, - name, - 1 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - input); - int numdims = input->num_dims; - int dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdims; i++) { - dims[i] = input->dims[i]; - } - cast->outputs[0] = create_tensor_legion_ordering( - numdims, dims, dtype, cast, 0, true /*create_grad*/); - cast->add_int_property("dtype", dtype); - layers.push_back(cast); - return cast->outputs[0]; -} - -Op *Cast::create_operator_from_layer( - FFModel &model, - Layer const *layer, - std::vector const &inputs) { - long long value; - layer->get_int_property("dtype", value); - DataType dtype = (DataType)value; - return new Cast(model, inputs[0], dtype, layer->name); -} - -CastParams Cast::get_params() const { - CastParams params; - params.dtype = this->outputs[0]->data_type; - return params; -} - -Cast::Cast(FFModel &model, - ParallelTensor const &input, - DataType _dtype, - char const *name) - : Op(model, - OP_CAST, - _dtype, - name, - 1 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - input) { - numOutputs = 1; - numWeights = 0; - int numdim = input->num_dims; - ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdim; i++) { - dims[i] = input->dims[i]; - } - outputs[0] = - model.create_parallel_tensor_legion_ordering(numdim, dims, _dtype, this); -} - -Cast::Cast(FFModel &model, - CastParams const ¶ms, - ParallelTensor const &input, - char const *name) - : Cast(model, input, params.dtype, name) {} - -static OpTaskSignature get_init_task_signature() { - OpTaskSignature init(OpTaskType::INIT); - - init.add_arg_slot(ATTRS); - init.add_arg_slot(PROFILING); - - init.add_input_slot(INPUT); - init.add_output_slot(OUTPUT); - - return init; -} - -static OpTaskSignature get_fwd_task_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_arg_slot(ATTRS); - - fwd.add_input_slot(INPUT); - fwd.add_output_slot(OUTPUT); - - return init; -} - -static OpTaskSignature get_bwd_task_signature() { - OpTaskSignature bwd(OpTaskType::BWD); - - bwd.add_arg_slot(ATTRS); - bwd.add_input_grad_slot(INPUT_GRAD); - bwd.add_output_grad_slot(OUTPUT_GRAD); +namespace FlexFlow { - return bwd; -} +enum Slots { INPUT, OUTPUT, ATTRS, PROFILING, PER_DEVICE_STATE, HANDLE }; -OpTaskBinding Cast::get_init_task_binding() const { +// declare Legion names +// using Legion::ArgumentMap; +// using Legion::Context; +// using Legion::coord_t; +// using Legion::Domain; +// using Legion::FutureMap; +// using Legion::IndexLauncher; +// using Legion::PhysicalRegion; +// using Legion::Predicate; +// using Legion::Rect; +// using Legion::RegionRequirement; +// using Legion::Runtime; +// using Legion::Task; +// using Legion::TaskArgument; +// using Legion::TaskLauncher; + +OpTaskInvocation init(CastAttrs const &attrs) { OpTaskBinding binding; - binding.bind_arg(PROFILING, this->profiling); - binding.bind_arg(ATTRS, this->attrs); + binding.bind_arg(HANDLE, ff_handle()); binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return binding; + return {CAST_INIT_TASK_ID, binding}; } -OpTaskBinding Cast::get_fwd_task_binding() const { +OpTaskInvocation forward(CastAttrs const &attrs) { OpTaskBinding binding; - binding.bind_arg(ATTRS, this->attrs); + binding.bind_arg(PER_DEVICE_STATE, per_device_op_state()); + binding.bind_arg(PROFILING, profiling_settings()); binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return binding; + return {CAST_FWD_TASK_ID, binding}; } -OpTaskBinding Cast::get_bwd_task_binding() const { - OpTaskBinding binding; +OpTaskInvocation backward(CastAttrs const &attrs) { + OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - binding.bind_arg(ATTRS, this->attrs); + return {CAST_BWD_TASK_ID, binding}; +} - binding.bind(INPUT_GRAD, input_tensor(0).grad()); - binding.bind(OUTPUT_GRAD, output_tensor(0).grad()); +static DeviceSpecific + init_task_impl(TaskArgumentAccessor const &acc) { - return binding; -} + PerDeviceFFHandle handle = acc.get_argument(HANDLE); + auto input = acc.get_tensor(INPUT); + auto output = acc.get_tensor(OUTPUT); -void Cast::init(FFModel const &ff) { - this->execute_task(ff, CAST_INIT_TASK_ID, get_init_task_signature()); - // assert(check_output_input_weight_same_parallel_is()); - // parallel_is = outputs[0]->parallel_is; - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_init(ff, argmap); - // IndexLauncher launcher(CAST_INIT_TASK_ID, - // parallel_is, - // TaskArgument(this, sizeof(Cast)), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[0]->region)); - // launcher.add_field(1, FID_DATA); - // FutureMap fm = runtime->execute_index_space(ctx, launcher); - // fm.wait_all_results(); - // set_opmeta_from_futuremap(ff, fm); + DeviceSpecific per_device_state = + acc.create_device_specific( + init_kernel(handle, input.data_type, output.data_type)); + return per_device_state; } -PerDeviceOpState *Cast::init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { +static DeviceSpecific + init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { TaskArgumentAccessor acc(task, regions, ctx, runtime); - - FFHandler handler = *((FFHandler const *)task->local_args); - CastPerDeviceState *m = new CastPerDeviceState(handler); - bool profiling = acc.get_argument(PROFILING); - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - m->input_data_type = input->data_type; - m->output_data_type = output->data_type; - m->profiling = profiling; - return m; + return init_task_impl(acc); } -void Cast::forward(FFModel const &ff) { - this->execute_task(ff, CAST_FWD_TASK_ID, get_fwd_task_signature()); - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_forward(ff, argmap); - // IndexLauncher launcher(CAST_FWD_TASK_ID, - // parallel_is, - // TaskArgument(NULL, false), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(1, FID_DATA); - // runtime->execute_index_space(ctx, launcher); -} +static optional forward_task_impl(TaskArgumentAccessor const &acc) { + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + ProfilingSettings profiling = acc.get_argument(PROFILING); -// template -// void Cast::forward_task_with_1_type(Task const *task, -// std::vector const -// ®ions, Context ctx, Runtime *runtime) -// { -// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); -// if (m->output_data_type == DT_FLOAT) { -// Cast::forward_task_with_2_type(task, regions, ctx, runtime); -// } else if (m->output_data_type == DT_DOUBLE) { -// Cast::forward_task_with_2_type(task, regions, ctx, runtime); -// } else if (m->output_data_type == DT_INT32) { -// Cast::forward_task_with_2_type(task, regions, ctx, -// runtime); -// } else if (m->output_data_type == DT_INT64) { -// Cast::forward_task_with_2_type(task, regions, ctx, -// runtime); -// } -// } - -// template -// void Cast::forward_task_with_2_type(Task const *task, -// std::vector const -// ®ions, Context ctx, Runtime *runtime) -// { -// assert(regions.size() == 2); -// assert(task->regions.size() == regions.size()); -// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); -// // Domain input_domain = runtime->get_index_space_domain( -// // ctx, task->regions[0].region.get_index_space()); -// Domain output_domain = runtime->get_index_space_domain( -// ctx, task->regions[1].region.get_index_space()); -// const IDT *input_ptr = helperGetTensorPointerRO( -// regions[0], task->regions[0], FID_DATA, ctx, runtime); -// ODT *output_ptr = helperGetTensorPointerWO( -// regions[1], task->regions[1], FID_DATA, ctx, runtime); -// forward_kernel_wrapper( -// m, input_ptr, output_ptr, output_domain.get_volume()); -// } - -void Cast::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); - // if (m->input_data_type == DT_FLOAT) { - // Cast::forward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->input_data_type == DT_DOUBLE) { - // Cast::forward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->input_data_type == DT_INT32) { - // Cast::forward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->input_data_type == DT_INT64) { - // Cast::forward_task_with_1_type(task, regions, ctx, runtime); - // } - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - - profile(forward_kernel, - m->profiling, - "[Cast] forward_time = %.2lfms\n", - m, - input, - output) -} + auto input = acc.get_tensor(INPUT); + auto output = acc.get_tensor(OUTPUT); -void Cast::backward(FFModel const &ff) { - this->execute_task(ff, CAST_BWD_TASK_ID, get_bwd_task_signature()); - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_backward(ff, argmap); - // IndexLauncher launcher(CAST_BWD_TASK_ID, - // parallel_is, - // TaskArgument(NULL, false), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // outputs[0]->region_grad)); - // launcher.add_field(0, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // inputs[0]->region_grad)); - // launcher.add_field(1, FID_DATA); - // runtime->execute_index_space(ctx, launcher); + return profile(forward_kernel, + profiling, + "[Cast] forward_time = %.2lfms\n", + &per_device_state, + input, + output); } -// template -// void Cast::backward_task_with_1_type(Task const *task, -// std::vector const -// ®ions, Context ctx, Runtime *runtime) -// { -// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); -// if (m->input_data_type == DT_FLOAT) { -// Cast::backward_task_with_2_type(task, regions, ctx, runtime); -// } else if (m->input_data_type == DT_DOUBLE) { -// Cast::backward_task_with_2_type(task, regions, ctx, -// runtime); -// } else if (m->input_data_type == DT_INT32) { -// Cast::backward_task_with_2_type(task, regions, ctx, -// runtime); -// } else if (m->input_data_type == DT_INT64) { -// Cast::backward_task_with_2_type(task, regions, ctx, -// runtime); -// } -// } - -// template -// void Cast::backward_task_with_2_type(Task const *task, -// std::vector const -// ®ions, Context ctx, Runtime *runtime) -// { -// assert(regions.size() == 2); -// assert(task->regions.size() == regions.size()); -// // Domain input_domain = runtime->get_index_space_domain( -// // ctx, task->regions[0].region.get_index_space()); -// Domain output_domain = runtime->get_index_space_domain( -// ctx, task->regions[1].region.get_index_space()); -// const IDT *input_ptr = helperGetTensorPointerRO( -// regions[0], task->regions[0], FID_DATA, ctx, runtime); -// ODT *output_ptr = helperGetTensorPointerRW( -// regions[1], task->regions[1], FID_DATA, ctx, runtime); -// backward_kernel_wrapper( -// input_ptr, output_ptr, output_domain.get_volume()); -// } - -void Cast::backward_task(Task const *task, +static void forward_task(Task const *task, std::vector const ®ions, Context ctx, Runtime *runtime) { - CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); - // if (m->output_data_type == DT_FLOAT) { - // Cast::backward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->output_data_type == DT_DOUBLE) { - // Cast::backward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->output_data_type == DT_INT32) { - // Cast::backward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->output_data_type == DT_INT64) { - // Cast::backward_task_with_1_type(task, regions, ctx, runtime); - // } - auto input_grad = acc.get_tensor(INPUT); - auto output_grad = acc.get_tensor(OUTPUT); - - profile(backward_kernel, - m->profiling, - "[Cast] forward_time = %.2lfms\n", - m, - input_grad, - output_grad) + TaskArgumentAccessor acc(task, regions, ctx, runtime); + forward_task_impl(acc); +} + +static optional backward_task_impl(TaskArgumentAccessor const &acc) { + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + ProfilingSettings profiling = acc.get_argument(PROFILING); + + auto input_grad = acc.get_tensor_grad(INPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); + + return profile(backward_kernel, + profiling, + "[Cast] forward_time = %.2lfms\n", + &per_device_state, + input_grad, + output_grad); +} + +static void backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + backward_task_impl(acc); } -bool Cast::measure_operator_cost(Simulator *sim, - MachineView const &mv, - CostMetrics &cost_metrics) const { +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + CastAttrs const &attrs, + InputParallelTensorDesc const &input_shape, + ProfilingSettings const &settings, + MachineView const &mv) { + auto env = sim.new_environment(); + // Assume cast has no cost - cost_metrics.forward_time = 0.0f; - cost_metrics.backward_time = 0.0f; - cost_metrics.inputs_memory = 0; - cost_metrics.outputs_memory = 0; - cost_metrics.weights_memory = 0; - return true; + float forward_time = 0.0; + float backward_time = 0.0; + float sync_time = 0.0; + return make_metrics(forward_time, backward_time, sync_time, env); } -void Cast::serialize(Legion::Serializer &sez) const { - sez.serialize(this->outputs[0]->data_type); +template <> +void register_task() { + OpTaskSignature init(OpTaskType::INIT); + + init.add_unchecked_arg_slot(HANDLE); + + init.add_input_slot(INPUT); + init.add_output_slot(OUTPUT); + + register_task(CAST_INIT_TASK_ID, "Cast Init", init, init_task); } -using PCG::Node; +template <> +void register_task() { + OpTaskSignature fwd(OpTaskType::FWD); + + fwd.add_arg_slot(PROFILING); + fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); -Node Cast::deserialize(FFModel &ff, - Legion::Deserializer &dez, - ParallelTensor inputs[], - int num_inputs) { - assert(num_inputs == 1); - DataType dtype; - dez.deserialize(dtype); - return ff.get_or_create_node(inputs[0], {dtype}); + fwd.add_input_slot(INPUT); + fwd.add_output_slot(OUTPUT); + + register_task(CAST_FWD_TASK_ID, "Cast Fwd", fwd, forward_task); } -Op *Cast::materialize(FFModel &ff, - ParallelTensor inputs[], - int num_inputs) const { - assert(num_inputs == 1); - return new Cast(ff, inputs[0], this->outputs[0]->data_type, this->name); +template <> +void register_task() { + OpTaskSignature bwd = infer_bwd_signature(get_op_signature(CAST_FWD_TASK_ID)); + + register_task(CAST_BWD_TASK_ID, "Cast Bwd", bwd, backward_task); } }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/cast.h b/lib/runtime/src/ops/cast.h index 7d346584d5..c3781ad783 100644 --- a/lib/runtime/src/ops/cast.h +++ b/lib/runtime/src/ops/cast.h @@ -16,8 +16,8 @@ #define _FLEXFLOW_CAST_H #include "op-attrs/ops/cast.h" -#include "op_task_invocation.h" #include "sim_environment.h" +#include "task_spec/op_task_invocation.h" namespace FlexFlow { @@ -33,11 +33,54 @@ OpTaskInvocation forward(CastAttrs const &); OpTaskInvocation backward(CastAttrs const &); CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, - BatchNormAttrs const &attrs, + CastAttrs const &attrs, ParallelTensorShape const &input_shape, ProfilingSettings const &settings, MachineView const &machine_view); +} // namespace FlexFlow + +#endif + +// template +// void Cast::backward_task_with_1_type(Task const *task, +// std::vector const +// ®ions, Context ctx, Runtime *runtime) +// { +// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); +// if (m->input_data_type == DT_FLOAT) { +// Cast::backward_task_with_2_type(task, regions, ctx, runtime); +// } else if (m->input_data_type == DT_DOUBLE) { +// Cast::backward_task_with_2_type(task, regions, ctx, +// runtime); +// } else if (m->input_data_type == DT_INT32) { +// Cast::backward_task_with_2_type(task, regions, ctx, +// runtime); +// } else if (m->input_data_type == DT_INT64) { +// Cast::backward_task_with_2_type(task, regions, ctx, +// runtime); +// } +// } + +// template +// void Cast::backward_task_with_2_type(Task const *task, +// std::vector const +// ®ions, Context ctx, Runtime *runtime) +// { +// assert(regions.size() == 2); +// assert(task->regions.size() == regions.size()); +// // Domain input_domain = runtime->get_index_space_domain( +// // ctx, task->regions[0].region.get_index_space()); +// Domain output_domain = runtime->get_index_space_domain( +// ctx, task->regions[1].region.get_index_space()); +// const IDT *input_ptr = helperGetTensorPointerRO( +// regions[0], task->regions[0], FID_DATA, ctx, runtime); +// ODT *output_ptr = helperGetTensorPointerRW( +// regions[1], task->regions[1], FID_DATA, ctx, runtime); +// backward_kernel_wrapper( +// input_ptr, output_ptr, output_domain.get_volume()); +// } + /* class Cast : public Op { */ /* public: */ /* Cast(FFModel &model, */ @@ -80,6 +123,222 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, /* CostMetrics &cost_metrics) const; */ /* }; */ -} // namespace FlexFlow +// void Cast::backward(FFModel const &ff) { +// this->execute_task(ff, CAST_BWD_TASK_ID, get_bwd_task_signature()); +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_backward(ff, argmap); +// IndexLauncher launcher(CAST_BWD_TASK_ID, +// parallel_is, +// TaskArgument(NULL, false), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// outputs[0]->region_grad)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// inputs[0]->region_grad)); +// launcher.add_field(1, FID_DATA); +// runtime->execute_index_space(ctx, launcher); +// } -#endif +// template +// void Cast::forward_task_with_1_type(Task const *task, +// std::vector const +// ®ions, Context ctx, Runtime *runtime) +// { +// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); +// if (m->output_data_type == DT_FLOAT) { +// Cast::forward_task_with_2_type(task, regions, ctx, runtime); +// } else if (m->output_data_type == DT_DOUBLE) { +// Cast::forward_task_with_2_type(task, regions, ctx, runtime); +// } else if (m->output_data_type == DT_INT32) { +// Cast::forward_task_with_2_type(task, regions, ctx, +// runtime); +// } else if (m->output_data_type == DT_INT64) { +// Cast::forward_task_with_2_type(task, regions, ctx, +// runtime); +// } +// } + +// template +// void Cast::forward_task_with_2_type(Task const *task, +// std::vector const +// ®ions, Context ctx, Runtime *runtime) +// { +// assert(regions.size() == 2); +// assert(task->regions.size() == regions.size()); +// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); +// // Domain input_domain = runtime->get_index_space_domain( +// // ctx, task->regions[0].region.get_index_space()); +// Domain output_domain = runtime->get_index_space_domain( +// ctx, task->regions[1].region.get_index_space()); +// const IDT *input_ptr = helperGetTensorPointerRO( +// regions[0], task->regions[0], FID_DATA, ctx, runtime); +// ODT *output_ptr = helperGetTensorPointerWO( +// regions[1], task->regions[1], FID_DATA, ctx, runtime); +// forward_kernel_wrapper( +// m, input_ptr, output_ptr, output_domain.get_volume()); +// } + +// void Cast::forward(FFModel const &ff) { +// this->execute_task(ff, CAST_FWD_TASK_ID, get_fwd_task_signature()); +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_forward(ff, argmap); +// IndexLauncher launcher(CAST_FWD_TASK_ID, +// parallel_is, +// TaskArgument(NULL, false), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(1, FID_DATA); +// runtime->execute_index_space(ctx, launcher); +// } + +// void Cast::init(FFModel const &ff) { +// this->execute_task(ff, CAST_INIT_TASK_ID, get_init_task_signature()); +// assert(check_output_input_weight_same_parallel_is()); +// parallel_is = outputs[0]->parallel_is; +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_init(ff, argmap); +// IndexLauncher launcher(CAST_INIT_TASK_ID, +// parallel_is, +// TaskArgument(this, sizeof(Cast)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(1, FID_DATA); +// FutureMap fm = runtime->execute_index_space(ctx, launcher); +// fm.wait_all_results(); +// set_opmeta_from_futuremap(ff, fm); +// } + +// void Cast::serialize(Legion::Serializer &sez) const { +// sez.serialize(this->outputs[0]->data_type); +// } + +// using PCG::Node; + +// Node Cast::deserialize(FFModel &ff, +// Legion::Deserializer &dez, +// ParallelTensor inputs[], +// int num_inputs) { +// assert(num_inputs == 1); +// DataType dtype; +// dez.deserialize(dtype); +// return ff.get_or_create_node(inputs[0], {dtype}); +// } + +// Op *Cast::materialize(FFModel &ff, +// ParallelTensor inputs[], +// int num_inputs) const { +// assert(num_inputs == 1); +// return new Cast(ff, inputs[0], this->outputs[0]->data_type, this->name); +// } + +// Cast::Cast(FFModel &model, +// ParallelTensor const &input, +// DataType _dtype, +// char const *name) +// : Op(model, +// OP_CAST, +// _dtype, +// name, +// 1 /*inputs*/, +// 0 /*weights*/, +// 1 /*outputs*/, +// input) { +// numOutputs = 1; +// numWeights = 0; +// int numdim = input->num_dims; +// ParallelDim dims[MAX_TENSOR_DIM]; +// for (int i = 0; i < numdim; i++) { +// dims[i] = input->dims[i]; +// } +// outputs[0] = +// model.create_parallel_tensor_legion_ordering(numdim, dims, _dtype, +// this); +// } + +// Tensor FFModel::cast(const Tensor input, DataType dtype, char const *name) { +// Layer *cast = new Layer(this, +// OP_CAST, +// dtype, +// name, +// 1 /*inputs*/, +// 0 /*weights*/, +// 1 /*outputs*/, +// input); +// int numdims = input->num_dims; +// int dims[MAX_TENSOR_DIM]; +// for (int i = 0; i < numdims; i++) { +// dims[i] = input->dims[i]; +// } +// cast->outputs[0] = create_tensor_legion_ordering( +// numdims, dims, dtype, cast, 0, true /*create_grad*/); +// cast->add_int_property("dtype", dtype); +// layers.push_back(cast); +// return cast->outputs[0]; +// } + +// Op *Cast::create_operator_from_layer( +// FFModel &model, +// Layer const *layer, +// std::vector const &inputs) { +// long long value; +// layer->get_int_property("dtype", value); +// DataType dtype = (DataType)value; +// return new Cast(model, inputs[0], dtype, layer->name); +// } + +// CastParams Cast::get_params() const { +// CastParams params; +// params.dtype = this->outputs[0]->data_type; +// return params; +// } + +// Cast::Cast(FFModel &model, +// CastParams const ¶ms, +// ParallelTensor const &input, +// char const *name) +// : Cast(model, input, params.dtype, name) {} diff --git a/lib/runtime/src/ops/combine.cc b/lib/runtime/src/ops/combine.cc index 2485955124..a22b317d10 100644 --- a/lib/runtime/src/ops/combine.cc +++ b/lib/runtime/src/ops/combine.cc @@ -13,322 +13,158 @@ * limitations under the License. */ -#include "parallel_ops/combine.h" +#include "combine.h" #include "kernels/combine_kernels.h" #include "utils/hash-utils.h" namespace FlexFlow { // declare Legion names -using Legion::ArgumentMap; using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; -using Legion::LogicalPartition; -using Legion::LogicalRegion; -using Legion::Machine; -using Legion::Memory; using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; using Legion::Runtime; using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; using namespace FlexFlow::Kernels::Combine; -CombineParams Combine::get_params() const { - CombineParams params; - params.combine_legion_dim = this->combine_dim; - params.combine_degree = this->combine_degree; - return params; -} +enum Slots { INPUT, OUTPUT, PROFILING, PER_DEVICE_STATE }; -Combine::Combine(FFModel &model, - CombineParams const ¶ms, - ParallelTensor const input, - char const *name) - : Combine(model, - input, - params.combine_legion_dim, - params.combine_degree, - name) {} - -Combine::Combine(FFModel &model, - const ParallelTensor _input, - int _combine_legion_dim, - int _combine_degree, - char const *name) - : ParallelOp(model, OP_COMBINE, name, _input), - combine_dim(_combine_legion_dim), combine_degree(_combine_degree) { - int numdim = _input->num_dims; - ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdim; i++) { - dims[i] = _input->dims[i]; - } - assert(combine_degree > 0 && "Must use combine_degree > 0"); - assert(dims[combine_dim].degree % combine_degree == 0); - dims[combine_dim].degree /= combine_degree; - ParallelTensorBase::update_parallel_ids(numdim, dims); - outputs[0] = model.create_parallel_tensor_legion_ordering( - numdim, dims, DT_FLOAT, this); - // inputs[0]->print("Combine::input"); - // outputs[0]->print("Combine::output"); -} +OpTaskInvocation init(CombineAttrs const &attrs) { + OpTaskBinding binding; + + binding.bind(INPUT, input_tensor(0)); -PerDeviceOpState *Combine::init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - Combine *rep = (Combine *)task->args; - // FFHandler handle = *((FFHandler *)task->local_args); - // CombineMeta* m = new CombineMeta(handle); - // m->data_type = rep->outputs[0]->data_type; - return nullptr; + return {COMBINE_INIT_TASK_ID, binding}; } -void Combine::init(FFModel const &ff) { - parallel_is = outputs[0]->parallel_is; - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - assert(numOutputs == 1); - assert(numInputs == 1); - IndexLauncher launcher(COMBINE_INIT_TASK_ID, - parallel_is, - TaskArgument(this, sizeof(Combine)), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - outputs[0]->machine_view.hash()); - launcher.add_region_requirement(RegionRequirement( - input_lp, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, inputs[0]->region)); - launcher.add_field(0, FID_DATA); - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - outputs[0]->region)); - launcher.add_field(1, FID_DATA); - FutureMap fm = runtime->execute_index_space(ctx, launcher); - fm.wait_all_results(); +OpTaskInvocation forward(CombineAttrs const &attrs) { + OpTaskBinding binding; + + binding.bind_arg(PER_DEVICE_STATE, + per_device_op_state()); + binding.bind_arg(PROFILING, profiling_settings()); + + binding.bind(INPUT, input_tensor(0)); + binding.bind(OUTPUT, output_tensor(0)); + + return {COMBINE_FWD_TASK_ID, binding}; } -void Combine::create_input_partition(FFModel &ff) { - assert(outputs[0]->part != LogicalPartition::NO_PART); - assert(inputs[0]->part != LogicalPartition::NO_PART); - ff.create_disjoint_partition(outputs[0]->num_dims, - outputs[0]->dims, - outputs[0]->parallel_is, - inputs[0]->region, - input_lp); - ff.create_disjoint_partition(inputs[0]->num_dims, - inputs[0]->dims, - inputs[0]->parallel_is, - outputs[0]->region_grad, - output_grad_lp); +OpTaskInvocation backward(CombineAttrs const &attrs) { + OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); + + return {COMBINE_BWD_TASK_ID, b}; } -void Combine::forward(FFModel const &ff) { - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - assert(numOutputs == 1); - assert(numInputs == 1); - assert(inputs[0]->data_type == outputs[0]->data_type); - DataType data_type = inputs[0]->data_type; - IndexLauncher launcher(COMBINE_FWD_TASK_ID, - outputs[0]->parallel_is, - TaskArgument(&data_type, sizeof(data_type)), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - outputs[0]->machine_view.hash()); - launcher.add_region_requirement(RegionRequirement( - input_lp, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, inputs[0]->region)); - launcher.add_field(0, FID_DATA); - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - outputs[0]->region)); - launcher.add_field(1, FID_DATA); - runtime->execute_index_space(ctx, launcher); +static DeviceSpecific + init_task_impl(TaskArgumentAccessor const &acc) { + + auto input = acc.get_tensor(INPUT); + + DeviceSpecific per_device_state = + acc.create_device_specific( + init_kernel(input.data_type)); + return per_device_state; } -void Combine::backward(FFModel const &ff) { - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - assert(numOutputs == 1); - assert(numInputs == 1); - assert(inputs[0]->data_type == outputs[0]->data_type); - DataType data_type = inputs[0]->data_type; - IndexLauncher launcher(COMBINE_BWD_TASK_ID, - inputs[0]->parallel_is, - TaskArgument(&data_type, sizeof(DataType)), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - inputs[0]->machine_view.hash()); - launcher.add_region_requirement(RegionRequirement(output_grad_lp, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - outputs[0]->region_grad)); - launcher.add_field(0, FID_DATA); - launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, - 0 /*projection id*/, - READ_WRITE, - EXCLUSIVE, - inputs[0]->region_grad)); - launcher.add_field(1, FID_DATA); - runtime->execute_index_space(ctx, launcher); +static DeviceSpecific + init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + return init_task_impl(acc); } -bool Combine::measure_operator_cost(Simulator *sim, - MachineView const &mv, - CostMetrics &cost_metrics) const { - // TODO: to be implemented - cost_metrics = CostMetrics(); - cost_metrics.forward_time = 0.05f; - cost_metrics.backward_time = 0.05f; - return true; +static optional forward_task_impl(TaskArgumentAccessor const &acc) { + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + ProfilingSettings profiling = acc.get_argument(PROFILING); + + auto input = acc.get_tensor(INPUT); + auto output = acc.get_tensor(OUTPUT); + + return profile(forward_kernel, + profiling, + "[Combine] forward_time = %.2lfms\n", + &per_device_state, + input, + output); } -bool Combine::get_int_parameter(PMParameter para, int *value) const { - switch (para) { - case PM_COMBINE_DIM: - *value = combine_dim; - return true; - case PM_COMBINE_DEGREE: - *value = combine_degree; - return true; - default: - return Op::get_int_parameter(para, value); - } +static void forward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + forward_task_impl(acc); } -bool Combine::append_parallel_op_info( - std::vector ¶llel_ops) const { - ParallelOpInfo ret; - ret.op_type = op_type; - ret.parallel_dim = combine_dim; - ret.parallel_degree = combine_degree; - parallel_ops.push_back(ret); - return true; +static optional backward_task_impl(TaskArgumentAccessor const &acc) { + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + ProfilingSettings profiling = acc.get_argument(PROFILING); + + auto input_grad = acc.get_tensor_grad(INPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); + + return profile(backward_kernel, + profiling, + "[Combine] forward_time = %.2lfms\n", + &per_device_state, + input_grad, + output_grad); } -tl::optional Combine::as_dot() const { - RecordFormatter rf; - { - std::ostringstream oss; - oss << "dim(" << this->combine_dim << ")"; - rf << oss.str(); - } - { - std::ostringstream oss; - oss << "deg(" << this->combine_degree << ")"; - rf << oss.str(); - } - return rf; +static void backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + backward_task_impl(acc); } -/*static*/ -void Combine::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 2); - assert(task->regions.size() == 2); - DataType data_type = *((DataType *)task->args); - if (data_type == DT_FLOAT) { - forward_task_with_type(task, regions, ctx, runtime); - } else if (data_type == DT_DOUBLE) { - forward_task_with_type(task, regions, ctx, runtime); - } else if (data_type == DT_INT32) { - forward_task_with_type(task, regions, ctx, runtime); - } else if (data_type == DT_INT64) { - forward_task_with_type(task, regions, ctx, runtime); - } else { - assert(false && "Unsupported data type in Combine forward"); - } +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + CombineAttrs const &attrs, + InputParallelTensorDesc const &input_shape, + ProfilingSettings const &settings, + MachineView const &mv) { + auto env = sim.new_environment(); + // TODO: to be implemented + float forward_time = 0.5; + float backward_time = 0.5; + float sync_time = 0.0; + return make_metrics(forward_time, backward_time, sync_time, env); } -template -void Combine::forward_task_with_type(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - Domain input_domain = runtime->get_index_space_domain( - ctx, task->regions[0].region.get_index_space()); - Domain output_domain = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); - assert(output_domain == input_domain); - - const DT *input_ptr = helperGetTensorPointerRO
( - regions[0], task->regions[0], FID_DATA, ctx, runtime); - DT *output_ptr = helperGetTensorPointerWO
( - regions[1], task->regions[1], FID_DATA, ctx, runtime); - - forward_kernel
(input_ptr, output_ptr, output_domain.get_volume()); +template <> +void register_task() { + OpTaskSignature init(OpTaskType::INIT); + + init.add_input_slot(INPUT); + + register_task(COMBINE_INIT_TASK_ID, "Combine Init", init, init_task); } -void Combine::backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 2); - assert(task->regions.size() == 2); - DataType data_type = *((DataType *)task->args); - if (data_type == DT_FLOAT) { - backward_task_with_type(task, regions, ctx, runtime); - } else if (data_type == DT_DOUBLE) { - backward_task_with_type(task, regions, ctx, runtime); - } else if (data_type == DT_INT32) { - backward_task_with_type(task, regions, ctx, runtime); - } else if (data_type == DT_INT64) { - backward_task_with_type(task, regions, ctx, runtime); - } else { - assert(false && "Unsupported data type in Combine backward"); - } +template <> +void register_task() { + OpTaskSignature fwd(OpTaskType::FWD); + + fwd.add_arg_slot(PROFILING); + fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); + + fwd.add_input_slot(INPUT); + fwd.add_output_slot(OUTPUT); + register_task(COMBINE_FWD_TASK_ID, "Combine Fwd", fwd, forward_task); } -template -void Combine::backward_task_with_type( - Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - Domain output_grad_domain = runtime->get_index_space_domain( - ctx, task->regions[0].region.get_index_space()); - Domain input_grad_domain = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); - assert(output_grad_domain == input_grad_domain); - - const DT *output_grad_ptr = helperGetTensorPointerRO
( - regions[0], task->regions[0], FID_DATA, ctx, runtime); - DT *input_grad_ptr = helperGetTensorPointerRW
( - regions[1], task->regions[1], FID_DATA, ctx, runtime); - - backward_kernel
( - output_grad_ptr, input_grad_ptr, output_grad_domain.get_volume()); +template <> +void register_task() { + OpTaskSignature bwd = + infer_bwd_signature(get_op_signature(COMBINE_FWD_TASK_ID)); + + register_task(COMBINE_BWD_TASK_ID, "Combine Bwd", bwd, backward_task); } }; // namespace FlexFlow -namespace std { -size_t hash::operator()( - FlexFlow::CombineParams const ¶ms) const { - size_t key = 0; - hash_combine(key, params.combine_legion_dim); - hash_combine(key, params.combine_degree); - return key; -} -}; // namespace std +namespace std {}; // namespace std diff --git a/lib/runtime/src/ops/combine.h b/lib/runtime/src/ops/combine.h index 455a8ea780..512c3be363 100644 --- a/lib/runtime/src/ops/combine.h +++ b/lib/runtime/src/ops/combine.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_COMBINE_H #include "op-attrs/ops/combine.h" -#include "op_task_invocation.h" #include "sim_environment.h" +#include "task_spec/op_task_invocation.h" namespace FlexFlow { @@ -20,7 +20,7 @@ OpTaskInvocation backward(CombineAttrs const &); CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, CombineAttrs const &attrs, - ParallelTensorShape const &input_shape, + InputParallelTensorDesc const &input_shape, ProfilingSettings const &settings, MachineView const &machine_view); @@ -82,3 +82,273 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, } // namespace FlexFlow #endif + +// size_t hash::operator()( +// FlexFlow::CombineParams const ¶ms) const { +// size_t key = 0; +// hash_combine(key, params.combine_legion_dim); +// hash_combine(key, params.combine_degree); +// return key; +// } + +// template +// void Combine::backward_task_with_type( +// Task const *task, +// std::vector const ®ions, +// Context ctx, +// Runtime *runtime) { +// Domain output_grad_domain = runtime->get_index_space_domain( +// ctx, task->regions[0].region.get_index_space()); +// Domain input_grad_domain = runtime->get_index_space_domain( +// ctx, task->regions[1].region.get_index_space()); +// assert(output_grad_domain == input_grad_domain); + +// const DT *output_grad_ptr = helperGetTensorPointerRO
( +// regions[0], task->regions[0], FID_DATA, ctx, runtime); +// DT *input_grad_ptr = helperGetTensorPointerRW
( +// regions[1], task->regions[1], FID_DATA, ctx, runtime); + +// backward_kernel
( +// output_grad_ptr, input_grad_ptr, output_grad_domain.get_volume()); +// } + +// void Combine::backward_task(Task const *task, +// std::vector const ®ions, +// Context ctx, +// Runtime *runtime) { +// assert(regions.size() == 2); +// assert(task->regions.size() == 2); +// DataType data_type = *((DataType *)task->args); +// if (data_type == DT_FLOAT) { +// backward_task_with_type(task, regions, ctx, runtime); +// } else if (data_type == DT_DOUBLE) { +// backward_task_with_type(task, regions, ctx, runtime); +// } else if (data_type == DT_INT32) { +// backward_task_with_type(task, regions, ctx, runtime); +// } else if (data_type == DT_INT64) { +// backward_task_with_type(task, regions, ctx, runtime); +// } else { +// assert(false && "Unsupported data type in Combine backward"); +// } +// } + +// bool Combine::get_int_parameter(PMParameter para, int *value) const { +// switch (para) { +// case PM_COMBINE_DIM: +// *value = combine_dim; +// return true; +// case PM_COMBINE_DEGREE: +// *value = combine_degree; +// return true; +// default: +// return Op::get_int_parameter(para, value); +// } +// } + +// bool Combine::append_parallel_op_info( +// std::vector ¶llel_ops) const { +// ParallelOpInfo ret; +// ret.op_type = op_type; +// ret.parallel_dim = combine_dim; +// ret.parallel_degree = combine_degree; +// parallel_ops.push_back(ret); +// return true; +// } + +// tl::optional Combine::as_dot() const { +// RecordFormatter rf; +// { +// std::ostringstream oss; +// oss << "dim(" << this->combine_dim << ")"; +// rf << oss.str(); +// } +// { +// std::ostringstream oss; +// oss << "deg(" << this->combine_degree << ")"; +// rf << oss.str(); +// } +// return rf; +// } + +// void Combine::init(FFModel const &ff) { +// parallel_is = outputs[0]->parallel_is; +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// assert(numOutputs == 1); +// assert(numInputs == 1); +// IndexLauncher launcher(COMBINE_INIT_TASK_ID, +// parallel_is, +// TaskArgument(this, sizeof(Combine)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement( +// input_lp, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(1, FID_DATA); +// FutureMap fm = runtime->execute_index_space(ctx, launcher); +// fm.wait_all_results(); +// } + +// void Combine::create_input_partition(FFModel &ff) { +// assert(outputs[0]->part != LogicalPartition::NO_PART); +// assert(inputs[0]->part != LogicalPartition::NO_PART); +// ff.create_disjoint_partition(outputs[0]->num_dims, +// outputs[0]->dims, +// outputs[0]->parallel_is, +// inputs[0]->region, +// input_lp); +// ff.create_disjoint_partition(inputs[0]->num_dims, +// inputs[0]->dims, +// inputs[0]->parallel_is, +// outputs[0]->region_grad, +// output_grad_lp); +// } + +// void Combine::forward(FFModel const &ff) { +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// assert(numOutputs == 1); +// assert(numInputs == 1); +// assert(inputs[0]->data_type == outputs[0]->data_type); +// DataType data_type = inputs[0]->data_type; +// IndexLauncher launcher(COMBINE_FWD_TASK_ID, +// outputs[0]->parallel_is, +// TaskArgument(&data_type, sizeof(data_type)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement( +// input_lp, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(1, FID_DATA); +// runtime->execute_index_space(ctx, launcher); +// } + +// void Combine::backward(FFModel const &ff) { +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// assert(numOutputs == 1); +// assert(numInputs == 1); +// assert(inputs[0]->data_type == outputs[0]->data_type); +// DataType data_type = inputs[0]->data_type; +// IndexLauncher launcher(COMBINE_BWD_TASK_ID, +// inputs[0]->parallel_is, +// TaskArgument(&data_type, sizeof(DataType)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// inputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(output_grad_lp, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// outputs[0]->region_grad)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, +// 0 /*projection id*/, +// READ_WRITE, +// EXCLUSIVE, +// inputs[0]->region_grad)); +// launcher.add_field(1, FID_DATA); +// runtime->execute_index_space(ctx, launcher); +// } + +// CombineParams Combine::get_params() const { +// CombineParams params; +// params.combine_legion_dim = this->combine_dim; +// params.combine_degree = this->combine_degree; +// return params; +// } + +// Combine::Combine(FFModel &model, +// CombineParams const ¶ms, +// ParallelTensor const input, +// char const *name) +// : Combine(model, +// input, +// params.combine_legion_dim, +// params.combine_degree, +// name) {} + +// Combine::Combine(FFModel &model, +// const ParallelTensor _input, +// int _combine_legion_dim, +// int _combine_degree, +// char const *name) +// : ParallelOp(model, OP_COMBINE, name, _input), +// combine_dim(_combine_legion_dim), combine_degree(_combine_degree) { +// int numdim = _input->num_dims; +// ParallelDim dims[MAX_TENSOR_DIM]; +// for (int i = 0; i < numdim; i++) { +// dims[i] = _input->dims[i]; +// } +// assert(combine_degree > 0 && "Must use combine_degree > 0"); +// assert(dims[combine_dim].degree % combine_degree == 0); +// dims[combine_dim].degree /= combine_degree; +// ParallelTensorBase::update_parallel_ids(numdim, dims); +// outputs[0] = model.create_parallel_tensor_legion_ordering( +// numdim, dims, DT_FLOAT, this); +// // inputs[0]->print("Combine::input"); +// // outputs[0]->print("Combine::output"); +// } + +// /*static*/ +// void Combine::forward_task(Task const *task, +// std::vector const ®ions, +// Context ctx, +// Runtime *runtime) { +// assert(regions.size() == 2); +// assert(task->regions.size() == 2); +// DataType data_type = *((DataType *)task->args); +// if (data_type == DT_FLOAT) { +// forward_task_with_type(task, regions, ctx, runtime); +// } else if (data_type == DT_DOUBLE) { +// forward_task_with_type(task, regions, ctx, runtime); +// } else if (data_type == DT_INT32) { +// forward_task_with_type(task, regions, ctx, runtime); +// } else if (data_type == DT_INT64) { +// forward_task_with_type(task, regions, ctx, runtime); +// } else { +// assert(false && "Unsupported data type in Combine forward"); +// } +// } + +// template +// void Combine::forward_task_with_type(Task const *task, +// std::vector const +// ®ions, Context ctx, Runtime *runtime) +// { +// Domain input_domain = runtime->get_index_space_domain( +// ctx, task->regions[0].region.get_index_space()); +// Domain output_domain = runtime->get_index_space_domain( +// ctx, task->regions[1].region.get_index_space()); +// assert(output_domain == input_domain); + +// const DT *input_ptr = helperGetTensorPointerRO
( +// regions[0], task->regions[0], FID_DATA, ctx, runtime); +// DT *output_ptr = helperGetTensorPointerWO
( +// regions[1], task->regions[1], FID_DATA, ctx, runtime); + +// forward_kernel
(input_ptr, output_ptr, output_domain.get_volume()); +// } diff --git a/lib/runtime/src/ops/concat.cc b/lib/runtime/src/ops/concat.cc index f17a33b956..d8f610c3ea 100644 --- a/lib/runtime/src/ops/concat.cc +++ b/lib/runtime/src/ops/concat.cc @@ -16,537 +16,199 @@ #include "concat.h" #include "kernels/concat_kernels.h" #include "legion/legion_utilities.h" -#include "task_spec.h" +#include "op-attrs/get_output_shapes.h" +#include "task_spec/variadic_tensor_ref.h" #include "utils/hash-utils.h" namespace FlexFlow { -enum Slots { - INPUTS, - OUTPUT, - INPUT_GRADS, - OUTPUT_GRAD, - ATTRS, - PROFILING -} +using namespace FlexFlow::Kernels::Concat; -// declare Legion names -using Legion::ArgumentMap; using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; using Legion::Runtime; using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; -using PCG::Node; -using namespace FlexFlow::Kernels::Concat; +enum Slots { + INPUTS, + OUTPUT, + ATTRS, + PROFILING, + HANDLE, + PER_DEVICE_STATE, + NUM_INPUTS +}; -bool operator==(ConcatParams const &lhs, ConcatParams const &rhs) { - return lhs.axis == rhs.axis; -} +OpTaskInvocation init(ConcatAttrs const &attrs) { + OpTaskBinding binding; -ConcatParams Concat::get_params() const { - ConcatParams params; - params.axis = legion_axis; - return params; -} + binding.bind_arg(PROFILING, profiling_settings()); + binding.bind_arg(ATTRS, attrs); -Tensor - FFModel::concat(int n, Tensor const *tensors, int axis, char const *name) { - Layer *concat = new Layer(this, - OP_CONCAT, - DT_FLOAT, - name, - n /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - tensors); - int numdim = tensors[0]->num_dims; - // Making sure axis is between [0, numdim) - axis = (axis % numdim + numdim) % numdim; - int dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdim; i++) { - dims[i] = tensors[0]->dims[i]; - } - for (int i = 1; i < n; i++) { - assert(tensors[i]->data_type == tensors[0]->data_type); - assert(tensors[i]->num_dims == tensors[0]->num_dims); - for (int j = 0; j < numdim; j++) { - if (j != numdim - axis - 1) { - assert(tensors[i]->dims[j] == tensors[0]->dims[j]); - } else { - dims[j] += tensors[i]->dims[j]; - } - } - } - concat->outputs[0] = create_tensor_legion_ordering( - numdim, dims, tensors[0]->data_type, concat, 0, true /*create_grad*/); - concat->add_int_property("legion_axis", numdim - axis - 1); - layers.push_back(concat); - return concat->outputs[0]; + return {CONCAT_INIT_TASK_ID, binding}; } -Op *Concat::create_operator_from_layer( - FFModel &model, - Layer const *layer, - std::vector const &inputs) { - long long value; - layer->get_int_property("legion_axis", value); - int legion_axis = value; - return new Concat( - model, inputs.size(), inputs.data(), legion_axis, layer->name); -} +OpTaskInvocation forward(ConcatAttrs const &attrs) { + OpTaskBinding binding; + binding.bind_arg(PER_DEVICE_STATE, + per_device_op_state()); + binding.bind(INPUTS, get_input_tensors()); + binding.bind(OUTPUT, output_tensor(0)); + binding.bind(NUM_INPUTS, get_number_inputs()); + binding.bind_arg(PROFILING, profiling_settings()); -Concat::Concat(FFModel &model, - int _n, - ParallelTensor const *_tensors, - int _legion_axis, - char const *name) - : Op(model, - OP_CONCAT, - DT_FLOAT, - name, - _n /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - _tensors), - legion_axis(_legion_axis) { - int num_dim = inputs[0]->num_dims; - ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < num_dim; i++) { - dims[i] = inputs[0]->dims[i]; - } - for (int i = 1; i < numInputs; i++) { - assert(inputs[i]->data_type == inputs[0]->data_type); - assert(inputs[i]->num_dims == inputs[0]->num_dims); - for (int j = 0; j < num_dim; j++) { - if (j != legion_axis) { - assert(inputs[i]->dims[j] == inputs[0]->dims[j]); - } else { - // Assert that the concat dim cannot be parallelized - assert(inputs[i]->dims[j].parallel_idx == -1); - assert(inputs[i]->dims[j].degree == 1); - dims[j].size += inputs[i]->dims[j].size; - } - } - } - numOutputs = 1; - outputs[0] = model.create_parallel_tensor_legion_ordering( - num_dim, dims, inputs[0]->data_type, this); + return {CONCAT_FWD_TASK_ID, binding}; } -Concat::Concat(FFModel &model, - ConcatParams const ¶ms, - std::vector const &inputs, - char const *name) - : Concat(model, inputs.size(), inputs.data(), params.axis, name) {} +OpTaskInvocation backward(ConcatAttrs const &attrs) { + OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); -static OpTaskSignature get_init_task_signature() { - OpTaskSignature init(OpTaskType::INIT); + return {CONCAT_BWD_TASK_ID, b}; +} - init.add_arg_slot(ATTRS); - init.add_arg_slot(PROFILING); +static DeviceSpecific + init_task_impl(TaskArgumentAccessor const &acc) { + auto const &attrs = acc.get_argument(ATTRS); + PerDeviceFFHandle handle = acc.get_argument(HANDLE); - init.add_input_slot(INPUTS, SlotType::VARIADIC); - init.add_output_slot(OUTPUT); + DeviceSpecific per_device_state = + acc.create_device_specific(init_kernel(attrs.axis)); + return per_device_state; +} - return init; +static DeviceSpecific + init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + return init_task_impl(acc); } -static OpTaskSignature get_fwd_task_signature() { - OpTaskSignature fwd(OpTaskType::FWD); +static optional forward_task_impl(TaskArgumentAccessor const &acc) { + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + int number_of_inputs = acc.get_argument(NUM_INPUTS); + ProfilingSettings profiling = acc.get_argument(PROFILING); - fwd.add_arg_slot(ATTRS); + auto output = acc.get_tensor(OUTPUT); + auto inputs = acc.get_variadic_tensor(INPUTS); - fwd.add_input_slot(INPUTS, SlotType::VARIADIC); - fwd.add_output_slot(OUTPUT); + return profile(forward_kernel, + profiling, + "[Concat] forward_time = %.2lfms\n", + &per_device_state, + output, + inputs, + number_of_inputs); +} - return init; +static void forward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + forward_task_impl(acc); } -static OpTaskSignature get_bwd_task_signature() { - OpTaskSignature bwd(OpTaskType::BWD); +static optional backward_task_impl(TaskArgumentAccessor const &acc) { + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + int number_of_inputs = acc.get_argument(NUM_INPUTS); + ProfilingSettings profiling = acc.get_argument(PROFILING); - bwd.add_arg_slot(ATTRS); + auto input_grads = acc.get_variadic_tensor_grad(INPUTS); + auto output_grad = acc.get_tensor_grad(OUTPUT); - bwd.add_input_grad_slot(INPUT_GRADS, SlotType::VARIADIC); - bwd.add_output_grad_slot(OUTPUT_GRAD); + assert(number_of_inputs <= MAX_NUM_INPUTS); - return bwd; + return profile(backward_kernel, + profiling, + "[Concat] backward_time = %.2lfms\n", + &per_device_state, + output_grad, + input_grads, + number_of_inputs); } -OpTaskBinding Concat::get_init_task_binding() const { - OpTaskBinding binding; - - binding.bind_arg(PROFILING, this->profiling); - binding.bind_arg(ATTRS, this->attrs); - - return binding; +static void backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + backward_task_impl(acc); } -OpTaskBinding Concat::get_fwd_task_binding() const { - OpTaskBinding binding; +CostMetrics + measure_operator_cost(SimEnvFactory const &sim, + ConcatAttrs const &attrs, + InputVariadicParallelTensorDesc const &inputs_shape, + ProfilingSettings const &settings, + MachineView const &mv) { + int numInputs = (inputs_shape.shapes).size(); + assert(numInputs <= MAX_NUM_INPUTS); - binding.bind_arg(ATTRS, this->attrs); + auto env = sim.new_environment(); - for (int i = 0; i < this->attrs.n; i++) { - binding.bind(INPUTS, input_tensor(i)); - } + ParallelTensorShape output_shape = + get_output_shape(attrs, inputs_shape.shapes); - binding.bind(OUTPUT, output_tensor(0)); + SimTaskBinding init_binding; + init_binding.bind_arg(PROFILING, settings); + init_binding.bind_arg(ATTRS, attrs); - return binding; -} + auto init_accessor = env.get_init_accessor(CONCAT_INIT_TASK_ID, init_binding); + DeviceSpecific per_device_state = + init_task_impl(init_accessor); -OpTaskBinding Concat::get_bwd_task_binding() const { - OpTaskBinding binding; + SimTaskBinding fwd_binding; + fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); + fwd_binding.bind(INPUTS, inputs_shape); + fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind_arg(NUM_INPUTS, numInputs); + fwd_binding.bind_arg(PROFILING, settings); - binding.bind_arg(ATTRS, this->attrs); + SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); - for (int i = 0; i < this->attrs.n; i++) { - binding.bind(INPUT_GRADS, input_tensor(i).grad()); - } + auto fwd_accessor = env.get_fwd_accessor(CONCAT_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = env.get_bwd_accessor(CONCAT_BWD_TASK_ID, bwd_binding); - binding.bind(OUTPUT_GRAD, output_tensor(0).grad()); + float forward_time = forward_task_impl(fwd_accessor).value(); + float backward_time = backward_task_impl(bwd_accessor).value(); - return binding; + float sync_time = default_estimate_sync_time(env); + return make_metrics(forward_time, backward_time, sync_time, env); } -void Concat::init(FFModel const &ff) { - this->execute_task(ff, CONCAT_INIT_TASK_ID, get_init_task_signature()); - // assert(check_output_input_weight_same_parallel_is()); - // parallel_is = outputs[0]->parallel_is; - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_init(ff, argmap); - // IndexLauncher launcher(CONCAT_INIT_TASK_ID, - // parallel_is, - // TaskArgument(this, sizeof(Concat)), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // for (int i = 0; i < numInputs; i++) { - // launcher.add_region_requirement(RegionRequirement(inputs[i]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[i]->region)); - // launcher.add_field(i + 1, FID_DATA); - // } - // for (int i = 0; i < numInputs; i++) { - // launcher.add_region_requirement(RegionRequirement(inputs[i]->part_grad, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // inputs[i]->region_grad)); - // launcher.add_field(i + numInputs + 1, FID_DATA); - // } - // FutureMap fm = runtime->execute_index_space(ctx, launcher); - // fm.wait_all_results(); - // set_opmeta_from_futuremap(ff, fm); -} +template <> +void register_task() { + OpTaskSignature init(OpTaskType::INIT); -PerDeviceOpState *Concat::init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - auto const &attrs = acc.get_argument(ATTRS); - bool profiling = acc.get_argument(PROFILING); - - FFHandler handler = *((FFHandler const *)task->local_args); - ConcatPerDeviceState *m = new ConcatPerDeviceState(handler); - // Note that our internal axis index ordering is opposite to other frameworks - init_meta(m, attrs.legion_axis); - m->profiling = profiling; - std::strcpy(m->op_name, attrs.name); - return m; -} + init.add_arg_slot(ATTRS); + init.add_arg_slot(PROFILING); -void Concat::forward(FFModel const &ff) { - this->execute_task(ff, CONCAT_FWD_TASK_ID, get_fwd_task_signature()); - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_forward(ff, argmap); - // IndexLauncher launcher(CONCAT_FWD_TASK_ID, - // parallel_is, - // TaskArgument(this, sizeof(Concat)), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // for (int i = 0; i < numInputs; i++) { - // launcher.add_region_requirement(RegionRequirement(inputs[i]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[i]->region)); - // launcher.add_field(i + 1, FID_DATA); - // } - // runtime->execute_index_space(ctx, launcher); + register_task(CONCAT_INIT_TASK_ID, "Concat Init", init, init_task); } -/* - regions[0](O): output - regions[1..numInputs](I): inputs -*/ -void Concat::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - // Concat const *cc = (Concat *)task->args; - ConcatPerDeviceState const *m = *((ConcatPerDeviceState **)task->local_args); - // Note that our internal axis index ordering is opposite to other frameworks - assert(regions.size() == attrs.n + 1); - assert(task->regions.size() == attrs.n + 1); - // Domain out_domain = runtime->get_index_space_domain( - // ctx, task->regions[0].region.get_index_space()); - // GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( - // DT_FLOAT, regions[0], task->regions[0], FID_DATA, ctx, runtime); - // assert(out_domain.get_dim() == cc->outputs[0].num_dims); - // Domain in_domain[MAX_NUM_INPUTS]; - // for (int i = 0; i < cc->numInputs; i++) - // in_domain[i] = runtime->get_index_space_domain( - // ctx, task->regions[i + 1].region.get_index_space()); - // float *output = helperGetTensorPointerWO( - // regions[0], task->regions[0], FID_DATA, ctx, runtime); - - auto output = acc.get_tensor(OUTPUT); - auto inputs = acc.get_variadic_tensor(INPUTS); - - // GenericTensorAccessorR inputs[MAX_NUM_INPUTS]; - // for (int i = 0; i < attrs.n; i++) { - // // inputs[i] = helperGetTensorPointerRO( - // // regions[i + 1], task->regions[i + 1], FID_DATA, ctx, runtime); - // inputs[i] = helperGetGenericTensorAccessorRO( - // DT_FLOAT, regions[i + 1], task->regions[i + 1], FID_DATA, ctx, - // runtime); - // } - profile(forward_kernel, - m->profiling, - "[Concat] forward_time = %.2lfms\n", - m, - output, - inputs, - attrs.n) -} +template <> +void register_task() { + OpTaskSignature fwd(OpTaskType::FWD); -void Concat::backward(FFModel const &ff) { - this->execute_task(ff, CONCAT_BWD_TASK_ID, get_bwd_task_signature()); - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_backward(ff, argmap); - // IndexLauncher launcher(CONCAT_BWD_TASK_ID, - // parallel_is, - // TaskArgument(this, sizeof(Concat)), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // outputs[0]->region_grad)); - // launcher.add_field(0, FID_DATA); - // for (int i = 0; i < numInputs; i++) { - // launcher.add_region_requirement(RegionRequirement(inputs[i]->part_grad, - // 0 /*projection id*/, - // READ_WRITE, - // EXCLUSIVE, - // inputs[i]->region_grad)); - // // LogicalRegion lr = inputs[i]->region_grad; - // // printf("concat[%d]: region(%d,%d,%d)\n", i+1, - // // lr.get_index_space().get_id(), lr.get_field_space().get_id(), - // // lr.get_tree_id()); - // launcher.add_field(i + 1, FID_DATA); - // } - // runtime->execute_index_space(ctx, launcher); -} + fwd.add_arg_slot(NUM_INPUTS); + fwd.add_arg_slot(PROFILING); + fwd.add_input_slot(INPUTS, SlotType::VARIADIC); + fwd.add_output_slot(OUTPUT); + fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); -/* - regions[0](I): output_grad - regions[1..numInputs](I/O): input_grad -*/ -void Concat::backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - // Concat const *cc = (Concat *)task->args; - ConcatPerDeviceState const *m = *((ConcatPerDeviceState **)task->local_args); - // Note that our internal axis index ordering is opposite to other frameworks - assert(regions.size() == attrs.n + 1); - assert(task->regions.size() == attrs.n + 1); - assert(attrs.n <= MAX_NUM_INPUTS); - // Domain out_grad_domain = runtime->get_index_space_domain( - // ctx, task->regions[0].region.get_index_space()); - // assert(out_grad_domain.get_dim() == cc->outputs[0].num_dims); - // Domain in_grad_domains[MAX_NUM_INPUTS]; - // for (int i = 0; i < cc->numInputs; i++) - // in_grad_domains[i] = runtime->get_index_space_domain( - // ctx, task->regions[i + 1].region.get_index_space()); - // float const *output_grad = helperGetTensorPointerRO( - // regions[0], task->regions[0], FID_DATA, ctx, runtime); - - auto input_grads = acc.get_variadic_tensor(INPUT_GRADS); - auto output_grad = acc.get_tensor(OUTPUT_GRAD); - - // GenericTensorAccessorR output_grad = helperGetGenericTensorAccessorRO( - // DT_FLOAT, regions[0], task->regions[0], FID_DATA, ctx, runtime); - // GenericTensorAccessorW input_grads[MAX_NUM_INPUTS]; - // for (int i = 0; i < attrs.n; i++) { - // // input_grads[i] = helperGetTensorPointerRW( - // // regions[i + 1], task->regions[i + 1], FID_DATA, ctx, runtime); - // input_grads[i] = helperGetGenericTensorAccessorRW( - // DT_FLOAT, regions[i + 1], task->regions[i + 1], FID_DATA, ctx, - // runtime); - // } - - profile(backward_kernel, - m->profiling, - "[Concat] backward_time = %.2lfms\n", - m, - output_grad, - input_grads, - attrs.n) + register_task(CONCAT_FWD_TASK_ID, "Concat Fwd", fwd, forward_task); } -bool Concat::get_int_parameter(PMParameter para, int *value) const { - switch (para) { - case PM_AXIS: - *value = legion_axis; - return true; - default: - return Op::get_int_parameter(para, value); - } -} +template <> +void register_task() { + OpTaskSignature bwd = + infer_bwd_signature(get_op_signature(CONCAT_FWD_TASK_ID)); -bool Concat::measure_operator_cost(Simulator *sim, - MachineView const &mv, - CostMetrics &cost_metrics) const { - assert(numInputs <= MAX_NUM_INPUTS); - ParallelTensorBase sub_inputs[MAX_NUM_INPUTS], sub_output; - if (!outputs[0]->get_sub_tensor(mv, sub_output)) { - return false; - } - for (int i = 0; i < numInputs; i++) { - if (!inputs[i]->get_sub_tensor(mv, sub_inputs[i])) { - return false; - } - } - - ConcatPerDeviceState *m = sim->concat_meta; - init_meta(m, this->legion_axis); - - sim->free_all(); - float *input_ptrs[MAX_NUM_INPUTS]; - float *input_grad_ptrs[MAX_NUM_INPUTS]; - bool out_of_memory = false; - for (int i = 0; i < numInputs; i++) { - input_ptrs[i] = - (float *)sim->allocate(sub_inputs[i].get_volume(), DT_FLOAT); - out_of_memory = out_of_memory || (input_ptrs[i] == NULL); - } - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - Domain out_domain = sub_output.get_domain(); - float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - GenericTensorAccessorW output_acc(DT_FLOAT, out_domain, output_ptr); - cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - out_of_memory = out_of_memory || (output_ptr == NULL); - if (out_of_memory) { - cost_metrics.forward_time = Simulator::MAXIMUM_TASK_RUN_TIME; - cost_metrics.backward_time = Simulator::MAXIMUM_TASK_RUN_TIME; - return true; - } - - Domain in_domains[MAX_NUM_INPUTS]; - GenericTensorAccessorR input_acc[MAX_NUM_INPUTS]; - for (int i = 0; i < numInputs; i++) { - in_domains[i] = sub_inputs[i].get_domain(); - input_acc[i] = - GenericTensorAccessorR(DT_FLOAT, in_domains[i], input_ptrs[i]); - } - - assert(m->profiling == false); - - std::function forward, backward; - forward = [&](ffStream_t stream) { - forward_kernel(stream, m, output_acc, input_acc, numInputs); - }; - if (sim->computationMode == COMP_MODE_TRAINING) { - GenericTensorAccessorW input_grad_accs[MAX_NUM_INPUTS]; - for (int i = 0; i < numInputs; i++) { - input_grad_ptrs[i] = - (float *)sim->allocate(sub_inputs[i].get_volume(), DT_FLOAT); - out_of_memory = out_of_memory || (input_grad_ptrs[i] == NULL); - input_grad_accs[i] = - GenericTensorAccessorW(DT_FLOAT, in_domains[i], input_grad_ptrs[i]); - } - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - float *output_grad_ptr = - (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - GenericTensorAccessorR output_grad_acc( - DT_FLOAT, out_domain, output_grad_ptr); - cost_metrics.outputs_memory += - cost_metrics.total_mem_diff_from(sim->offset); - - out_of_memory = out_of_memory || (output_grad_ptr == NULL); - if (out_of_memory) { - cost_metrics.forward_time = Simulator::MAXIMUM_TASK_RUN_TIME; - cost_metrics.backward_time = Simulator::MAXIMUM_TASK_RUN_TIME; - return true; - } - backward = [&](ffStream_t stream) { - backward_kernel(stream, m, output_grad_acc, input_grad_accs, numInputs); - }; - } - - inner_measure_operator_cost(sim, forward, backward, cost_metrics); - - if (sim->computationMode == COMP_MODE_TRAINING) { - printf( - "[Measure Concat] name(%s) forward_time(%.4lf) backward_time(%.4lf)\n", - name, - cost_metrics.forward_time, - cost_metrics.backward_time); - } else { - printf("[Measure Concat] name(%s) forward_time(%.4lf)\n", - name, - cost_metrics.forward_time); - } - - return true; + register_task(CONCAT_BWD_TASK_ID, "BatchMatmul Bwd", bwd, backward_task); } }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/concat.h b/lib/runtime/src/ops/concat.h index 0e0c0c2523..0493006345 100644 --- a/lib/runtime/src/ops/concat.h +++ b/lib/runtime/src/ops/concat.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_CONCAT_H #include "op-attrs/ops/concat.h" -#include "op_task_invocation.h" #include "sim_environment.h" +#include "task_spec/op_task_invocation.h" namespace FlexFlow { @@ -76,3 +76,335 @@ CostMetrics } // namespace FlexFlow #endif + +// bool operator==(ConcatParams const &lhs, ConcatParams const &rhs) { +// return lhs.axis == rhs.axis; +// } + +// ConcatParams Concat::get_params() const { +// ConcatParams params; +// params.axis = legion_axis; +// return params; +// } + +// Tensor +// FFModel::concat(int n, Tensor const *tensors, int axis, char const *name) +// { +// Layer *concat = new Layer(this, +// OP_CONCAT, +// DT_FLOAT, +// name, +// n /*inputs*/, +// 0 /*weights*/, +// 1 /*outputs*/, +// tensors); +// int numdim = tensors[0]->num_dims; +// // Making sure axis is between [0, numdim) +// axis = (axis % numdim + numdim) % numdim; +// int dims[MAX_TENSOR_DIM]; +// for (int i = 0; i < numdim; i++) { +// dims[i] = tensors[0]->dims[i]; +// } +// for (int i = 1; i < n; i++) { +// assert(tensors[i]->data_type == tensors[0]->data_type); +// assert(tensors[i]->num_dims == tensors[0]->num_dims); +// for (int j = 0; j < numdim; j++) { +// if (j != numdim - axis - 1) { +// assert(tensors[i]->dims[j] == tensors[0]->dims[j]); +// } else { +// dims[j] += tensors[i]->dims[j]; +// } +// } +// } +// concat->outputs[0] = create_tensor_legion_ordering( +// numdim, dims, tensors[0]->data_type, concat, 0, true /*create_grad*/); +// concat->add_int_property("legion_axis", numdim - axis - 1); +// layers.push_back(concat); +// return concat->outputs[0]; +// } + +// Op *Concat::create_operator_from_layer( +// FFModel &model, +// Layer const *layer, +// std::vector const &inputs) { +// long long value; +// layer->get_int_property("legion_axis", value); +// int legion_axis = value; +// return new Concat( +// model, inputs.size(), inputs.data(), legion_axis, layer->name); +// } + +// Concat::Concat(FFModel &model, +// int _n, +// ParallelTensor const *_tensors, +// int _legion_axis, +// char const *name) +// : Op(model, +// OP_CONCAT, +// DT_FLOAT, +// name, +// _n /*inputs*/, +// 0 /*weights*/, +// 1 /*outputs*/, +// _tensors), +// legion_axis(_legion_axis) { +// int num_dim = inputs[0]->num_dims; +// ParallelDim dims[MAX_TENSOR_DIM]; +// for (int i = 0; i < num_dim; i++) { +// dims[i] = inputs[0]->dims[i]; +// } +// for (int i = 1; i < numInputs; i++) { +// assert(inputs[i]->data_type == inputs[0]->data_type); +// assert(inputs[i]->num_dims == inputs[0]->num_dims); +// for (int j = 0; j < num_dim; j++) { +// if (j != legion_axis) { +// assert(inputs[i]->dims[j] == inputs[0]->dims[j]); +// } else { +// // Assert that the concat dim cannot be parallelized +// assert(inputs[i]->dims[j].parallel_idx == -1); +// assert(inputs[i]->dims[j].degree == 1); +// dims[j].size += inputs[i]->dims[j].size; +// } +// } +// } +// numOutputs = 1; +// outputs[0] = model.create_parallel_tensor_legion_ordering( +// num_dim, dims, inputs[0]->data_type, this); +// } + +// Concat::Concat(FFModel &model, +// ConcatParams const ¶ms, +// std::vector const &inputs, +// char const *name) +// : Concat(model, inputs.size(), inputs.data(), params.axis, name) {} + +// void Concat::init(FFModel const &ff) { +// this->execute_task(ff, CONCAT_INIT_TASK_ID, get_init_task_signature()); +// // assert(check_output_input_weight_same_parallel_is()); +// // parallel_is = outputs[0]->parallel_is; +// // ArgumentMap argmap; +// // Context ctx = ff.config.lg_ctx; +// // Runtime *runtime = ff.config.lg_hlr; +// // set_argumentmap_for_init(ff, argmap); +// // IndexLauncher launcher(CONCAT_INIT_TASK_ID, +// // parallel_is, +// // TaskArgument(this, sizeof(Concat)), +// // argmap, +// // Predicate::TRUE_PRED, +// // false /*must*/, +// // 0 /*mapper_id*/, +// // outputs[0]->machine_view.hash()); +// // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// // 0 /*projection id*/, +// // WRITE_ONLY, +// // EXCLUSIVE, +// // outputs[0]->region)); +// // launcher.add_field(0, FID_DATA); +// // for (int i = 0; i < numInputs; i++) { +// // launcher.add_region_requirement(RegionRequirement(inputs[i]->part, +// // 0 /*projection id*/, +// // READ_ONLY, +// // EXCLUSIVE, +// // inputs[i]->region)); +// // launcher.add_field(i + 1, FID_DATA); +// // } +// // for (int i = 0; i < numInputs; i++) { +// // launcher.add_region_requirement(RegionRequirement(inputs[i]->part_grad, +// // 0 /*projection id*/, +// // WRITE_ONLY, +// // EXCLUSIVE, +// // inputs[i]->region_grad)); +// // launcher.add_field(i + numInputs + 1, FID_DATA); +// // } +// // FutureMap fm = runtime->execute_index_space(ctx, launcher); +// // fm.wait_all_results(); +// // set_opmeta_from_futuremap(ff, fm); +// } + +// void Concat::forward(FFModel const &ff) { +// this->execute_task(ff, CONCAT_FWD_TASK_ID, get_fwd_task_signature()); +// // ArgumentMap argmap; +// // Context ctx = ff.config.lg_ctx; +// // Runtime *runtime = ff.config.lg_hlr; +// // set_argumentmap_for_forward(ff, argmap); +// // IndexLauncher launcher(CONCAT_FWD_TASK_ID, +// // parallel_is, +// // TaskArgument(this, sizeof(Concat)), +// // argmap, +// // Predicate::TRUE_PRED, +// // false /*must*/, +// // 0 /*mapper_id*/, +// // outputs[0]->machine_view.hash()); +// // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// // 0 /*projection id*/, +// // WRITE_ONLY, +// // EXCLUSIVE, +// // outputs[0]->region)); +// // launcher.add_field(0, FID_DATA); +// // for (int i = 0; i < numInputs; i++) { +// // launcher.add_region_requirement(RegionRequirement(inputs[i]->part, +// // 0 /*projection id*/, +// // READ_ONLY, +// // EXCLUSIVE, +// // inputs[i]->region)); +// // launcher.add_field(i + 1, FID_DATA); +// // } +// // runtime->execute_index_space(ctx, launcher); +// } + +/* + regions[0](O): output + regions[1..numInputs](I): inputs +*/ + +// void Concat::backward(FFModel const &ff) { +// this->execute_task(ff, CONCAT_BWD_TASK_ID, get_bwd_task_signature()); +// // ArgumentMap argmap; +// // Context ctx = ff.config.lg_ctx; +// // Runtime *runtime = ff.config.lg_hlr; +// // set_argumentmap_for_backward(ff, argmap); +// // IndexLauncher launcher(CONCAT_BWD_TASK_ID, +// // parallel_is, +// // TaskArgument(this, sizeof(Concat)), +// // argmap, +// // Predicate::TRUE_PRED, +// // false /*must*/, +// // 0 /*mapper_id*/, +// // outputs[0]->machine_view.hash()); +// // launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, +// // 0 /*projection id*/, +// // READ_ONLY, +// // EXCLUSIVE, +// // outputs[0]->region_grad)); +// // launcher.add_field(0, FID_DATA); +// // for (int i = 0; i < numInputs; i++) { +// // launcher.add_region_requirement(RegionRequirement(inputs[i]->part_grad, +// // 0 /*projection id*/, +// // READ_WRITE, +// // EXCLUSIVE, +// // inputs[i]->region_grad)); +// // // LogicalRegion lr = inputs[i]->region_grad; +// // // printf("concat[%d]: region(%d,%d,%d)\n", i+1, +// // // lr.get_index_space().get_id(), lr.get_field_space().get_id(), +// // // lr.get_tree_id()); +// // launcher.add_field(i + 1, FID_DATA); +// // } +// // runtime->execute_index_space(ctx, launcher); +// } + +/* + regions[0](I): output_grad + regions[1..numInputs](I/O): input_grad +*/ + +// bool Concat::get_int_parameter(PMParameter para, int *value) const { +// switch (para) { +// case PM_AXIS: +// *value = legion_axis; +// return true; +// default: +// return Op::get_int_parameter(para, value); +// } +// } + +// bool Concat::measure_operator_cost(Simulator *sim, +// MachineView const &mv, +// CostMetrics &cost_metrics) const { +// assert(numInputs <= MAX_NUM_INPUTS); +// ParallelTensorBase sub_inputs[MAX_NUM_INPUTS], sub_output; +// if (!outputs[0]->get_sub_tensor(mv, sub_output)) { +// return false; +// } +// for (int i = 0; i < numInputs; i++) { +// if (!inputs[i]->get_sub_tensor(mv, sub_inputs[i])) { +// return false; +// } +// } + +// ConcatPerDeviceState *m = sim->concat_meta; +// init_meta(m, this->legion_axis); + +// sim->free_all(); +// float *input_ptrs[MAX_NUM_INPUTS]; +// float *input_grad_ptrs[MAX_NUM_INPUTS]; +// bool out_of_memory = false; +// for (int i = 0; i < numInputs; i++) { +// input_ptrs[i] = +// (float *)sim->allocate(sub_inputs[i].get_volume(), DT_FLOAT); +// out_of_memory = out_of_memory || (input_ptrs[i] == NULL); +// } +// cost_metrics.inputs_memory += +// cost_metrics.total_mem_diff_from(sim->offset); + +// Domain out_domain = sub_output.get_domain(); +// float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), +// DT_FLOAT); GenericTensorAccessorW output_acc(DT_FLOAT, out_domain, +// output_ptr); cost_metrics.outputs_memory += +// cost_metrics.total_mem_diff_from(sim->offset); + +// out_of_memory = out_of_memory || (output_ptr == NULL); +// if (out_of_memory) { +// cost_metrics.forward_time = Simulator::MAXIMUM_TASK_RUN_TIME; +// cost_metrics.backward_time = Simulator::MAXIMUM_TASK_RUN_TIME; +// return true; +// } + +// Domain in_domains[MAX_NUM_INPUTS]; +// GenericTensorAccessorR input_acc[MAX_NUM_INPUTS]; +// for (int i = 0; i < numInputs; i++) { +// in_domains[i] = sub_inputs[i].get_domain(); +// input_acc[i] = +// GenericTensorAccessorR(DT_FLOAT, in_domains[i], input_ptrs[i]); +// } + +// assert(m->profiling == false); + +// std::function forward, backward; +// forward = [&](ffStream_t stream) { +// forward_kernel(stream, m, output_acc, input_acc, numInputs); +// }; +// if (sim->computationMode == COMP_MODE_TRAINING) { +// GenericTensorAccessorW input_grad_accs[MAX_NUM_INPUTS]; +// for (int i = 0; i < numInputs; i++) { +// input_grad_ptrs[i] = +// (float *)sim->allocate(sub_inputs[i].get_volume(), DT_FLOAT); +// out_of_memory = out_of_memory || (input_grad_ptrs[i] == NULL); +// input_grad_accs[i] = +// GenericTensorAccessorW(DT_FLOAT, in_domains[i], +// input_grad_ptrs[i]); +// } +// cost_metrics.inputs_memory += +// cost_metrics.total_mem_diff_from(sim->offset); float *output_grad_ptr = +// (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); +// GenericTensorAccessorR output_grad_acc( +// DT_FLOAT, out_domain, output_grad_ptr); +// cost_metrics.outputs_memory += +// cost_metrics.total_mem_diff_from(sim->offset); + +// out_of_memory = out_of_memory || (output_grad_ptr == NULL); +// if (out_of_memory) { +// cost_metrics.forward_time = Simulator::MAXIMUM_TASK_RUN_TIME; +// cost_metrics.backward_time = Simulator::MAXIMUM_TASK_RUN_TIME; +// return true; +// } +// backward = [&](ffStream_t stream) { +// backward_kernel(stream, m, output_grad_acc, input_grad_accs, +// numInputs); +// }; +// } + +// inner_measure_operator_cost(sim, forward, backward, cost_metrics); + +// if (sim->computationMode == COMP_MODE_TRAINING) { +// printf( +// "[Measure Concat] name(%s) forward_time(%.4lf) +// backward_time(%.4lf)\n", name, cost_metrics.forward_time, +// cost_metrics.backward_time); +// } else { +// printf("[Measure Concat] name(%s) forward_time(%.4lf)\n", +// name, +// cost_metrics.forward_time); +// } + +// return true; +// } diff --git a/lib/runtime/src/ops/conv_2d.cc b/lib/runtime/src/ops/conv_2d.cc index e362c73f92..8379da80ff 100644 --- a/lib/runtime/src/ops/conv_2d.cc +++ b/lib/runtime/src/ops/conv_2d.cc @@ -1,1051 +1,238 @@ #include "conv_2d.h" #include "kernels/conv_2d_kernels.h" -#include "layer.h" #include "legion/legion_utilities.h" #include "mpark/variant.hpp" -#include "task_spec.h" +#include "op-attrs/get_output_shapes.h" #include "utils/hash-utils.h" namespace FlexFlow { -enum Slots { - INPUT, - OUTPUT, - FILTER, - BIAS, - FILTER_GRAD, - INPUT_GRAD, - OUTPUT_GRAD, - BIAS_GRAD, - ATTRS, - PROFILING, -} - -// declare Legion names -using Legion::ArgumentMap; using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; -using Legion::InlineLauncher; using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; using Legion::Runtime; using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; using namespace FlexFlow::Kernels::Conv2D; -Tensor FFModel::conv2d(Tensor const &input, - int outChannels, - int kernelH, - int kernelW, - int strideH, - int strideW, - int paddingH, - int paddingW, - ActiMode activation, - int groups, - bool use_bias, - Layer const *shared_op, - Initializer *kernel_initializer, - Initializer *bias_initializer, - char const *name) { - assert(input->num_dims() == 4); /*NCHW*/ - - Conv2DAttrs attrs = {outChannels, - kernelH, - kernelW, - strideH, - strideW, - paddingH, - paddingW, - groups, - activation, - use_bias}; - - TensorShape output_shape = get_output_shape(attrs, input->get_shape()); - Tensor output = this->tensor_mgr.create(output_shape, CreateGrad::YES, conv); - - std::vector weights; - - TensorShape kernel_shape = get_kernel_shape(attrs, input->get_shape()); - weights.push_back(this->tensor_mgr.create( - kernel_shape, CreateGrad::YES, kernel_initializer, CHOSEN_SYNC_TYPE)); - - if (use_bias) { - TensorShape bias_shape = get_bias_shape(attrs, input->get_shape()); - weights.push_back(this->tensor_mgr.create( - bias_shape, CreateGrad::YES, bias_initializer, CHOSEN_SYNC_TYPE)); - } - - Layer *conv = - this->layer_mgr.create(attrs, DT_FLOAT, name, {input}, weights, {output}); - - //{ - // int numdims = 4; - // int dims[MAX_TENSOR_DIM]; - // dims[3] = input->dims[3]; - // dims[2] = outChannels; - // dims[1] = 1 + (input->dims[1] + 2 * paddingH - kernelH) / strideH; - // dims[0] = 1 + (input->dims[0] + 2 * paddingW - kernelW) / strideW; - // conv->outputs[0] = create_tensor_legion_ordering( - // numdims, dims, DT_FLOAT, conv, 0, true /*create_grad*/); - //} - //{ - // int dims[4] = {kernelW, kernelH, input->dims[2], outChannels}; - // conv->weights[0] = create_weight_legion_ordering(4, - // dims, - // DT_FLOAT, - // conv, - // true /*create_grad*/, - // kernel_initializer, - // CHOSEN_SYNC_TYPE); - //} - // if (use_bias) { - // int dims[1] = {outChannels}; - // conv->weights[1] = create_weight_legion_ordering(1, - // dims, - // DT_FLOAT, - // conv, - // true /*create_grad*/, - // bias_initializer, - // CHOSEN_SYNC_TYPE); - //} - conv->add_initializer("kernel", kernel_initializer); - conv->add_initializer("bias", bias_initializer); - /* layers.push_back(conv); */ - return conv->outputs[0]; -} - -Op *Conv2D::create_operator_from_layer( - FFModel &model, - Layer const *layer, - std::vector const &inputs) { - return new Conv2D(model, - get(layer->attrs), - inputs, - layer->name, - false /*allocate_weights*/ - ); -} - -/* void Conv2DParams::mark_replica_dims( */ -/* ParallelTensorShape const &input, */ -/* ParallelDim output_dims[MAX_TENSOR_DIM], */ -/* ParallelDim kernel_dims[MAX_TENSOR_DIM], */ -/* ParallelDim bias_dims[MAX_TENSOR_DIM]) const { */ -/* if (output_dims != nullptr) { */ -/* output_dims[Conv2DOutput::REPLICA].is_replica_dim = true; */ -/* } */ -/* if (kernel_dims != nullptr) { */ -/* kernel_dims[Conv2DOutput::REPLICA].is_replica_dim = true; */ -/* } */ -/* if (bias_dims != nullptr) { */ -/* bias_dims[Conv2DBias::REPLICA_1].is_replica_dim = true; */ -/* bias_dims[Conv2DBias::REPLICA_2].is_replica_dim = true; */ -/* bias_dims[Conv2DBias::REPLICA_3].is_replica_dim = true; */ -/* bias_dims[Conv2DBias::REPLICA_4].is_replica_dim = true; */ -/* } */ -/* } */ - -/* int Conv2DParams::output_size(ParallelTensorShape const &input, */ -/* ParallelDim output_dims[MAX_TENSOR_DIM]) const - * { */ -/* int input_w = input.dims[Conv2DInput::WIDTH].size; */ -/* int input_h = input.dims[Conv2DInput::HEIGHT].size; */ - -/* output_dims[Conv2DOutput::SAMPLE].size = - * input.dims[Conv2DInput::SAMPLE].size; */ -/* output_dims[Conv2DOutput::CHANNEL].size = out_channels; */ -/* output_dims[Conv2DOutput::HEIGHT].size = */ -/* 1 + (input_h + 2 * padding_h - kernel_h) / stride_h; */ -/* output_dims[Conv2DOutput::WIDTH].size = */ -/* 1 + (input_w + 2 * padding_w - kernel_w) / stride_w; */ - -/* return input.num_dims; */ -/* }; */ - -/* int Conv2DParams::kernel_size(ParallelTensorShape const &input, */ -/* ParallelDim kernel_dims[MAX_TENSOR_DIM]) const - * { */ -/* kernel_dims[Conv2DKernel::CHANNEL_OUT].size = this->out_channels; */ -/* kernel_dims[Conv2DKernel::CHANNEL_IN].size = */ -/* input.dims[Conv2DInput::CHANNEL].size / this->groups; */ -/* kernel_dims[Conv2DKernel::HEIGHT].size = */ -/* this->kernel_h * input.dims[Conv2DInput::HEIGHT].degree; */ -/* kernel_dims[Conv2DKernel::WIDTH].size = */ -/* this->kernel_w * input.dims[Conv2DInput::WIDTH].degree; */ - -/* return Conv2DKernel::NUMDIM; */ -/* } */ - -/* int Conv2DParams::bias_size(ParallelTensorShape const &input, */ -/* ParallelDim bias_dims[MAX_TENSOR_DIM]) const { */ -/* bias_dims[Conv2DBias::CHANNEL].size = this->out_channels; */ - -/* return Conv2DBias::NUMDIM; */ -/* }; */ - -/* void Conv2DParams::solve_dims(ParallelTensorShape const &input, */ -/* ParallelDim output_dims[MAX_TENSOR_DIM], */ -/* int *output_ndims, */ -/* ParallelDim kernel_dims[MAX_TENSOR_DIM], */ -/* int *kernel_ndims, */ -/* ParallelDim bias_dims[MAX_TENSOR_DIM], */ -/* int *bias_ndims) const { */ -/* assert((output_dims == nullptr) == (output_ndims == nullptr)); */ -/* assert((kernel_dims == nullptr) == (kernel_ndims == nullptr)); */ -/* assert((bias_dims == nullptr) == (bias_ndims == nullptr)); */ - -/* std::vector mapping; */ -/* Conv2D::construct_mappings(mapping, this->use_bias); */ - -/* this->mark_replica_dims(input, output_dims, kernel_dims, bias_dims); */ - -/* std::vector output_dim_sets; */ -/* if (output_dims != nullptr) { */ -/* output_dim_sets.push_back(output_dims); */ -/* } */ - -/* std::vector weight_dim_sets; */ -/* if (kernel_dims != nullptr) { */ -/* weight_dim_sets.push_back(kernel_dims); */ -/* } */ -/* if (bias_dims != nullptr && this->use_bias) { */ -/* weight_dim_sets.push_back(bias_dims); */ -/* } */ - -/* solve_parallel_dim_mappings( */ -/* mapping, {input.dims}, weight_dim_sets, output_dim_sets); */ - -/* if (output_dims != nullptr) { */ -/* *output_ndims = this->output_size(input, output_dims); */ -/* } */ -/* if (kernel_dims != nullptr) { */ -/* *kernel_ndims = this->kernel_size(input, kernel_dims); */ -/* } */ -/* if (bias_dims != nullptr && this->use_bias) { */ -/* *bias_ndims = this->bias_size(input, bias_dims); */ -/* } */ -/* } */ - -/*static*/ -/* void Conv2D::construct_mappings(std::vector &out, - */ -/* bool use_bias) { */ -/* Conv2D::construct_output_mappings(out); */ -/* Conv2D::construct_weight_mappings(out, use_bias); */ -/* } */ - -/*static*/ -/* void Conv2D::construct_output_mappings( */ -/* std::vector &out) { */ -/* Op::construct_output_parallel_dims( */ -/* out, */ -/* {{Conv2DInput::CHANNEL, */ -/* MappingOperation::REPLICATE, */ -/* Conv2DOutput::REPLICA}, */ -/* {Conv2DInput::SAMPLE, MappingOperation::PARTITION, - * Conv2DOutput::SAMPLE}, */ -/* {Conv2DInput::REPLICA, */ -/* MappingOperation::PARTITION, */ -/* Conv2DOutput::CHANNEL}, */ -/* {Conv2DInput::HEIGHT, MappingOperation::PARTITION, - * Conv2DOutput::HEIGHT}, */ -/* {Conv2DInput::WIDTH, MappingOperation::PARTITION, - * Conv2DOutput::WIDTH}}); */ -/* } */ - -/*static*/ -/* void Conv2D::construct_weight_mappings( */ -/* std::vector &out, bool use_bias) { */ -/* Op::construct_weight_parallel_dims( */ -/* out, */ -/* { */ -/* {Conv2DInput::REPLICA, */ -/* MappingOperation::PARTITION, */ -/* Conv2DKernel::CHANNEL_OUT}, */ -/* {Conv2DInput::SAMPLE, */ -/* MappingOperation::REPLICATE, */ -/* Conv2DKernel::REPLICA}, */ -/* {Conv2DInput::CHANNEL, */ -/* MappingOperation::PARTITION, */ -/* Conv2DKernel::CHANNEL_IN}, */ -/* {Conv2DInput::HEIGHT, */ -/* MappingOperation::REPLICATE, */ -/* Conv2DKernel::HEIGHT}, // Kernel::{HEIGHT, WEIGHT} would both work - */ -/* // here */ -/* {Conv2DInput::WIDTH, */ -/* MappingOperation::REPLICATE, */ -/* Conv2DKernel::WIDTH}, // same as above */ -/* }, */ -/* Conv2DInput::INDEX, */ -/* Conv2DKernel::INDEX); */ - -/* if (use_bias) { */ -/* Op::construct_weight_parallel_dims( */ -/* out, */ -/* {{Conv2DInput::REPLICA, Conv2DBias::REPLICA_1}, */ -/* {Conv2DInput::SAMPLE, Conv2DBias::REPLICA_2}, */ -/* {Conv2DInput::CHANNEL, Conv2DBias::CHANNEL}, */ -/* {Conv2DInput::HEIGHT, Conv2DBias::REPLICA_3}, */ -/* {Conv2DInput::WIDTH, Conv2DBias::REPLICA_4}}, */ -/* Conv2DInput::INDEX, */ -/* Conv2DBias::INDEX); */ -/* } */ -/* } */ - -Conv2D::Conv2D(FFModel &model, - Conv2D const &other, - const ParallelTensor input, - bool allocate_weights) - : Conv2D(model, - other.layer_guid, - input, - other.out_channels, - other.kernel_h, - other.kernel_w, - other.stride_h, - other.stride_w, - other.padding_h, - other.padding_w, - other.activation, - other.groups, - other.use_bias, - allocate_weights, - other.name) {} - -Conv2D::Conv2D(FFModel &model, - Conv2DAttrs const &attrs, - std::vector const &inputs, - char const *name, - bool allocate_weights) - : Conv2D(model, - params.layer_guid, - input, - params.out_channels, - params.kernel_h, - params.kernel_w, - params.stride_h, - params.stride_w, - params.padding_h, - params.padding_w, - params.activation, - params.groups, - params.use_bias, - allocate_weights, - name) {} - -/* bool Conv2DParams::is_valid(ParallelTensorShape const &input) const { */ -/* ParallelTensorShape output_shape, kernel_shape, bias_shape; */ -/* this->solve_dims(input, */ -/* output_shape.dims, */ -/* &output_shape.num_dims, */ -/* kernel_shape.dims, */ -/* &kernel_shape.num_dims, */ -/* bias_shape.dims, */ -/* &bias_shape.num_dims); */ -/* bool is_valid = true; */ -/* is_valid &= input.is_valid(); */ -/* is_valid &= output_shape.is_valid(); */ -/* is_valid &= kernel_shape.is_valid(); */ -/* if (use_bias) { */ -/* is_valid &= bias_shape.is_valid(); */ -/* } */ - -/* // TODO FIXME: Currently disable parallelizing the height and width - * dimension */ -/* if (input.dims[0].degree > 1 || input.dims[1].degree > 1) { */ -/* return false; */ -/* } */ - -/* return is_valid; */ -/* } */ - -Conv2D::Conv2D(FFModel &model, - LayerID const &_layer_guid, - const ParallelTensor input, - int outChannels, - int kernelH, - int kernelW, - int strideH, - int strideW, - int paddingH, - int paddingW, - ActiMode activation, - int groups, - bool use_bias, - bool allocate_weights, - char const *name) - : Op(model, - OP_CONV2D, - DT_FLOAT, - name, - 1 /*inputs*/, - use_bias ? 2 : 1 /*weights*/, - allocate_weights, - 1 /*outputs*/, - input), - in_channels(input->dims[Conv2DInput::CHANNEL].size / - input->dims[Conv2DInput::CHANNEL].degree), - out_channels(outChannels), kernel_h(kernelH), kernel_w(kernelW), - stride_h(strideH), stride_w(strideW), padding_h(paddingH), - padding_w(paddingW), activation(activation), groups(groups), - use_bias(use_bias) { - // overwrite layer_guid - layer_guid = _layer_guid; - assert(input->num_dims == Conv2DInput::NUMDIM); - assert(this->stride_h > 0); - assert(this->stride_w > 0); - - ParallelDim output_dims[MAX_TENSOR_DIM], kernel_dims[MAX_TENSOR_DIM], - bias_dims[MAX_TENSOR_DIM]; - int output_ndims, kernel_ndims, bias_ndims; - - this->construct_mappings(*this->parallel_dims_mapping, this->use_bias); - this->get_params().solve_dims(this->inputs[0]->get_shape(), - output_dims, - &output_ndims, - kernel_dims, - &kernel_ndims, - bias_dims, - &bias_ndims); - - if (allocate_weights) { - Initializer *kernel_initializer = new GlorotUniform(std::rand() /*seed*/); - - weights[Conv2DKernel::INDEX] = - model.create_parallel_weight_legion_ordering(kernel_ndims, - kernel_dims, - DT_FLOAT, - NULL /*owner_op*/, - true /*create_grad*/, - kernel_initializer, - CHOSEN_SYNC_TYPE); - - if (use_bias) { - Initializer *bias_initializer = new ZeroInitializer(); - - weights[Conv2DBias::INDEX] = - model.create_parallel_weight_legion_ordering(bias_ndims, - bias_dims, - DT_FLOAT, - NULL /*owner_op*/, - true /*create_grad*/, - bias_initializer, - CHOSEN_SYNC_TYPE); - } - } - - outputs[0] = model.create_parallel_tensor_legion_ordering( - output_ndims, output_dims, DT_FLOAT, this); - - assert(check_output_input_weight_parallel_dims(allocate_weights)); -} - -static OpTaskSignature get_init_task_signature() { - OpTaskSignature init(OpTaskType::INIT); - - init.add_arg_slot(ATTRS); - init.add_arg_slot(PROFILING); - - init.add_input_slot(INPUT); - init.add_output_slot(OUTPUT, WRITE_ONLY); - init.add_param_slot(FILTER); - init.add_param_slot(BIAS); - init.add_param_grad_slot(FILTER_GRAD, WRITE_ONLY); - init.add_input_grad_slot(INPUT_GRAD); - - return init; -} - -static OpTaskSignature get_fwd_task_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_arg_slot(ATTRS); - - fwd.add_input_slot(INPUT); - fwd.add_output_slot(OUTPUT, WRITE_ONLY); - fwd.add_param_slot(FILTER); - fwd.add_param_slot(BIAS); - - return fwd; -} - -static OpTaskSignature get_bwd_task_signature() { - OpTaskSignature bwd(OpTaskType::BWD); +enum Slots { + INPUT, + OUTPUT, + FILTER, + BIAS, + ATTRS, + PROFILING, + PER_DEVICE_STATE, + HANDLE +}; - bwd.add_arg_slot(ATTRS); +OpTaskInvocation init(Conv2DAttrs const &attrs) { + OpTaskBinding binding; - bwd.add_input_slot(INPUT); - bwd.add_input_grad_slot(INPUT_GRAD, READ_WRITE); - bwd.add_output_slot(OUTPUT); - bwd.add_output_grad_slot(OUTPUT_GRAD, READ_WRITE); - bwd.add_param_slot(FILTER); - bwd.add_param_grad_slot(FILTER_GRAD, READ_WRITE); - bwd.add_param_grad_slot(BIAS_GRAD, READ_WRITE); + binding.bind_arg(ATTRS, attrs); + binding.bind_arg(HANDLE, ff_handle()); - return bwd; + return {CONV2D_INIT_TASK_ID, binding}; } -OpTaskBinding Conv2d::get_init_task_binding() const { +OpTaskInvocation forward(Conv2DAttrs const &attrs) { OpTaskBinding binding; - binding.bind_arg(ATTRS, this->attrs); - binding.bind_arg(PROFILING, this->profiling); + binding.bind_arg(PROFILING, profiling_settings()); + binding.bind_arg(PER_DEVICE_STATE, + per_device_op_state()); binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - binding.bind(FILTER, param_tensor(0)); - binding.bind(BIAS, param_tensor(1)); - binding.bind(FILTER_GRAD, param_tensor(0).grad()); - binding.bind(INPUT_GRAD, input_tensor(0).grad()); + binding.bind(FILTER, weight_tensor(0)); + binding.bind(BIAS, weight_tensor(1)); - return binding; + return {CONV2D_FWD_TASK_ID, binding}; } -OpTaskBinding Conv2d::get_fwd_task_binding() const { - OpTaskBinding binding; - - binding.bind_arg(ATTRS, this->attrs); - - binding.bind(INPUT, input_tensor(0)); - binding.bind(OUTPUT, output_tensor(0)); - binding.bind(FILTER, param_tensor(0)); - binding.bind(BIAS, param_tensor(1)); +OpTaskInvocation backward(Conv2DAttrs const &attrs) { + OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return binding; + return {CONV2D_BWD_TASK_ID, binding}; } -OpTaskBinding Conv2d::get_bwd_task_binding() const { - OpTaskBinding binding; +static DeviceSpecific + init_task_impl(TaskArgumentAccessor const &acc) { - binding.bind_arg(ATTRS, this->attrs); + PerDeviceFFHandle handle = acc.get_argument(HANDLE); + auto const &attrs = acc.get_argument(ATTRS); - binding.bind(INPUT, input_tensor(0)); - binding.bind(INPUT_GRAD, input_tensor(0).grad()); - binding.bind(OUTPUT, output_tensor(0)); - binding.bind(OUTPUT_GRAD, output_tensor(0).grad()); - binding.bind(FILTER, param_tensor(0)); - binding.bind(FILTER_GRAD, param_tensor(0).grad()); - binding.bind(BIAS_GRAD, param_tensor(1).grad()); + ffTensorDescriptor_t inputTensor; + ffTensorDescriptor_t biasTensor; + ffTensorDescriptor_t outputTensor; + ffFilterDescriptor_t filterDesc; + ffActivationDescriptor_t actiDesc; + ffConvolutionDescriptor_t convDesc; + ffConvolutionFwdAlgo_t fwdAlgo; + ffConvolutionBwdFilterAlgo_t bwdFilterAlgo; + ffConvolutionBwdDataAlgo_t bwdDataAlgo; - return binding; -} + DeviceSpecific per_device_state = + acc.create_device_specific( + init_kernel(handle, + inputTensor, + biasTensor, + outputTensor, + filterDesc, + actiDesc, + convDesc, + fwdAlgo, + bwdFilterAlgo, + bwdDataAlgo, + attrs.activation, + attrs.use_bias)); -void Conv2D::init(FFModel const &ff) { - this->execute_task(ff, CONV2D_INIT_TASK_ID, get_init_task_signature()); - // assert(check_output_input_weight_same_parallel_is()); - // parallel_is = outputs[0]->parallel_is; - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_init(ff, argmap); - // IndexLauncher launcher(CONV2D_INIT_TASK_ID, - // parallel_is, - // TaskArgument(this, sizeof(Conv2D)), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(1, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(weights[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // weights[0]->region)); - // launcher.add_field(2, FID_DATA); - // // launcher.add_region_requirement( - // // RegionRequirement(weights[1]->part, 0/*projection id*/, - // // READ_ONLY, EXCLUSIVE, weights[1]->region)); - // // launcher.add_field(3, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(weights[0]->part_grad, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // weights[0]->region_grad)); - // launcher.add_field(3, FID_DATA); - // // launcher.add_region_requirement( - // // RegionRequirement(inputs[0]->part_grad, 0/*projection id*/, - // // WRITE_ONLY, EXCLUSIVE, inputs[0]->region_grad)); - // // launcher.add_field(4, FID_DATA); - // FutureMap fm = runtime->execute_index_space(ctx, launcher); - // fm.wait_all_results(); - // set_opmeta_from_futuremap(ff, fm); + return per_device_state; } -/* - regions[0]: input - regions[1]: output - regions[2](I): filter - regions[3](I): bias - regions[4](O): filter_grad - regions[5](O): input_grad -*/ -PerDeviceOpState *Conv2D::init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 4); - assert(task->regions.size() == 4); - // Conv2D const *conv = (Conv2D *)task->args; +static DeviceSpecific + init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { TaskArgumentAccessor acc(task, regions, ctx, runtime); - FFHandler handle = *((FFHandler const *)task->local_args); - auto const &attrs = acc.get_argument(ATTRS); - bool profiling = acc.get_argument(PROFILING); - // TensorAccessorR acc_input( - // regions[0], task->regions[0], FID_DATA, ctx, runtime); - // TensorAccessorW acc_output(regions[1], - // task->regions[1], - // FID_DATA, - // ctx, - // runtime, - // false - // /*readOutput*/); - // TensorAccessorR acc_kernel( - // regions[2], task->regions[2], FID_DATA, ctx, runtime); - // TensorAccessorR acc_bias( - // regions[3], task->regions[3], FID_DATA, ctx, runtime); - // TensorAccessorW acc_kernel_grad( - // regions[3], - // task->regions[3], - // FID_DATA, - // ctx, - // runtime, - // false /*readOutput*/); - // TensorAccessorW acc_input_grad( - // regions[4], task->regions[4], FID_DATA, ctx, runtime, - // false/*readOutput*/); - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - auto filter = acc.get_tensor(FILTER); - auto bias = acc.get_tensor(BIAS); - auto filter_grad = acc.get_tensor(FILTER_GRAD); - auto input_grad = acc.get_tensor(INPUT_GRAD); - - Conv2DPerDeviceState *m = new Conv2DPerDeviceState(handle); - m->relu = attrs.activation == AC_MODE_RELU; - m->use_bias = attrs.use_bias; - m->profiling = profiling; - // m->trainableInputs[0] = conv->trainableInputs[0]; ?? - std::strcpy(m->op_name, attrs.name); - - int input_w = input.shape[0]; - int input_h = input.shape[1]; - int input_c = input.shape[2]; - int input_n = input.shape[3]; - int output_w = output.shape[0]; - int output_h = output.shape[1]; - int output_c = output.shape[2]; - int output_n = output.shape[3]; - - printf("init conv (input): n(%d) c(%d) h(%d) w(%d)\n", - input_n, - input_c, - input_h, - input_w); - printf("init conv (output): n(%d) c(%d) h(%d) w(%d)\n", - output_n, - output_c, - output_h, - output_w); - - // printf("convDim: padding(%d %d) stride(%d %d)\n", conv->padding_h, - // conv->padding_w, conv->stride_h, conv->stride_w); - int pad_h = - ((output_h - 1) * attrs.stride_h + attrs.kernel_h - input_h + 1) / 2; - int pad_w = - ((output_w - 1) * attrs.stride_w + attrs.kernel_w - input_w + 1) / 2; - if (pad_h != attrs.padding_h) { - printf("Warning: changing conv_padding_h to satisfy output_h size\n"); - } - if (pad_w != attrs.padding_w) { - printf("Warning: changing conv_padding_w to satisfy output_w size\n"); - } - - init_kernel(m, - input_w, - input_h, - input_c, - input_n, - output_w, - output_h, - output_c, - output_n, - attrs.kernel_h, - attrs.kernel_w, - attrs.groups, - attrs.stride_h, - attrs.stride_w, - pad_h, - pad_w, - input.get_float_ptr(), - output.get_float_ptr(), - filter.get_float_ptr(), - filter_grad.get_float_ptr()); - - return m; + return init_task_impl(acc); } -// TaskSpec Conv2D::get_tasks_spec() const { -// OpTasksSpec spec { -// CONV2D_INIT_TASK_ID, -// CONV2D_FWD_TASK_ID, -// CONV2D_BWD_TASK_ID -// }; -// auto &fwd = spec.get_fwd(); - -// fwd.add_input_slot(INPUT); -// fwd.add_param_slot(KERNEL); -// fwd.add_output_slot(OUTPUT); - -// auto input = spec.input_tensor(0); -// auto kernel = spec.param_tensor(0); -// auto bias = spec.param_tensor(1); -// auto output = spec.output_tensor(0); - -// fwd[INPUT] = input; -// fwd[KERNEL] = kernel; -// if (this->use_bias) { -// fwd[BIAS] = bias; -// } -// fwd[OUTPUT] = output; - -// return spec; -// } - -/* TaskSpec Conv2D::get_forward_task_spec() const { */ -/* TaskSpec spec = { CONV2D_FWD_TASK_ID, Pass::FWD }; */ - -/* auto input = spec.add_tensor(TensorRole::INPUT, 0); */ -/* auto kernel = spec.add_tensor(TensorRole::PARAM, 0); */ -/* auto bias = spec.add_tensor(TensorRole::BIAS, 1); */ -/* auto output = spec.add_tensor(TensorRole::OUTPUT, 0); */ - -/* spec.add_input(INPUT, input); */ -/* spec.add_input(KERNEL, kernel); */ - -/* if (this->use_bias) { */ -/* spec.add_input(BIAS, bias); */ -/* } */ - -/* spec.add_output(OUTPUT, output); */ - -/* return spec; */ -/* } */ - -/* TaskSpec Conv2D::get_backward_task_spec() const { */ -/* TaskSpec spec = { CONV2D_BWD_TASK_ID, Pass::BWD }; */ +static optional forward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_argument(PROFILING); + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); -/* auto input = spec.add_tensor(TensorRole::INPUT, 0); */ -/* auto kernel = spec.add_tensor(TensorRole::PARAM, 0); */ -/* auto bias = spec.add_tensor(TensorRole::BIAS, 1); */ -/* auto output = spec.add_tensor(TensorRole::OUTPUT, 0); */ + auto input = acc.get_tensor(INPUT); + auto filter = acc.get_tensor(FILTER); + auto bias = acc.get_tensor(BIAS); + auto output = acc.get_tensor(OUTPUT); -/* spec.add_input(INPUT, input); */ -/* spec.add_output(INPUT_GRAD, input.grad); */ -/* spec.add_input(KERNEL, kernel); */ -/* spec.add_output(KERNEL_GRAD, kernel.grad); */ + return profile(forward_kernel, + profiling, + "[Conv2d] forward_time = %.2lfms\n", + &per_device_state, + input.get_float_ptr(), + output.get_float_ptr(), + filter.get_float_ptr(), + bias.get_float_ptr()); +} -/* if (this->use_bias) { */ -/* spec.add_input(BIAS, bias); */ -/* spec.add_output(BIAS_GRAD, bias.grad); */ -/* } */ +static void forward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + forward_task_impl(acc); +} -/* spec.add_input(OUTPUT, output); */ -/* spec.add_input(OUTPUT_GRAD, output.grad); */ +static optional backward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_argument(PROFILING); + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); -/* return spec; */ -/* } */ + auto input = acc.get_tensor(INPUT); + auto output = acc.get_tensor(OUTPUT); + auto filter = acc.get_tensor(FILTER); -void Conv2D::forward(FFModel const &ff) { - this->execute_task(ff, CONV2D_FWD_TASK_ID, get_fwd_task_signature()); -} + auto input_grad = acc.get_tensor_grad(INPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); + auto filter_grad = acc.get_tensor_grad(FILTER); + auto bias_grad = acc.get_tensor_grad(BIAS); -void Conv2D::backward(FFModel const &ff) { - this->execute_task(ff, CONV2D_bWD_TASK_ID, get_bwd_task_signature()); + return profile(backward_kernel, + profiling, + "[Conv2d] backward_time = %.2lfms\n", + &per_device_state, + input.get_float_ptr(), + input_grad.get_float_ptr(), + output.get_float_ptr(), + output_grad.get_float_ptr(), + filter.get_float_ptr(), + filter_grad.get_float_ptr(), + bias_grad.get_float_ptr()); } -/* - regions[0](I): input - regions[1](O): output - regions[2](I): filter - regions[3](I): bias -*/ -void Conv2D::forward_task(Task const *task, +static void backward_task(Task const *task, std::vector const ®ions, Context ctx, Runtime *runtime) { - Conv2DPerDeviceState const *m = *((Conv2DPerDeviceState **)task->local_args); - TaskArgumentAccessor acc(task, regions, ctx, runtime); + backward_task_impl(acc); +} - auto input = acc.get_tensor(INPUT); - auto filter = acc.get_tensor(FILTER); - auto bias = acc.get_tensor(BIAS); - auto output = acc.get_tensor(OUTPUT); +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + Conv2DAttrs const &attrs, + InputParallelTensorDesc const &input_shape, + InputParallelTensorDesc const &filter_shape, + InputParallelTensorDesc const &bias_shape, + ProfilingSettings const &settings, + MachineView const &mv) { - // TensorAccessorR acc_input( - // regions[0], task->regions[0], FID_DATA, ctx, runtime); - // TensorAccessorW acc_output(regions[1], - // task->regions[1], - // FID_DATA, - // ctx, - // runtime, - // false - // /*readOutput*/); - // TensorAccessorR acc_kernel( - // regions[2], task->regions[2], FID_DATA, ctx, runtime); - // float const *acc_bias_ptr = NULL; - // if (m->use_bias) { - // TensorAccessorR acc_bias( - // regions[3], task->regions[3], FID_DATA, ctx, runtime); - // acc_bias_ptr = acc_bias.ptr; - // } + auto env = sim.new_environment(); - profile(forward_kernel, - m->profiling, - "[Conv2d] forward_time = %.2lfms\n", - m, - input.get_float_ptr(), - output.get_float_ptr(), - filter.get_float_ptr(), - bias.get_float_ptr()); -} + ParallelTensorShape output_shape = get_output_shape(attrs, input_shape.shape); -/* - region(I): input - region(I/O): input_grad (if trainableInputs[0]) - region(I): output - region(I/O): output_grad - region(I): filter - region(I/O): filter_grad - region(I/O): bias_grad (if use_bias) -*/ -void Conv2D::backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - // Conv2D* conv = (Conv2D*) task->args; - Conv2DPerDeviceState const *m = *((Conv2DPerDeviceState **)task->local_args); - assert(regions.size() == (5 + static_cast(m->trainableInputs[0]) + - static_cast(m->use_bias))); - assert(task->regions.size() == - (5 + static_cast(m->trainableInputs[0]) + - static_cast(m->use_bias))); - size_t rid = 0; - TensorAccessorR acc_input( - regions[rid], task->regions[rid], FID_DATA, ctx, runtime); - rid++; - float *acc_input_grad_ptr = NULL; - if (m->trainableInputs[0]) { - TensorAccessorW acc_input_grad( - regions[rid], - task->regions[rid], - FID_DATA, - ctx, - runtime, - true /*readOutput*/); - acc_input_grad_ptr = acc_input_grad.ptr; - rid++; - } - TensorAccessorR acc_output( - regions[rid], task->regions[rid], FID_DATA, ctx, runtime); - rid++; - TensorAccessorW acc_output_grad( - regions[rid], - task->regions[rid], - FID_DATA, - ctx, - runtime, - true /*readOutput*/); - rid++; - TensorAccessorR acc_kernel( - regions[rid], task->regions[rid], FID_DATA, ctx, runtime); - rid++; - TensorAccessorW acc_kernel_grad( - regions[rid], - task->regions[rid], - FID_DATA, - ctx, - runtime, - true /*readOutput*/); - rid++; - float *acc_bias_grad_ptr = NULL; - if (m->use_bias) { - TensorAccessorW acc_bias_grad( - regions[rid], - task->regions[rid], - FID_DATA, - ctx, - runtime, - true /*readOutput*/); - acc_bias_grad_ptr = static_cast(acc_bias_grad.ptr); - rid++; - } - assert(rid == regions.size()); + SimTaskBinding init_binding; + init_binding.bind_arg(ATTRS, attrs); + init_binding.bind_arg(HANDLE, ff_handle()); - backward_kernel_wrapper(m, - acc_input.ptr, - acc_input_grad_ptr, - acc_output.ptr, - acc_output_grad.ptr, - acc_kernel.ptr, - acc_kernel_grad.ptr, - acc_bias_grad_ptr); -} + auto init_accessor = env.get_init_accessor(CONV2D_INIT_TASK_ID, init_binding); + DeviceSpecific per_device_state = + init_task_impl(init_accessor); + + SimTaskBinding fwd_binding; + fwd_binding.bind_arg(PROFILING, settings); + fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); -bool Conv2D::estimate_sync_cost(Simulator *sim, - MachineView const &view, - CostMetrics &cost_metrics) const { - ParallelDim kernel_dims[MAX_TENSOR_DIM], bias_dims[MAX_TENSOR_DIM]; - int kernel_ndims, bias_ndims; + fwd_binding.bind(INPUT, input_shape); + fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind(FILTER, filter_shape); + fwd_binding.bind(BIAS, bias_shape); - this->get_params().solve_dims(this->inputs[0]->get_shape(), - nullptr, - nullptr, - kernel_dims, - &kernel_ndims, - bias_dims, - &bias_ndims); + SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); - cost_metrics.sync_time = - sim->default_estimate_sync_cost(kernel_dims, kernel_ndims, view); + auto fwd_accessor = env.get_fwd_accessor(CONV2D_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = env.get_bwd_accessor(CONV2D_BWD_TASK_ID, bwd_binding); - if (this->use_bias) { - cost_metrics.sync_time += - sim->default_estimate_sync_cost(bias_dims, bias_ndims, view); - } + float forward_time = forward_task_impl(fwd_accessor).value(); + float backward_time = backward_task_impl(bwd_accessor).value(); - return true; + float sync_time = default_estimate_sync_time(env); + return make_metrics(forward_time, backward_time, sync_time, env); } -tl::optional Conv2D::as_dot() const { - RecordFormatter rr; - RecordFormatter r; +template <> +void register_task() { + OpTaskSignature init(OpTaskType::INIT); - r << this->inputs[0]->get_shape().as_dot(); - r << "in_channels" << this->in_channels; - r << "out_channels" << this->out_channels; - r << "kernel_h" << this->kernel_h; - r << "kernel_w" << this->kernel_w; - r << "padding_h" << this->padding_h; - r << "padding_w" << this->padding_w; - r << "stride_h" << this->stride_h; - r << "stride_w" << this->stride_w; - r << this->outputs[0]->get_shape().as_dot(); - rr << r; + init.add_arg_slot(ATTRS); + init.add_unchecked_arg_slot(HANDLE); - return rr; + register_task(CONV2D_INIT_TASK_ID, "Conv2D Init", init, init_task); } -bool Conv2D::measure_operator_cost(Simulator *sim, - MachineView const &mv, - CostMetrics &cost_metrics) const { - ParallelTensorBase sub_output, sub_input; - if (!outputs[0]->get_sub_tensor(mv, sub_output)) { - return false; - } - if (!inputs[0]->get_sub_tensor(mv, sub_input)) { - return false; - } - int input_w = sub_input.dims[0].size; - int input_h = sub_input.dims[1].size; - int input_c = sub_input.dims[2].size; - int input_n = sub_input.dims[3].size; - int output_w = sub_output.dims[0].size; - int output_h = sub_output.dims[1].size; - int output_c = sub_output.dims[2].size; - int output_n = sub_output.dims[3].size; - int pad_h = ((output_h - 1) * stride_h + kernel_h - input_h + 1) / 2; - int pad_w = ((output_w - 1) * stride_w + kernel_w - input_w + 1) / 2; - - Conv2DPerDeviceState *m = sim->conv2d_meta; - m->relu = activation == AC_MODE_RELU; - // require input_c is divisible by groups +template <> +void register_task() { + OpTaskSignature fwd(OpTaskType::FWD); - // allocate tensors in simulator - sim->free_all(); - float *input_ptr = (float *)sim->allocate(sub_input.get_volume(), DT_FLOAT); - assert(input_ptr != NULL); - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); + fwd.add_arg_slot(PROFILING); + fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); - float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - assert(output_ptr != NULL); - cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); + fwd.add_input_slot(INPUT); + fwd.add_output_slot(OUTPUT); + fwd.add_weight_slot(FILTER); + fwd.add_weight_slot(BIAS); - float *weight_ptr = (float *)sim->allocate( - (size_t)output_c * input_c * kernel_h * kernel_w / groups, DT_FLOAT); - assert(weight_ptr != NULL); - float *bias_ptr = (float *)sim->allocate(output_c, DT_FLOAT); - assert(bias_ptr != NULL); - cost_metrics.weights_memory += cost_metrics.total_mem_diff_from(sim->offset); + register_task(CONV2D_FWD_TASK_ID, "Conv2D Fwd", fwd, forward_task); +} - init_kernel(m, - input_w, - input_h, - input_c, - input_n, - output_w, - output_h, - output_c, - output_n, - kernel_h, - kernel_w, - groups, - stride_h, - stride_w, - pad_h, - pad_w, - input_ptr, - output_ptr, - weight_ptr, - weight_ptr, // note we reuse weight_ptr for kernel_grad_ptr here - // to avoid allocating another tensor - &cost_metrics.forward_time, - &cost_metrics.backward_time); +template <> +void register_task() { + OpTaskSignature bwd = + infer_bwd_signature(get_op_signature(CONV2D_FWD_TASK_ID)); - log_measure.debug("[Measure Conv2D] name(%s) input(%d %d %d %d) weight(%d %d " - "%d %d) output(%d %d %d %d) stride(%d %d) padding(%d %d) " - "forward_time(%.4lf) backward_time(%.4lf)\n", - name, - input_n, - input_c, - input_h, - input_w, - output_c, - input_c / groups, - kernel_h, - kernel_w, - output_n, - output_c, - output_h, - output_w, - stride_h, - stride_w, - padding_h, - padding_w, - cost_metrics.forward_time, - cost_metrics.backward_time); - return true; + register_task(CONV2D_BWD_TASK_ID, "Conv2D Bwd", bwd, backward_task); } } // namespace FlexFlow diff --git a/lib/runtime/src/ops/conv_2d.h b/lib/runtime/src/ops/conv_2d.h index 382538b70a..777b491089 100644 --- a/lib/runtime/src/ops/conv_2d.h +++ b/lib/runtime/src/ops/conv_2d.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_CONV_2D_H #include "op-attrs/ops/conv_2d.h" -#include "op_task_invocation.h" #include "sim_environment.h" +#include "task_spec/op_task_invocation.h" namespace FlexFlow { @@ -134,3 +134,692 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, } // namespace FlexFlow #endif + +// Tensor FFModel::conv2d(Tensor const &input, +// int outChannels, +// int kernelH, +// int kernelW, +// int strideH, +// int strideW, +// int paddingH, +// int paddingW, +// ActiMode activation, +// int groups, +// bool use_bias, +// Layer const *shared_op, +// Initializer *kernel_initializer, +// Initializer *bias_initializer, +// char const *name) { +// assert(input->num_dims() == 4); /*NCHW*/ + +// Conv2DAttrs attrs = {outChannels, +// kernelH, +// kernelW, +// strideH, +// strideW, +// paddingH, +// paddingW, +// groups, +// activation, +// use_bias}; + +// TensorShape output_shape = get_output_shape(attrs, input->get_shape()); +// Tensor output = this->tensor_mgr.create(output_shape, CreateGrad::YES, +// conv); + +// std::vector weights; + +// TensorShape kernel_shape = get_kernel_shape(attrs, input->get_shape()); +// weights.push_back(this->tensor_mgr.create( +// kernel_shape, CreateGrad::YES, kernel_initializer, CHOSEN_SYNC_TYPE)); + +// if (use_bias) { +// TensorShape bias_shape = get_bias_shape(attrs, input->get_shape()); +// weights.push_back(this->tensor_mgr.create( +// bias_shape, CreateGrad::YES, bias_initializer, CHOSEN_SYNC_TYPE)); +// } + +// Layer *conv = +// this->layer_mgr.create(attrs, DT_FLOAT, name, {input}, weights, +// {output}); + +// //{ +// // int numdims = 4; +// // int dims[MAX_TENSOR_DIM]; +// // dims[3] = input->dims[3]; +// // dims[2] = outChannels; +// // dims[1] = 1 + (input->dims[1] + 2 * paddingH - kernelH) / strideH; +// // dims[0] = 1 + (input->dims[0] + 2 * paddingW - kernelW) / strideW; +// // conv->outputs[0] = create_tensor_legion_ordering( +// // numdims, dims, DT_FLOAT, conv, 0, true /*create_grad*/); +// //} +// //{ +// // int dims[4] = {kernelW, kernelH, input->dims[2], outChannels}; +// // conv->weights[0] = create_weight_legion_ordering(4, +// // dims, +// // DT_FLOAT, +// // conv, +// // true /*create_grad*/, +// // kernel_initializer, +// // CHOSEN_SYNC_TYPE); +// //} +// // if (use_bias) { +// // int dims[1] = {outChannels}; +// // conv->weights[1] = create_weight_legion_ordering(1, +// // dims, +// // DT_FLOAT, +// // conv, +// // true /*create_grad*/, +// // bias_initializer, +// // CHOSEN_SYNC_TYPE); +// //} +// conv->add_initializer("kernel", kernel_initializer); +// conv->add_initializer("bias", bias_initializer); +// /* layers.push_back(conv); */ +// return conv->outputs[0]; +// } + +// Op *Conv2D::create_operator_from_layer( +// FFModel &model, +// Layer const *layer, +// std::vector const &inputs) { +// return new Conv2D(model, +// get(layer->attrs), +// inputs, +// layer->name, +// false /*allocate_weights*/ +// ); +// } + +/* void Conv2DParams::mark_replica_dims( */ +/* ParallelTensorShape const &input, */ +/* ParallelDim output_dims[MAX_TENSOR_DIM], */ +/* ParallelDim kernel_dims[MAX_TENSOR_DIM], */ +/* ParallelDim bias_dims[MAX_TENSOR_DIM]) const { */ +/* if (output_dims != nullptr) { */ +/* output_dims[Conv2DOutput::REPLICA].is_replica_dim = true; */ +/* } */ +/* if (kernel_dims != nullptr) { */ +/* kernel_dims[Conv2DOutput::REPLICA].is_replica_dim = true; */ +/* } */ +/* if (bias_dims != nullptr) { */ +/* bias_dims[Conv2DBias::REPLICA_1].is_replica_dim = true; */ +/* bias_dims[Conv2DBias::REPLICA_2].is_replica_dim = true; */ +/* bias_dims[Conv2DBias::REPLICA_3].is_replica_dim = true; */ +/* bias_dims[Conv2DBias::REPLICA_4].is_replica_dim = true; */ +/* } */ +/* } */ + +/* int Conv2DParams::output_size(ParallelTensorShape const &input, */ +/* ParallelDim output_dims[MAX_TENSOR_DIM]) const + * { */ +/* int input_w = input.dims[Conv2DInput::WIDTH].size; */ +/* int input_h = input.dims[Conv2DInput::HEIGHT].size; */ + +/* output_dims[Conv2DOutput::SAMPLE].size = + * input.dims[Conv2DInput::SAMPLE].size; */ +/* output_dims[Conv2DOutput::CHANNEL].size = out_channels; */ +/* output_dims[Conv2DOutput::HEIGHT].size = */ +/* 1 + (input_h + 2 * padding_h - kernel_h) / stride_h; */ +/* output_dims[Conv2DOutput::WIDTH].size = */ +/* 1 + (input_w + 2 * padding_w - kernel_w) / stride_w; */ + +/* return input.num_dims; */ +/* }; */ + +/* int Conv2DParams::kernel_size(ParallelTensorShape const &input, */ +/* ParallelDim kernel_dims[MAX_TENSOR_DIM]) const + * { */ +/* kernel_dims[Conv2DKernel::CHANNEL_OUT].size = this->out_channels; */ +/* kernel_dims[Conv2DKernel::CHANNEL_IN].size = */ +/* input.dims[Conv2DInput::CHANNEL].size / this->groups; */ +/* kernel_dims[Conv2DKernel::HEIGHT].size = */ +/* this->kernel_h * input.dims[Conv2DInput::HEIGHT].degree; */ +/* kernel_dims[Conv2DKernel::WIDTH].size = */ +/* this->kernel_w * input.dims[Conv2DInput::WIDTH].degree; */ + +/* return Conv2DKernel::NUMDIM; */ +/* } */ + +/* int Conv2DParams::bias_size(ParallelTensorShape const &input, */ +/* ParallelDim bias_dims[MAX_TENSOR_DIM]) const { */ +/* bias_dims[Conv2DBias::CHANNEL].size = this->out_channels; */ + +/* return Conv2DBias::NUMDIM; */ +/* }; */ + +/* void Conv2DParams::solve_dims(ParallelTensorShape const &input, */ +/* ParallelDim output_dims[MAX_TENSOR_DIM], */ +/* int *output_ndims, */ +/* ParallelDim kernel_dims[MAX_TENSOR_DIM], */ +/* int *kernel_ndims, */ +/* ParallelDim bias_dims[MAX_TENSOR_DIM], */ +/* int *bias_ndims) const { */ +/* assert((output_dims == nullptr) == (output_ndims == nullptr)); */ +/* assert((kernel_dims == nullptr) == (kernel_ndims == nullptr)); */ +/* assert((bias_dims == nullptr) == (bias_ndims == nullptr)); */ + +/* std::vector mapping; */ +/* Conv2D::construct_mappings(mapping, this->use_bias); */ + +/* this->mark_replica_dims(input, output_dims, kernel_dims, bias_dims); */ + +/* std::vector output_dim_sets; */ +/* if (output_dims != nullptr) { */ +/* output_dim_sets.push_back(output_dims); */ +/* } */ + +/* std::vector weight_dim_sets; */ +/* if (kernel_dims != nullptr) { */ +/* weight_dim_sets.push_back(kernel_dims); */ +/* } */ +/* if (bias_dims != nullptr && this->use_bias) { */ +/* weight_dim_sets.push_back(bias_dims); */ +/* } */ + +/* solve_parallel_dim_mappings( */ +/* mapping, {input.dims}, weight_dim_sets, output_dim_sets); */ + +/* if (output_dims != nullptr) { */ +/* *output_ndims = this->output_size(input, output_dims); */ +/* } */ +/* if (kernel_dims != nullptr) { */ +/* *kernel_ndims = this->kernel_size(input, kernel_dims); */ +/* } */ +/* if (bias_dims != nullptr && this->use_bias) { */ +/* *bias_ndims = this->bias_size(input, bias_dims); */ +/* } */ +/* } */ + +/*static*/ +/* void Conv2D::construct_mappings(std::vector &out, + */ +/* bool use_bias) { */ +/* Conv2D::construct_output_mappings(out); */ +/* Conv2D::construct_weight_mappings(out, use_bias); */ +/* } */ + +/*static*/ +/* void Conv2D::construct_output_mappings( */ +/* std::vector &out) { */ +/* Op::construct_output_parallel_dims( */ +/* out, */ +/* {{Conv2DInput::CHANNEL, */ +/* MappingOperation::REPLICATE, */ +/* Conv2DOutput::REPLICA}, */ +/* {Conv2DInput::SAMPLE, MappingOperation::PARTITION, + * Conv2DOutput::SAMPLE}, */ +/* {Conv2DInput::REPLICA, */ +/* MappingOperation::PARTITION, */ +/* Conv2DOutput::CHANNEL}, */ +/* {Conv2DInput::HEIGHT, MappingOperation::PARTITION, + * Conv2DOutput::HEIGHT}, */ +/* {Conv2DInput::WIDTH, MappingOperation::PARTITION, + * Conv2DOutput::WIDTH}}); */ +/* } */ + +/*static*/ +/* void Conv2D::construct_weight_mappings( */ +/* std::vector &out, bool use_bias) { */ +/* Op::construct_weight_parallel_dims( */ +/* out, */ +/* { */ +/* {Conv2DInput::REPLICA, */ +/* MappingOperation::PARTITION, */ +/* Conv2DKernel::CHANNEL_OUT}, */ +/* {Conv2DInput::SAMPLE, */ +/* MappingOperation::REPLICATE, */ +/* Conv2DKernel::REPLICA}, */ +/* {Conv2DInput::CHANNEL, */ +/* MappingOperation::PARTITION, */ +/* Conv2DKernel::CHANNEL_IN}, */ +/* {Conv2DInput::HEIGHT, */ +/* MappingOperation::REPLICATE, */ +/* Conv2DKernel::HEIGHT}, // Kernel::{HEIGHT, WEIGHT} would both work + */ +/* // here */ +/* {Conv2DInput::WIDTH, */ +/* MappingOperation::REPLICATE, */ +/* Conv2DKernel::WIDTH}, // same as above */ +/* }, */ +/* Conv2DInput::INDEX, */ +/* Conv2DKernel::INDEX); */ + +/* if (use_bias) { */ +/* Op::construct_weight_parallel_dims( */ +/* out, */ +/* {{Conv2DInput::REPLICA, Conv2DBias::REPLICA_1}, */ +/* {Conv2DInput::SAMPLE, Conv2DBias::REPLICA_2}, */ +/* {Conv2DInput::CHANNEL, Conv2DBias::CHANNEL}, */ +/* {Conv2DInput::HEIGHT, Conv2DBias::REPLICA_3}, */ +/* {Conv2DInput::WIDTH, Conv2DBias::REPLICA_4}}, */ +/* Conv2DInput::INDEX, */ +/* Conv2DBias::INDEX); */ +/* } */ +/* } */ + +// Conv2D::Conv2D(FFModel &model, +// Conv2D const &other, +// const ParallelTensor input, +// bool allocate_weights) +// : Conv2D(model, +// other.layer_guid, +// input, +// other.out_channels, +// other.kernel_h, +// other.kernel_w, +// other.stride_h, +// other.stride_w, +// other.padding_h, +// other.padding_w, +// other.activation, +// other.groups, +// other.use_bias, +// allocate_weights, +// other.name) {} + +// Conv2D::Conv2D(FFModel &model, +// Conv2DAttrs const &attrs, +// std::vector const &inputs, +// char const *name, +// bool allocate_weights) +// : Conv2D(model, +// params.layer_guid, +// input, +// params.out_channels, +// params.kernel_h, +// params.kernel_w, +// params.stride_h, +// params.stride_w, +// params.padding_h, +// params.padding_w, +// params.activation, +// params.groups, +// params.use_bias, +// allocate_weights, +// name) {} + +/* bool Conv2DParams::is_valid(ParallelTensorShape const &input) const { */ +/* ParallelTensorShape output_shape, kernel_shape, bias_shape; */ +/* this->solve_dims(input, */ +/* output_shape.dims, */ +/* &output_shape.num_dims, */ +/* kernel_shape.dims, */ +/* &kernel_shape.num_dims, */ +/* bias_shape.dims, */ +/* &bias_shape.num_dims); */ +/* bool is_valid = true; */ +/* is_valid &= input.is_valid(); */ +/* is_valid &= output_shape.is_valid(); */ +/* is_valid &= kernel_shape.is_valid(); */ +/* if (use_bias) { */ +/* is_valid &= bias_shape.is_valid(); */ +/* } */ + +/* // TODO FIXME: Currently disable parallelizing the height and width + * dimension */ +/* if (input.dims[0].degree > 1 || input.dims[1].degree > 1) { */ +/* return false; */ +/* } */ + +/* return is_valid; */ +/* } */ + +// Conv2D::Conv2D(FFModel &model, +// LayerID const &_layer_guid, +// const ParallelTensor input, +// int outChannels, +// int kernelH, +// int kernelW, +// int strideH, +// int strideW, +// int paddingH, +// int paddingW, +// ActiMode activation, +// int groups, +// bool use_bias, +// bool allocate_weights, +// char const *name) +// : Op(model, +// OP_CONV2D, +// DT_FLOAT, +// name, +// 1 /*inputs*/, +// use_bias ? 2 : 1 /*weights*/, +// allocate_weights, +// 1 /*outputs*/, +// input), +// in_channels(input->dims[Conv2DInput::CHANNEL].size / +// input->dims[Conv2DInput::CHANNEL].degree), +// out_channels(outChannels), kernel_h(kernelH), kernel_w(kernelW), +// stride_h(strideH), stride_w(strideW), padding_h(paddingH), +// padding_w(paddingW), activation(activation), groups(groups), +// use_bias(use_bias) { +// // overwrite layer_guid +// layer_guid = _layer_guid; +// assert(input->num_dims == Conv2DInput::NUMDIM); +// assert(this->stride_h > 0); +// assert(this->stride_w > 0); + +// ParallelDim output_dims[MAX_TENSOR_DIM], kernel_dims[MAX_TENSOR_DIM], +// bias_dims[MAX_TENSOR_DIM]; +// int output_ndims, kernel_ndims, bias_ndims; + +// this->construct_mappings(*this->parallel_dims_mapping, this->use_bias); +// this->get_params().solve_dims(this->inputs[0]->get_shape(), +// output_dims, +// &output_ndims, +// kernel_dims, +// &kernel_ndims, +// bias_dims, +// &bias_ndims); + +// if (allocate_weights) { +// Initializer *kernel_initializer = new GlorotUniform(std::rand() +// /*seed*/); + +// weights[Conv2DKernel::INDEX] = +// model.create_parallel_weight_legion_ordering(kernel_ndims, +// kernel_dims, +// DT_FLOAT, +// NULL /*owner_op*/, +// true /*create_grad*/, +// kernel_initializer, +// CHOSEN_SYNC_TYPE); + +// if (use_bias) { +// Initializer *bias_initializer = new ZeroInitializer(); + +// weights[Conv2DBias::INDEX] = +// model.create_parallel_weight_legion_ordering(bias_ndims, +// bias_dims, +// DT_FLOAT, +// NULL /*owner_op*/, +// true /*create_grad*/, +// bias_initializer, +// CHOSEN_SYNC_TYPE); +// } +// } + +// outputs[0] = model.create_parallel_tensor_legion_ordering( +// output_ndims, output_dims, DT_FLOAT, this); + +// assert(check_output_input_weight_parallel_dims(allocate_weights)); +// } + +// tl::optional Conv2D::as_dot() const { +// RecordFormatter rr; +// RecordFormatter r; + +// r << this->inputs[0]->get_shape().as_dot(); +// r << "in_channels" << this->in_channels; +// r << "out_channels" << this->out_channels; +// r << "kernel_h" << this->kernel_h; +// r << "kernel_w" << this->kernel_w; +// r << "padding_h" << this->padding_h; +// r << "padding_w" << this->padding_w; +// r << "stride_h" << this->stride_h; +// r << "stride_w" << this->stride_w; +// r << this->outputs[0]->get_shape().as_dot(); +// rr << r; + +// return rr; +// } + +// bool Conv2D::estimate_sync_cost(Simulator *sim, +// MachineView const &view, +// CostMetrics &cost_metrics) const { +// ParallelDim kernel_dims[MAX_TENSOR_DIM], bias_dims[MAX_TENSOR_DIM]; +// int kernel_ndims, bias_ndims; + +// this->get_params().solve_dims(this->inputs[0]->get_shape(), +// nullptr, +// nullptr, +// kernel_dims, +// &kernel_ndims, +// bias_dims, +// &bias_ndims); + +// cost_metrics.sync_time = +// sim->default_estimate_sync_cost(kernel_dims, kernel_ndims, view); + +// if (this->use_bias) { +// cost_metrics.sync_time += +// sim->default_estimate_sync_cost(bias_dims, bias_ndims, view); +// } + +// return true; +// } + +/* + region(I): input + region(I/O): input_grad (if trainableInputs[0]) + region(I): output + region(I/O): output_grad + region(I): filter + region(I/O): filter_grad + region(I/O): bias_grad (if use_bias) +*/ + +/* + regions[0](I): input + regions[1](O): output + regions[2](I): filter + regions[3](I): bias +*/ + +// void Conv2D::forward(FFModel const &ff) { +// this->execute_task(ff, CONV2D_FWD_TASK_ID, get_fwd_task_signature()); +// } + +// void Conv2D::backward(FFModel const &ff) { +// this->execute_task(ff, CONV2D_bWD_TASK_ID, get_bwd_task_signature()); +// } + +// TaskSpec Conv2D::get_tasks_spec() const { +// OpTasksSpec spec { +// CONV2D_INIT_TASK_ID, +// CONV2D_FWD_TASK_ID, +// CONV2D_BWD_TASK_ID +// }; +// auto &fwd = spec.get_fwd(); + +// fwd.add_input_slot(INPUT); +// fwd.add_param_slot(KERNEL); +// fwd.add_output_slot(OUTPUT); + +// auto input = spec.input_tensor(0); +// auto kernel = spec.param_tensor(0); +// auto bias = spec.param_tensor(1); +// auto output = spec.output_tensor(0); + +// fwd[INPUT] = input; +// fwd[KERNEL] = kernel; +// if (this->use_bias) { +// fwd[BIAS] = bias; +// } +// fwd[OUTPUT] = output; + +// return spec; +// } + +/* TaskSpec Conv2D::get_forward_task_spec() const { */ +/* TaskSpec spec = { CONV2D_FWD_TASK_ID, Pass::FWD }; */ + +/* auto input = spec.add_tensor(TensorRole::INPUT, 0); */ +/* auto kernel = spec.add_tensor(TensorRole::PARAM, 0); */ +/* auto bias = spec.add_tensor(TensorRole::BIAS, 1); */ +/* auto output = spec.add_tensor(TensorRole::OUTPUT, 0); */ + +/* spec.add_input(INPUT, input); */ +/* spec.add_input(KERNEL, kernel); */ + +/* if (this->use_bias) { */ +/* spec.add_input(BIAS, bias); */ +/* } */ + +/* spec.add_output(OUTPUT, output); */ + +/* return spec; */ +/* } */ + +/* TaskSpec Conv2D::get_backward_task_spec() const { */ +/* TaskSpec spec = { CONV2D_BWD_TASK_ID, Pass::BWD }; */ + +/* auto input = spec.add_tensor(TensorRole::INPUT, 0); */ +/* auto kernel = spec.add_tensor(TensorRole::PARAM, 0); */ +/* auto bias = spec.add_tensor(TensorRole::BIAS, 1); */ +/* auto output = spec.add_tensor(TensorRole::OUTPUT, 0); */ + +/* spec.add_input(INPUT, input); */ +/* spec.add_output(INPUT_GRAD, input.grad); */ +/* spec.add_input(KERNEL, kernel); */ +/* spec.add_output(KERNEL_GRAD, kernel.grad); */ + +/* if (this->use_bias) { */ +/* spec.add_input(BIAS, bias); */ +/* spec.add_output(BIAS_GRAD, bias.grad); */ +/* } */ + +/* spec.add_input(OUTPUT, output); */ +/* spec.add_input(OUTPUT_GRAD, output.grad); */ + +/* return spec; */ +/* } */ + +/* + regions[0]: input + regions[1]: output + regions[2](I): filter + regions[3](I): bias + regions[4](O): filter_grad + regions[5](O): input_grad +*/ + +// void Conv2D::init(FFModel const &ff) { +// this->execute_task(ff, CONV2D_INIT_TASK_ID, get_init_task_signature()); +// assert(check_output_input_weight_same_parallel_is()); +// parallel_is = outputs[0]->parallel_is; +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_init(ff, argmap); +// IndexLauncher launcher(CONV2D_INIT_TASK_ID, +// parallel_is, +// TaskArgument(this, sizeof(Conv2D)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(1, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(weights[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// weights[0]->region)); +// launcher.add_field(2, FID_DATA); +// // launcher.add_region_requirement( +// // RegionRequirement(weights[1]->part, 0/*projection id*/, +// // READ_ONLY, EXCLUSIVE, weights[1]->region)); +// // launcher.add_field(3, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(weights[0]->part_grad, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// weights[0]->region_grad)); +// launcher.add_field(3, FID_DATA); +// // launcher.add_region_requirement( +// // RegionRequirement(inputs[0]->part_grad, 0/*projection id*/, +// // WRITE_ONLY, EXCLUSIVE, inputs[0]->region_grad)); +// // launcher.add_field(4, FID_DATA); +// FutureMap fm = runtime->execute_index_space(ctx, launcher); +// fm.wait_all_results(); +// set_opmeta_from_futuremap(ff, fm); +// } + +// printf("init conv (input): n(%d) c(%d) h(%d) w(%d)\n", +// input_n, +// input_c, +// input_h, +// input_w); +// printf("init conv (output): n(%d) c(%d) h(%d) w(%d)\n", +// output_n, +// output_c, +// output_h, +// output_w); + +// printf("convDim: padding(%d %d) stride(%d %d)\n", conv->padding_h, +// conv->padding_w, conv->stride_h, conv->stride_w); +// int pad_h = +// ((output_h - 1) * attrs.stride_h + attrs.kernel_h - input_h + 1) / 2; +// int pad_w = +// ((output_w - 1) * attrs.stride_w + attrs.kernel_w - input_w + 1) / 2; +// if (pad_h != attrs.padding_h) { +// printf("Warning: changing conv_padding_h to satisfy output_h size\n"); +// } +// if (pad_w != attrs.padding_w) { +// printf("Warning: changing conv_padding_w to satisfy output_w size\n"); +// } + +// size_t rid = 0; +// TensorAccessorR acc_input( +// regions[rid], task->regions[rid], FID_DATA, ctx, runtime); +// rid++; +// float *acc_input_grad_ptr = NULL; +// if (m->trainableInputs[0]) { +// TensorAccessorW acc_input_grad( +// regions[rid], +// task->regions[rid], +// FID_DATA, +// ctx, +// runtime, +// true /*readOutput*/); +// acc_input_grad_ptr = acc_input_grad.ptr; +// rid++; +// } +// TensorAccessorR acc_output( +// regions[rid], task->regions[rid], FID_DATA, ctx, runtime); +// rid++; +// TensorAccessorW acc_output_grad( +// regions[rid], +// task->regions[rid], +// FID_DATA, +// ctx, +// runtime, +// true /*readOutput*/); +// rid++; +// TensorAccessorR acc_kernel( +// regions[rid], task->regions[rid], FID_DATA, ctx, runtime); +// rid++; +// TensorAccessorW acc_kernel_grad( +// regions[rid], +// task->regions[rid], +// FID_DATA, +// ctx, +// runtime, +// true /*readOutput*/); +// rid++; +// float *acc_bias_grad_ptr = NULL; +// if (m->use_bias) { +// TensorAccessorW acc_bias_grad( +// regions[rid], +// task->regions[rid], +// FID_DATA, +// ctx, +// runtime, +// true /*readOutput*/); +// acc_bias_grad_ptr = static_cast(acc_bias_grad.ptr); +// rid++; +// } diff --git a/lib/runtime/src/serialization.h b/lib/runtime/src/serialization.h index adf838201a..5c1194c7d6 100644 --- a/lib/runtime/src/serialization.h +++ b/lib/runtime/src/serialization.h @@ -7,9 +7,10 @@ #include "legion/legion_utilities.h" #include "op-attrs/dim_ordered.h" #include "utils/optional.h" +#include "utils/required.h" +#include "utils/type_traits.h" #include "utils/variant.h" #include "utils/visitable.h" -#include namespace FlexFlow { @@ -77,6 +78,13 @@ struct is_trivially_serializable< typename std::enable_if::value>::type> : std::true_type {}; +template +struct is_trivially_serializable>> + : is_trivially_serializable> {}; + +template +struct is_trivially_serializable> : is_trivially_serializable {}; + template <> struct is_trivially_serializable : std::true_type {}; template <> diff --git a/lib/runtime/src/task_spec/op_task_invocation.h b/lib/runtime/src/task_spec/op_task_invocation.h index 07f5bf12ae..56e709734e 100644 --- a/lib/runtime/src/task_spec/op_task_invocation.h +++ b/lib/runtime/src/task_spec/op_task_invocation.h @@ -6,6 +6,7 @@ #include "legion.h" #include "op_arg_ref.h" #include "op_task_signature.h" +#include "op_tensor_spec.h" #include "runtime/config.h" #include "runtime/profiling.h" #include "serialization.h" @@ -14,6 +15,7 @@ #include "utils/bidict.h" #include "utils/optional.h" #include "utils/stack_map.h" +#include "variadic_tensor_ref.h" #include #include #include @@ -22,16 +24,6 @@ namespace FlexFlow { enum class IsTrainable { YES, NO }; -struct OpTensorSpec { - TensorRole role; - req idx; -}; -FF_VISITABLE_STRUCT(OpTensorSpec, role, idx); - -OpTensorSpec input_tensor(int); -OpTensorSpec output_tensor(int); -OpTensorSpec weight_tensor(int); - using OpArgSpec = variant + void bind(slot_id name, VariadicTensorRef const &t) { + NOT_IMPLEMENTED(); + } + template void bind_device_specific_arg(slot_id name, T const &t) { NOT_IMPLEMENTED(); diff --git a/lib/runtime/src/task_spec/op_tensor_spec.h b/lib/runtime/src/task_spec/op_tensor_spec.h new file mode 100644 index 0000000000..d859bb3072 --- /dev/null +++ b/lib/runtime/src/task_spec/op_tensor_spec.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_OP_TENSOR_SPEC_REF_H +#define _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_OP_TENSOR_SPEC_REF_H + +#include "op_task_signature.h" + +namespace FlexFlow { + +struct OpTensorSpec { + TensorRole role; + req idx; +}; +FF_VISITABLE_STRUCT(OpTensorSpec, role, idx); + +OpTensorSpec input_tensor(int); +OpTensorSpec output_tensor(int); +OpTensorSpec weight_tensor(int); + +} // namespace FlexFlow + +#endif diff --git a/lib/runtime/src/task_spec/runtime_arg_ref.h b/lib/runtime/src/task_spec/runtime_arg_ref.h index 6b4345091a..033c2bcfbc 100644 --- a/lib/runtime/src/task_spec/runtime_arg_ref.h +++ b/lib/runtime/src/task_spec/runtime_arg_ref.h @@ -3,6 +3,7 @@ #include "arg_ref.h" #include "device_specific.h" +#include "runtime/config.h" namespace FlexFlow { @@ -15,6 +16,7 @@ using RuntimeArgRefSpec = ArgRefSpec; RuntimeArgRef profiling_settings(); RuntimeArgRef> ff_handle(); +RuntimeArgRef iteration_config(); } // namespace FlexFlow diff --git a/lib/runtime/src/task_spec/variadic_tensor_ref.h b/lib/runtime/src/task_spec/variadic_tensor_ref.h new file mode 100644 index 0000000000..ddd9bd5069 --- /dev/null +++ b/lib/runtime/src/task_spec/variadic_tensor_ref.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_VARIADIC_TENSOR_ARG_REF_H +#define _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_VARIADIC_TENSOR_ARG_REF_H + +#include "arg_ref.h" +#include "op_tensor_spec.h" + +namespace FlexFlow { + +enum class VariadicTensorRefType { INPUT_TENSORS, NUM_INPUTS }; + +template +using VariadicTensorRef = ArgRef; + +VariadicTensorRef get_input_tensors() { + return {VariadicTensorRefType::INPUT_TENSORS}; +} + +VariadicTensorRef get_number_inputs() { + return {VariadicTensorRefType::NUM_INPUTS}; +} + +} // namespace FlexFlow + +#endif