diff --git a/deps/fmt b/deps/fmt index f5e54359df..a33701196a 160000 --- a/deps/fmt +++ b/deps/fmt @@ -1 +1 @@ -Subproject commit f5e54359df4c26b6230fc61d38aa294581393084 +Subproject commit a33701196adfad74917046096bf5a2aa0ab0bb50 diff --git a/lib/kernels/include/kernels/element_unary_kernels.h b/lib/kernels/include/kernels/element_unary_kernels.h index 428c0ed897..5fac73132d 100644 --- a/lib/kernels/include/kernels/element_unary_kernels.h +++ b/lib/kernels/include/kernels/element_unary_kernels.h @@ -3,42 +3,50 @@ #include "kernels/accessor.h" #include "kernels/device.h" -#include "legion.h" +#include "kernels/ff_handle.h" +#include "op-attrs/ops/element_unary.h" #include namespace FlexFlow { -class ElementUnaryPerDeviceState : public PerDeviceOpState { -public: - ElementUnaryPerDeviceState(FFHandler handle); +struct ElementUnaryPerDeviceState { + PerDeviceFFHandle handle; ffTensorDescriptor_t inputTensor, outputTensor; ffActivationDescriptor_t actiDesc; OperatorType op_type; DataType data_type; - bool inplace; float scalar; - char op_name[MAX_OPNAME]; }; +FF_VISITABLE_STRUCT_NO_EQ(ElementUnaryPerDeviceState, + handle, + inputTensor, + outputTensor, + actiDesc, + op_type, + data_type, + scalar); + namespace Kernels { namespace ElementUnary { -void init_kernel(ElementUnaryPerDeviceState *m, - Legion::Domain const &input_domain, - Legion::Domain const &output_domain); +ElementUnaryPerDeviceState init_kernel(PerDeviceFFHandle const &handle, + ArrayShape const &input_shape, + ArrayShape const &output_shape, + DataType data_type); void forward_kernel(ffStream_t stream, - ElementUnaryPerDeviceState const *m, + ElementUnaryPerDeviceState const &device_state, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output); void backward_kernel(ffStream_t stream, - ElementUnaryPerDeviceState const *m, + ElementUnaryPerDeviceState const &device_state, GenericTensorAccessorR const &input, - GenericTensorAccessorR const &input_grad, - GenericTensorAccessorW const &output, - GenericTensorAccessorW const &output_grad); + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output, + GenericTensorAccessorR const &output_grad); } // namespace ElementUnary } // namespace Kernels diff --git a/lib/kernels/include/kernels/embedding_kernels.h b/lib/kernels/include/kernels/embedding_kernels.h index 9d70fd9a79..34b892c17e 100644 --- a/lib/kernels/include/kernels/embedding_kernels.h +++ b/lib/kernels/include/kernels/embedding_kernels.h @@ -5,29 +5,25 @@ #include "kernels/device.h" namespace FlexFlow { - -class EmbeddingPerDeviceState : public PerDeviceOpState { -public: - EmbeddingPerDeviceState(FFHandler handle); - DataType input_data_type, output_data_type; - AggrMode aggr; -}; - namespace Kernels { namespace Embedding { void forward_kernel(ffStream_t stream, - EmbeddingPerDeviceState const *m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, GenericTensorAccessorR const &weight, + DataType input_data_type, + DataType output_data_type, + AggrMode aggr, int in_dim, int out_dim, int batch_size); void backward_kernel(ffStream_t stream, - EmbeddingPerDeviceState const *m, GenericTensorAccessorR const &input, GenericTensorAccessorR const &output, GenericTensorAccessorW const &weight_grad, + DataType input_data_type, + DataType output_data_type, + AggrMode aggr, int in_dim, int out_dim, int batch_size); diff --git a/lib/kernels/include/kernels/layer_norm_kernels.h b/lib/kernels/include/kernels/layer_norm_kernels.h index a49e1b3483..fb49854653 100644 --- a/lib/kernels/include/kernels/layer_norm_kernels.h +++ b/lib/kernels/include/kernels/layer_norm_kernels.h @@ -5,42 +5,39 @@ namespace FlexFlow { -class LayerNormPerDeviceState : public PerDeviceOpState { -public: - LayerNormPerDeviceState(FFHandler handle, - bool elementwise_affine_, - int64_t effective_batch_size_, - int64_t effective_num_elements_, - bool profiling_, - float eps_); - -public: - bool elementwise_affine; - int64_t effective_batch_size, effective_num_elements; - float eps; +struct LayerNormPerDeviceState { float *mean, *rstd, *ds, *db, *scale, *bias; - char op_name[MAX_OPNAME]; - DataType data_type; }; namespace Kernels { namespace LayerNorm { +LayerNormPerDeviceState init_kernel(PerDeviceFFHandle handle, + int64_t batch_size); + void forward_kernel(ffStream_t stream, - LayerNormPerDeviceState const *m, + LayerNormPerDeviceState const &m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, GenericTensorAccessorW const &gamma, - GenericTensorAccessorW const &beta); + GenericTensorAccessorW const &beta, + DataType data_type, + int64_t batch_size, + int64_t num_elements, + float eps); void backward_kernel(ffStream_t stream, - LayerNormPerDeviceState const *m, + LayerNormPerDeviceState const &m, GenericTensorAccessorR const &output_grad, GenericTensorAccessorR const &input, GenericTensorAccessorW const &input_grad, GenericTensorAccessorR const &gamma, GenericTensorAccessorW const &gamma_grad, - GenericTensorAccessorW const &beta_grad); + GenericTensorAccessorW const &beta_grad, + DataType data_type, + int64_t batch_size, + int64_t num_elements, + float eps); } // namespace LayerNorm } // namespace Kernels diff --git a/lib/kernels/src/cuda/element_unary_kernels.cu b/lib/kernels/src/cuda/element_unary_kernels.cu index 1251f75603..568e1f319c 100644 --- a/lib/kernels/src/cuda/element_unary_kernels.cu +++ b/lib/kernels/src/cuda/element_unary_kernels.cu @@ -18,18 +18,6 @@ #include "kernels/element_unary_kernels.h" namespace FlexFlow { - -// declare Legion names -using Legion::coord_t; -using Legion::Domain; - -ElementUnaryPerDeviceState::ElementUnaryPerDeviceState(FFHandler handler) - : PerDeviceOpState(handler) { - checkCUDNN(cudnnCreateTensorDescriptor(&inputTensor)); - checkCUDNN(cudnnCreateTensorDescriptor(&outputTensor)); - checkCUDNN(cudnnCreateActivationDescriptor(&actiDesc)); -} - namespace Kernels { namespace ElementUnary { @@ -45,13 +33,23 @@ static bool use_cudnn(OperatorType op_type) { } } -void init_kernel(ElementUnaryPerDeviceState *m, - Domain const &input_domain, - Domain const &output_domain) { +ElementUnaryPerDeviceState init_kernel(PerDeviceFFHandle const &handle, + ArrayShape const &input_shape, + ArrayShape const &output_shape, + OperatorType op_type, + DataType data_type) { + + ffTensorDescriptor_t inputTensor; + ffTensorDescriptor_t outputTensor; + ffActivationDescriptor_t actiDesc; - if (use_cudnn(m->op_type)) { + checkCUDNN(cudnnCreateTensorDescriptor(&inputTensor)); + checkCUDNN(cudnnCreateTensorDescriptor(&outputTensor)); + checkCUDNN(cudnnCreateActivationDescriptor(&actiDesc)); + + if (use_cudnn(op_type)) { cudnnActivationMode_t mode; - switch (m->op_type) { + switch (op_type) { case OP_SIGMOID: mode = CUDNN_ACTIVATION_SIGMOID; break; @@ -67,32 +65,37 @@ void init_kernel(ElementUnaryPerDeviceState *m, default: assert(false); } - checkCUDNN(cudnnSetActivationDescriptor( - m->actiDesc, mode, CUDNN_PROPAGATE_NAN, 0.0)); checkCUDNN( - cudnnSetTensorDescriptorFromDomain(m->inputTensor, input_domain)); - // input_domain == output_domain + cudnnSetActivationDescriptor(actiDesc, mode, CUDNN_PROPAGATE_NAN, 0.0)); checkCUDNN( - cudnnSetTensorDescriptorFromDomain(m->outputTensor, output_domain)); + cudnnSetTensorDescriptorFromArrayShape(inputTensor, input_shape)); + // input_shape == output_shape + checkCUDNN( + cudnnSetTensorDescriptorFromArrayShape(outputTensor, output_shape)); } + + ElementUnaryPerDeviceState per_device_state = { + handle, inputTensor, outputTensor, actiDesc, op_type, data_type, scalar}; + + return per_device_state; } template struct ForwardKernel { void operator()(ffStream_t stream, - ElementUnaryPerDeviceState const *m, + ElementUnaryPerDeviceState const &m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) const { - checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); - if (use_cudnn(m->op_type)) { + checkCUDNN(cudnnSetStream(m.handle.dnn, stream)); + if (use_cudnn(m.op_type)) { float alpha = 1.0f, beta = 0.0f; - checkCUDNN(cudnnActivationForward(m->handle.dnn, - m->actiDesc, + checkCUDNN(cudnnActivationForward(m.handle.dnn, + m.actiDesc, &alpha, - m->inputTensor, + m.inputTensor, input.get(), &beta, - m->outputTensor, + m.outputTensor, output.get())); } else { size_t num_elements = input.shape.num_elements(); @@ -100,8 +103,8 @@ struct ForwardKernel { CUDA_NUM_THREADS, 0, stream>>>(num_elements, - (T)m->scalar, - m->op_type, + (T)m.scalar, + m.op_type, input.get(), output.get()); } @@ -111,34 +114,34 @@ struct ForwardKernel { template struct BackwardKernel { void operator()(ffStream_t stream, - ElementUnaryPerDeviceState const *m, + ElementUnaryPerDeviceState const &m, GenericTensorAccessorR const &input, - GenericTensorAccessorR const &input_grad, - GenericTensorAccessorW const &output, - GenericTensorAccessorW const &output_grad) { - checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output, + GenericTensorAccessorR const &output_grad) { + checkCUDNN(cudnnSetStream(m.handle.dnn, stream)); - if (use_cudnn(m->op_type)) { + if (use_cudnn(m.op_type)) { float alpha = 1.0f; - checkCUDNN(cudnnActivationBackward(m->handle.dnn, - m->actiDesc, + checkCUDNN(cudnnActivationBackward(m.handle.dnn, + m.actiDesc, &alpha, - m->outputTensor, + m.outputTensor, output.get(), - m->outputTensor, + m.outputTensor, output_grad.get()), - m->inputTensor, + m.inputTensor, input.get(), &alpha, - m->inputTensor, + m.inputTensor, input_grad.get())); } else { size_t num_elements = input.shape.num_elements(); elewise_unary_backward_kernel <<>>( num_elements, - m->scalar, - m->op_type, + m.scalar, + m.op_type, output.get(), output_grad.get(), input.get(), @@ -148,21 +151,19 @@ struct BackwardKernel { } void forward_kernel(ffStream_t stream, - ElementUnaryPerDeviceState const *m, + ElementUnaryPerDeviceState const &device_state, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { - { - DataTypeDispatch1{}(m->data_type, stream, m, input, output); - } + { DataTypeDispatch1{}(m.data_type, stream, m, input, output); } void backward_kernel(ffStream_t stream, - ElementUnaryPerDeviceState const *m, + ElementUnaryPerDeviceState const &device_state, GenericTensorAccessorR const &input, GenericTensorAccessorR const &input_grad, GenericTensorAccessorW const &output, GenericTensorAccessorW const &output_grad) DataTypeDispatch1{}( - m->data_type, stream, m, input, input_grad, output, output_grad); + m.data_type, stream, m, input, input_grad, output, output_grad); } template diff --git a/lib/kernels/src/cuda/embedding_kernels.cu b/lib/kernels/src/cuda/embedding_kernels.cu index b97d74d010..9d3cca66a0 100644 --- a/lib/kernels/src/cuda/embedding_kernels.cu +++ b/lib/kernels/src/cuda/embedding_kernels.cu @@ -24,7 +24,7 @@ namespace Embedding { template struct ForwardKernel { void operator()(cudaStream_t stream, - EmbeddingPerDeviceState const *m, + AggrMode aggr, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, GenericTensorAccessorR const &weight, @@ -35,8 +35,8 @@ struct ForwardKernel { assert(weight.data_type == DT_HALF || weight.data_type == DT_FLOAT || weight.data_type == DT_DOUBLE); - if (m->aggr == AGGR_MODE_NONE) { - embed_forward_no_aggr<<<<>>(input.get(), @@ -45,8 +45,8 @@ struct ForwardKernel { out_dim, batch_size); } else { - assert(m->aggr == AGGR_MODE_AVG || m->aggr == AGGR_MODE_SUM); - embed_forward_with_aggr<<<<>>(input.get(), @@ -55,7 +55,7 @@ struct ForwardKernel { out_dim, in_dim, batch_size, - m->aggr); + aggr); } } } @@ -63,7 +63,7 @@ struct ForwardKernel { template struct BackwardKernel { void operator()(cudaStream_t stream, - EmbeddingPerDeviceState const *m, + AggrMode aggr, GenericTensorAccessorR const &input, GenericTensorAccessorR const &output, GenericTensorAccessorW const &weight_grad, @@ -73,8 +73,8 @@ struct BackwardKernel { assert(input.data_type == DT_INT32 || input.data_type == DT_INT64); assert(output.data_type == DT_HALF || output.data_type == DT_FLOAT, || output.data_type == DT_DOUBLE); - if (m->aggr == AGGR_MODE_NONE) { - embed_backward_no_aggr<<<<>>(input.get(), @@ -83,7 +83,7 @@ struct BackwardKernel { out_dim, batch_size); } else { - embed_backward_with_aggr<<<<>>(input.get(), @@ -92,23 +92,25 @@ struct BackwardKernel { out_dim, in_dim, batch_size, - m->aggr); + aggr); } } } -void forward_kernel(cudaStream_t stream, - EmbeddingPerDeviceState const *m, +void forward_kernel(ffStream_t stream, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, GenericTensorAccessorR const &weight, + DataType input_data_type, + DataType output_data_type, + AggrMode aggr, int in_dim, int out_dim, int batch_size) { - DataTypeDispatch2{}(m->input_data_type, - m->output_data_type, + DataTypeDispatch2{}(input_data_type, + output_data_type, stream, - m, + aggr, input, output, weight, @@ -118,17 +120,19 @@ void forward_kernel(cudaStream_t stream, } void backward_kernel(cudaStream_t stream, - EmbeddingPerDeviceState const *m, GenericTensorAccessorR const &input, GenericTensorAccessorR const &output, GenericTensorAccessorW const &weight_grad, + DataType input_data_type, + DataType output_data_type, + AggrMode aggr, int in_dim, int out_dim, int batch_size) { - DataTypeDispatch2{}(m->input_data_type, - m->output_data_type, + DataTypeDispatch2{}(input_data_type, + output_data_type, stream, - m, + aggr, input, output, weight, diff --git a/lib/kernels/src/cuda/layer_norm_kernels.cu b/lib/kernels/src/cuda/layer_norm_kernels.cu index 65d33bec5e..eb9e291750 100644 --- a/lib/kernels/src/cuda/layer_norm_kernels.cu +++ b/lib/kernels/src/cuda/layer_norm_kernels.cu @@ -24,86 +24,73 @@ constexpr int kCUDABlockReduceNumThreads = 512; constexpr int kCUDANumThreads = 256; constexpr int kColwiseReduceTileSize = 32; -LayerNormPerDeviceState::LayerNormPerDeviceState( - FFHandler handle, - bool elementwise_affine_, - int64_t effective_batch_size_, - int64_t effective_num_elements_, - bool profiling_, - float eps_) - : PerDeviceOpState(handle) { - elementwise_affine = elementwise_affine_; - effective_batch_size = effective_batch_size_; - effective_num_elements = effective_num_elements_; - profiling = profiling_; - eps = eps_; - checkCUDA(cudaMalloc(&mean_ptr, sizeof(float) * effective_batch_size)); - checkCUDA(cudaMalloc(&rstd_ptr, sizeof(float) * effective_batch_size)); - checkCUDA(cudaMalloc(&ds_ptr, sizeof(float) * effective_batch_size)); - checkCUDA(cudaMalloc(&db_ptr, sizeof(float) * effective_batch_size)); - checkCUDA(cudaMalloc(&scale_ptr, sizeof(float) * effective_batch_size)); - checkCUDA(cudaMalloc(&bias_ptr, sizeof(float) * effective_batch_size)); -} - namespace Kernels { namespace LayerNorm { +LayerNormPerDeviceState init_kernel(PerDeviceFFHandle handle, + int64_t effective_batch_size) { + float *mean, *rstd, *ds, *db, *scale, *bias; + checkCUDA(cudaMalloc(&mean, sizeof(float) * batch_size)); + checkCUDA(cudaMalloc(&rstd, sizeof(float) * batch_size)); + checkCUDA(cudaMalloc(&ds, sizeof(float) * batch_size)); + checkCUDA(cudaMalloc(&db, sizeof(float) * batch_size)); + checkCUDA(cudaMalloc(&scale, sizeof(float) * batch_size)); + checkCUDA(cudaMalloc(&bias, sizeof(float) * batch_size)); + + LayerNormPerDeviceState per_device_state = {mean, rstd, ds, db, scale, bias}; + return per_device_state; +} + template struct ForwardKernel { void operator()(cudaStream_t stream, - LayerNormPerDeviceState const *m, + LayerNormPerDeviceState const &m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, GenericTensorAccessorW const &gamma, - GenericTensorAccessorW const &beta) { + GenericTensorAccessorW const &beta, + int64_t batch_size, + int64_t num_elements, + float eps) { RowwiseMomentsCUDAKernel - <<effective_batch_size, kCUDABlockReduceNumThreads, 0, stream>>>( - m->effective_num_elements, - m->eps, - input.get(), - m->mean_ptr, - m->rstd_ptr); + <<>>( + num_elements, eps, input.get(), m.mean, m.rstd); LayerNormForwardCUDAKernel - <<effective_batch_size, kCUDANumThreads, 0, stream>>>( - m->effective_num_elements, - input.get(), - m->mean_ptr, - m->rstd_ptr, - gamma.get(), - beta.get(), - output.get()); + <<>>(num_elements, + input.get(), + m.mean, + m.rstd, + gamma.get(), + beta.get(), + output.get()); } } template struct BackwardKernel { void operator()(cudaStream_t stream, - LayerNormPerDeviceState const *m, + LayerNormPerDeviceState const &m, GenericTensorAccessorR const &output_grad, GenericTensorAccessorR const &input, GenericTensorAccessorW const &input_grad, GenericTensorAccessorR const &gamma, GenericTensorAccessorW const &gamma_grad, - GenericTensorAccessorW const &beta_grad) { - const int64_t M = m->effective_batch_size; - const int64_t N = m->effective_num_elements; + GenericTensorAccessorW const &beta_grad, + int64_t batch_size, + int64_t num_elements, + float eps) { + const int64_t M = batch_size; + const int64_t N = num_elements; ComputeInternalGradientsCUDAKernel <<>>(N, output_grad.get(), input.get(), gamma.get(), - m->ds_ptr, - m->db_ptr); + m.ds, + m.db); const int64_t B = (M + kCUDANumThreads - 1) / kCUDANumThreads; - ComputeGradientFusedParamsCUDAKernel - <<>>(M, - N, - m->mean_ptr, - m->rstd_ptr, - m->ds_ptr, - m->db_ptr, - m->scale_ptr, - m->bias_ptr); + ComputeGradientFusedParamsCUDAKernel<<>>( + M, N, m.mean, m.rstd, m.ds, m.db, m.scale, m.bias); if (gamma_grad.get() != NULL || beta_grad.get() != NULL) { if (M < 512) { // For small batch size, do colwise reduce directly @@ -113,8 +100,8 @@ struct BackwardKernel { N, output_grad.get(), input.get(), - m->mean_ptr, - m->rstd_ptr, + m.mean, + m.rstd, gamma_grad.get(), beta_grad.get()); } else { @@ -127,8 +114,8 @@ struct BackwardKernel { N, output_grad.get(), input.get(), - m->mean_ptr, - m->rstd_ptr, + m.mean, + m.rstd, gamma_grad.get(), beta_grad.get()); } @@ -137,24 +124,40 @@ struct BackwardKernel { } void forward_kernel(cudaStream_t stream, - LayerNormPerDeviceState const *m, + LayerNormPerDeviceState const &m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, GenericTensorAccessorW const &gamma, - GenericTensorAccessorW const &beta) { - DataTypeDispatch1{}( - m->data_type, stream, m, input, output, gamma, beta); + GenericTensorAccessorW const &beta, + DataType data_type, + int64_t batch_size, + int64_t num_elements, + float eps) { + DataTypeDispatch1{}(data_type, + stream, + m, + input, + output, + gamma, + beta, + batch_size, + num_elements, + eps); } void backward_kernel(cudaStream_t stream, - LayerNormPerDeviceState const *m, + LayerNormPerDeviceState const &m, GenericTensorAccessorR const &output_grad, GenericTensorAccessorR const &input, GenericTensorAccessorW const &input_grad, GenericTensorAccessorR const &gamma, GenericTensorAccessorW const &gamma_grad, - GenericTensorAccessorW const &beta_grad) { - DataTypeDispatch1{}(m->data_type, + GenericTensorAccessorW const &beta_grad, + DataType data_type, + int64_t batch_size, + int64_t num_elements, + float eps) { + DataTypeDispatch1{}(data_type, stream, m, output_grad, @@ -162,7 +165,10 @@ void backward_kernel(cudaStream_t stream, input_grad, gamma, gamma_grad, - beta_grad); + beta_grad, + batch_size, + num_elements, + eps); } template diff --git a/lib/kernels/src/hip/element_unary_kernels.cpp b/lib/kernels/src/hip/element_unary_kernels.cpp index 58bec1b262..01bb7ca5f9 100644 --- a/lib/kernels/src/hip/element_unary_kernels.cpp +++ b/lib/kernels/src/hip/element_unary_kernels.cpp @@ -18,26 +18,24 @@ #include namespace FlexFlow { +namespace Kernels { +namespace ElementUnary { -// declare Legion names -using Legion::coord_t; -using Legion::Domain; +ElementUnaryPerDeviceState init_kernel(PerDeviceFFHandle const &handle, + ArrayShape const &input_shape, + ArrayShape const &output_shape, + OperatorType op_type, + DataType data_type) { + miopenTensorDescriptor_t inputTensor; + miopenTensorDescriptor_t outputTensor; + miopenActivationDescriptor_t actiDesc; + miopenActivationMode_t mode; -ElementUnaryPerDeviceState::ElementUnaryPerDeviceState(FFHandler handler) - : PerDeviceOpState(handler) { checkCUDNN(miopenCreateTensorDescriptor(&inputTensor)); checkCUDNN(miopenCreateTensorDescriptor(&outputTensor)); checkCUDNN(miopenCreateActivationDescriptor(&actiDesc)); -} -namespace Kernels { -namespace ElementUnary { - -void init_kernel(ElementUnaryPerDeviceState *m, - Domain const &input_domain, - Domain const &output_domain) { - miopenActivationMode_t mode; - switch (m->op_type) { + switch (op_type) { case OP_SIGMOID: mode = miopenActivationLOGISTIC; break; @@ -53,11 +51,16 @@ void init_kernel(ElementUnaryPerDeviceState *m, default: assert(false); } - checkCUDNN(miopenSetActivationDescriptor(m->actiDesc, mode, 0.0, 0.0, 0.0)); - checkCUDNN(cudnnSetTensorDescriptorFromDomain(m->inputTensor, input_domain)); + checkCUDNN(miopenSetActivationDescriptor(actiDesc, mode, 0.0, 0.0, 0.0)); + checkCUDNN(cudnnSetTensorDescriptorFromArrayShape(inputTensor, input_shape)); // input_domain == output_domain checkCUDNN( - cudnnSetTensorDescriptorFromDomain(m->outputTensor, output_domain)); + cudnnSetTensorDescriptorFromArrayShape(outputTensor, output_shape)); + + ElementUnaryPerDeviceState per_device_state = { + handle, inputTensor, outputTensor, actiDesc, op_type, data_type, scalar}; + + return per_device_state; } bool use_cudnn(OperatorType type) { @@ -82,16 +85,16 @@ struct ForwardKernel { ElementUnaryPerDeviceState const *m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { - checkCUDNN(miopenSetStream(m->handle.dnn, stream)); - if (use_cudnn(m->op_type)) { + checkCUDNN(miopenSetStream(m.handle.dnn, stream)); + if (use_cudnn(m.op_type)) { float alpha = 1.0f, beta = 0.0f; - checkCUDNN(miopenActivationForward(m->handle.dnn, - m->actiDesc, + checkCUDNN(miopenActivationForward(m.handle.dnn, + m.actiDesc, &alpha, - m->inputTensor, + m.inputTensor, input.get(), &beta, - m->outputTensor, + m.outputTensor, output.get())); } else { size_t num_elements = input.shape.num_elements(); @@ -101,8 +104,8 @@ struct ForwardKernel { 0, stream, num_elements, - (T)m->scalar, - m->op_type, + (T)m.scalar, + m.op_type, input.get(), output.get()); } @@ -117,22 +120,22 @@ struct BackwardKernel { GenericTensorAccessorR const &input_grad, GenericTensorAccessorW const &output, GenericTensorAccessorW const &output_grad) { - checkCUDNN(miopenSetStream(m->handle.dnn, stream)); + checkCUDNN(miopenSetStream(m.handle.dnn, stream)); - if (use_cudnn(m->op_type)) { + if (use_cudnn(m.op_type)) { float alpha = 1.0f; float beta = 0.0f; - checkCUDNN(miopenActivationBackward(m->handle.dnn, - m->actiDesc, + checkCUDNN(miopenActivationBackward(m.handle.dnn, + m.actiDesc, &alpha, - m->outputTensor, + m.outputTensor, output.get(), - m->outputTensor, + m.outputTensor, output_grad.get()), - m->inputTensor, + m.inputTensor, input.get(), &beta, - m->inputTensor, + m.inputTensor, input_grad.get()); } else { size_t num_elements = input.shape.num_elements(); @@ -142,8 +145,8 @@ struct BackwardKernel { 0, stream, num_elements, - m->scalar, - m->op_type, + m.scalar, + m.op_type, output.get(), output_grad.get(), input.get(), @@ -151,21 +154,19 @@ struct BackwardKernel { } } } void forward_kernel(ffStream_t stream, - ElementUnaryPerDeviceState const *m, + ElementUnaryPerDeviceState const &device_state, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { - { - DataTypeDispatch1{}(m->data_type, stream, m, input, output); - } + { DataTypeDispatch1{}(m.data_type, stream, m, input, output); } void backward_kernel(ffStream_t stream, - ElementUnaryPerDeviceState const *m, + ElementUnaryPerDeviceState const &device_state, GenericTensorAccessorR const &input, - GenericTensorAccessorR const &input_grad, - GenericTensorAccessorW const &output, - GenericTensorAccessorW const &output_grad) + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output, + GenericTensorAccessorR const &output_grad) DataTypeDispatch1{}( - m->data_type, stream, m, input, input_grad, output, output_grad); + m.data_type, stream, m, input, input_grad, output, output_grad); } template diff --git a/lib/kernels/src/hip/embedding_kernels.cpp b/lib/kernels/src/hip/embedding_kernels.cpp index 93bb7276cb..17edfea5c1 100644 --- a/lib/kernels/src/hip/embedding_kernels.cpp +++ b/lib/kernels/src/hip/embedding_kernels.cpp @@ -25,7 +25,7 @@ namespace Embedding { template struct ForwardKernel { void operator()(hipStream_t stream, - EmbeddingPerDeviceState const *m, + AggrMode aggr, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, GenericTensorAccessorR const &weight, @@ -36,9 +36,9 @@ struct ForwardKernel { assert(weight.data_type == DT_HALF || weight.data_type == DT_FLOAT || weight.data_type == DT_DOUBLE); - if (m->aggr == AGGR_MODE_NONE) { + if (aggr == AGGR_MODE_NONE) { hipLaunchKernelGGL(HIP_KERNEL_NAME(embed_forward_no_aggr), - GET_BLOCKS(output.domain.get_volume()), + GET_BLOCKS(output.shape.get_volume()), CUDA_NUM_THREADS, 0, stream, @@ -49,7 +49,7 @@ struct ForwardKernel { batch_size); } else { hipLaunchKernelGGL(HIP_KERNEL_NAME(embed_forward_with_aggr), - GET_BLOCKS(output.domain.get_volume()), + GET_BLOCKS(output.shape.get_volume()), CUDA_NUM_THREADS, 0, stream, @@ -59,7 +59,7 @@ struct ForwardKernel { out_dim, in_dim, batch_size, - m->aggr); + aggr); } } } @@ -67,7 +67,7 @@ struct ForwardKernel { template struct BackwardKernel { void operator()(hipStream_t stream, - EmbeddingPerDeviceState const *m, + AggrMode aggr, GenericTensorAccessorR const &input, GenericTensorAccessorR const &output, GenericTensorAccessorW const &weight_grad, @@ -77,9 +77,9 @@ struct BackwardKernel { assert(input.data_type == DT_INT32 || input.data_type == DT_INT64); assert(output.data_type == DT_HALF || output.data_type == DT_FLOAT, || output.data_type == DT_DOUBLE); - if (m->aggr == AGGR_MODE_NONE) { + if (aggr == AGGR_MODE_NONE) { hipLaunchKernelGGL(HIP_KERNEL_NAME(embed_backward_no_aggr), - GET_BLOCKS(output.domain.get_volume()), + GET_BLOCKS(output.shape.get_volume()), CUDA_NUM_THREADS, 0, stream, @@ -90,7 +90,7 @@ struct BackwardKernel { batch_size); } else { hipLaunchKernelGGL(HIP_KERNEL_NAME(embed_backward_with_aggr), - GET_BLOCKS(output.domain.get_volume()), + GET_BLOCKS(output.shape.get_volume()), CUDA_NUM_THREADS, 0, stream, @@ -100,23 +100,25 @@ struct BackwardKernel { out_dim, in_dim, batch_size, - m->aggr); + aggr); } } } void forward_kernel(hipStream_t stream, - EmbeddingPerDeviceState const *m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, GenericTensorAccessorR const &weight, + DataType input_data_type, + DataType output_data_type, + AggrMode aggr, int in_dim, int out_dim, int batch_size) { - DataTypeDispatch2{}(m->input_data_type, - m->output_data_type, + DataTypeDispatch2{}(input_data_type, + output_data_type, stream, - m, + aggr, input, output, weight, @@ -126,17 +128,19 @@ void forward_kernel(hipStream_t stream, } void backward_kernel(hipStream_t stream, - EmbeddingPerDeviceState const *m, GenericTensorAccessorR const &input, GenericTensorAccessorR const &output, GenericTensorAccessorW const &weight_grad, + DataType input_data_type, + DataType output_data_type, + AggrMode aggr, int in_dim, int out_dim, int batch_size) { - DataTypeDispatch2{}(m->input_data_type, - m->output_data_type, + DataTypeDispatch2{}(input_data_type, + output_data_type, stream, - m, + aggr, input, output, weight, diff --git a/lib/kernels/src/hip/layer_norm_kernels.cpp b/lib/kernels/src/hip/layer_norm_kernels.cpp index dc2685ef28..5cf82b213e 100644 --- a/lib/kernels/src/hip/layer_norm_kernels.cpp +++ b/lib/kernels/src/hip/layer_norm_kernels.cpp @@ -24,57 +24,53 @@ constexpr int kCUDABlockReduceNumThreads = 512; constexpr int kCUDANumThreads = 256; constexpr int kColwiseReduceTileSize = 32; -LayerNormPerDeviceState::LayerNormPerDeviceState( - FFHandler handle, - bool elementwise_affine_, - int64_t effective_batch_size_, - int64_t effective_num_elements_, - bool profiling_, - float eps_) - : PerDeviceOpState(handle) { - elementwise_affine = elementwise_affine_; - effective_batch_size = effective_batch_size_; - effective_num_elements = effective_num_elements_; - profiling = profiling_; - eps = eps_; - checkCUDA(hipMalloc(&mean_ptr, sizeof(float) * effective_batch_size)); - checkCUDA(hipMalloc(&rstd_ptr, sizeof(float) * effective_batch_size)); - checkCUDA(hipMalloc(&ds_ptr, sizeof(float) * effective_batch_size)); - checkCUDA(hipMalloc(&db_ptr, sizeof(float) * effective_batch_size)); - checkCUDA(hipMalloc(&scale_ptr, sizeof(float) * effective_batch_size)); - checkCUDA(hipMalloc(&bias_ptr, sizeof(float) * effective_batch_size)); -} - namespace Kernels { namespace LayerNorm { +LayerNormPerDeviceState init_kernel(PerDeviceFFHandle handle, + int64_t effective_batch_size) { + float *mean, *rstd, *ds, *db, *scale, *bias; + checkCUDA(cudaMalloc(&mean, sizeof(float) * batch_size)); + checkCUDA(cudaMalloc(&rstd, sizeof(float) * batch_size)); + checkCUDA(cudaMalloc(&ds, sizeof(float) * batch_size)); + checkCUDA(cudaMalloc(&db, sizeof(float) * batch_size)); + checkCUDA(cudaMalloc(&scale, sizeof(float) * batch_size)); + checkCUDA(cudaMalloc(&bias, sizeof(float) * batch_size)); + + LayerNormPerDeviceState per_device_state = {mean, rstd, ds, db, scale, bias}; + return per_device_state; +} + template struct ForwardKernel { void operator()(hipStream_t stream, - LayerNormPerDeviceState const *m, + LayerNormPerDeviceState const &m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, GenericTensorAccessorW const &gamma, - GenericTensorAccessorW const &beta) { + GenericTensorAccessorW const &beta, + int64_t batch_size, + int64_t num_elements, + float eps) { hipLaunchKernelGGL(HIP_KERNEL_NAME(RowwiseMomentsCUDAKernel), - m->effective_batch_size, + batch_size, kCUDABlockReduceNumThreads, 0, stream, - m->effective_num_elements, - m->eps, + num_elements, + m.eps, input.get(), - m->mean_ptr, - m->rstd_ptr); + m.mean, + m.rstd); hipLaunchKernelGGL(HIP_KERNEL_NAME(LayerNormForwardCUDAKernel), - m->effective_batch_size, + batch_size, kCUDANumThreads, 0, stream, - m->effective_num_elements, + num_elements, input.get(), - m->mean_ptr, - m->rstd_ptr, + m.mean, + m.rstd, gamma.get(), beta.get(), output.get()); @@ -84,15 +80,18 @@ struct ForwardKernel { template struct BackwardKernel { void operator()(hipStream_t stream, - LayerNormPerDeviceState const *m, + LayerNormPerDeviceState const &m, GenericTensorAccessorR const &output_grad, GenericTensorAccessorR const &input, GenericTensorAccessorW const &input_grad, GenericTensorAccessorR const &gamma, GenericTensorAccessorW const &gamma_grad, - GenericTensorAccessorW const &beta_grad) { - const int64_t M = m->effective_batch_size; - const int64_t N = m->effective_num_elements; + GenericTensorAccessorW const &beta_grad, + int64_t batch_size, + int64_t num_elements, + float eps) { + const int64_t M = batch_size; + const int64_t N = num_elements; hipLaunchKernelGGL(HIP_KERNEL_NAME(ComputeInternalGradientsCUDAKernel), M, kCUDABlockReduceNumThreads, @@ -102,8 +101,8 @@ struct BackwardKernel { output_grad.get(), input.get(), gamma.get(), - m->ds_ptr, - m->db_ptr); + m.ds, + m.db); const int64_t B = (M + kCUDANumThreads - 1) / kCUDANumThreads; hipLaunchKernelGGL(HIP_KERNEL_NAME(ComputeGradientFusedParamsCUDAKernel), B, @@ -112,12 +111,12 @@ struct BackwardKernel { stream, M, N, - m->mean_ptr, - m->rstd_ptr, - m->ds_ptr, - m->db_ptr, - m->scale_ptr, - m->bias_ptr); + m.mean, + m.rstd, + m.ds, + m.db, + m.scale, + m.bias); if (gamma_grad.get() != NULL || beta_grad.get() != NULL) { if (M < 512) { // For small batch size, do colwise reduce directly @@ -132,8 +131,8 @@ struct BackwardKernel { N, output_grad.get(), input.get(), - m->mean_ptr, - m->rstd_ptr, + m.mean, + m.rstd, gamma_grad.get(), beta_grad.get()); } else { @@ -150,8 +149,8 @@ struct BackwardKernel { N, output_grad.get(), input.get(), - m->mean_ptr, - m->rstd_ptr, + m.mean, + m.rstd, gamma_grad.get(), beta_grad.get()); } @@ -159,24 +158,40 @@ struct BackwardKernel { } void forward_kernel(hipStream_t stream, - LayerNormPerDeviceState const *m, + LayerNormPerDeviceState const &m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, GenericTensorAccessorW const &gamma, - GenericTensorAccessorW const &beta) { - DataTypeDispatch1{}( - m->data_type, stream, m, input, output, gamma, beta); + GenericTensorAccessorW const &beta, + DataType data_type, + int64_t batch_size, + int64_t num_elements, + float eps) { + DataTypeDispatch1{}(data_type, + stream, + m, + input, + output, + gamma, + beta, + batch_size, + num_elements, + eps); } void backward_kernel(hipStream_t stream, - LayerNormPerDeviceState const *m, + LayerNormPerDeviceState const &m, GenericTensorAccessorR const &output_grad, GenericTensorAccessorR const &input, GenericTensorAccessorW const &input_grad, GenericTensorAccessorR const &gamma, GenericTensorAccessorW const &gamma_grad, - GenericTensorAccessorW const &beta_grad) { - DataTypeDispatch1{}(m->data_type, + GenericTensorAccessorW const &beta_grad, + DataType data_type, + int64_t batch_size, + int64_t num_elements, + float eps) { + DataTypeDispatch1{}(data_type, stream, m, output_grad, @@ -184,7 +199,10 @@ struct BackwardKernel { input_grad, gamma, gamma_grad, - beta_grad); + beta_grad, + batch_size, + num_elements, + eps); } template diff --git a/lib/runtime/src/ops/element_unary.cc b/lib/runtime/src/ops/element_unary.cc index 07959bd6da..4490b40daa 100644 --- a/lib/runtime/src/ops/element_unary.cc +++ b/lib/runtime/src/ops/element_unary.cc @@ -6,665 +6,296 @@ namespace FlexFlow { // declare Legion names -using Legion::ArgumentMap; using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; using Legion::Runtime; using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; using namespace FlexFlow::Kernels::ElementUnary; -Tensor FFModel::unary(OperatorType op, - const Tensor x, - bool inplace, - char const *name, - float scalar) { - Layer *ele = nullptr; - DataType dtype; - // FIXME: currently cast input to float if it has a lower type - if (x->data_type < DT_FLOAT) { - dtype = DT_FLOAT; - std::string str(name); - Tensor new_x = cast(x, dtype, (str + "input_pre_cast").c_str()); - ele = new Layer(this, - op, - dtype, - name, - 1 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - new_x); - } else { - dtype = x->data_type; - ele = new Layer( - this, op, dtype, name, 1 /*inputs*/, 0 /*weights*/, 1 /*outputs*/, x); - } - int numdims = x->num_dims; - int dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdims; i++) { - dims[i] = x->dims[i]; - } - ele->outputs[0] = create_tensor_legion_ordering( - numdims, dims, dtype, ele, 0, true /*create_grad*/); - ele->add_int_property("inplace", inplace); - ele->add_float_property("scalar", scalar); - layers.push_back(ele); - return ele->outputs[0]; -} +enum Slots { + INPUT, + INPUT_SHAPE, + OUTPUT, + ATTRS, + HANDLE, + PROFILING, + PER_DEVICE_STATE +}; -Op *ElementUnary::create_operator_from_layer( - FFModel &model, - Layer const *layer, - std::vector const &inputs) { - long long value; - layer->get_int_property("inplace", value); - bool inplace = (bool)value; - float scalar; - layer->get_float_property("scalar", scalar); - return new ElementUnary( - model, layer->op_type, inputs[0], inplace, layer->name, scalar); -} +/* ElementUnary */ +OpTaskInvocation init(ElementUnaryAttrs const &attrs) { + OpTaskBinding b; -Tensor FFModel::exp(const Tensor x, char const *name) { - return this->unary(OP_EXP, x, false /*inplace*/, name); -} + b.bind_arg(HANDLE, ff_handle()); + b.bind_arg(ATTRS, attrs); + b.bind_arg(INPUT_SHAPE, input_parallel_tensor_shape(0)); -Tensor FFModel::scalar_multiply(const Tensor x, - float const scalar, - bool inplace, - char const *name) { - return this->unary(OP_SCALAR_MULTIPLY, x, inplace, name, scalar); + return {ELEMENTUNARY_INIT_TASK_ID, b}; } -Tensor FFModel::scalar_add(const Tensor x, - float const scalar, - bool inplace, - char const *name) { - return this->unary(OP_SCALAR_ADD, x, inplace, name, scalar); -} +OpTaskInvocation forward(ElementUnaryAttrs const &attrs) { + OpTaskBinding b; -Tensor FFModel::scalar_sub(const Tensor x, - float const scalar, - bool inplace, - char const *name) { - return this->unary(OP_SCALAR_SUB, x, inplace, name, scalar); -} + b.bind(INPUT, input_tensor(0)); + b.bind(OUTPUT, output_tensor(0)); -Tensor FFModel::scalar_truediv(const Tensor x, - float const scalar, - bool inplace, - char const *name) { - return this->unary(OP_SCALAR_TRUE_DIV, x, inplace, name, scalar); -} + b.bind_arg(PROFILING, profiling_settings()); + b.bind_arg(PER_DEVICE_STATE, + per_device_op_state()); -Tensor FFModel::relu(const Tensor x, bool inplace, char const *name) { - return this->unary(OP_RELU, x, inplace, name); + return {ELEMENTUNARY_FWD_TASK_ID, b}; } -Tensor FFModel::sigmoid(const Tensor x, char const *name) { - return this->unary(OP_SIGMOID, x, false /*inplace*/, name); -} +OpTaskInvocation backward(ElementUnaryAttrs const &attrs) { + OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); -Tensor FFModel::tanh(const Tensor x, char const *name) { - return this->unary(OP_TANH, x, false /*inplace*/, name); + return {ELEMENTUNARY_BWD_TASK_ID, b}; } -Tensor FFModel::identity(const Tensor x, char const *name) { - return this->unary(OP_IDENTITY, x, false /*inplace*/, name); -} +/* ElementScalarUnary */ +OpTaskInvocation init(ElementScalarUnaryAttrs const &attrs) { + OpTaskBinding b; -Tensor FFModel::gelu(const Tensor x, char const *name) { - return this->unary(OP_GELU, x, false /*inplace*/, name); -} + b.bind_arg(HANDLE, ff_handle()); + b.bind_arg(ATTRS, attrs); + b.bind_arg(INPUT_SHAPE, input_parallel_tensor_shape(0)); -Tensor FFModel::elu(const Tensor x, bool inplace, char const *name) { - // Currently assume inplace is false - assert(!inplace); - return this->unary(OP_ELU, x, inplace, name); + return {ELEMENTUNARY_INIT_TASK_ID, b}; } -Tensor FFModel::rsqrt(const Tensor x, bool inplace, char const *name) { - return this->unary(OP_RSQRT, x, inplace, name); -} +OpTaskInvocation forward(ElementScalarUnaryAttrs const &attrs) { + OpTaskBinding b; -Tensor FFModel::pow(const Tensor x, - float const exponent, - bool inplace, - char const *name) { - return this->unary(OP_POW, x, inplace, name, exponent); -} + b.bind(INPUT, input_tensor(0)); + b.bind(OUTPUT, output_tensor(0)); -Tensor FFModel::sin(const Tensor x, char const *name) { - return this->unary(OP_SIN, x, false /*inplace*/, name); -} + b.bind_arg(PROFILING, profiling_settings()); + b.bind_arg(PER_DEVICE_STATE, + per_device_op_state()); -Tensor FFModel::cos(const Tensor x, char const *name) { - return this->unary(OP_COS, x, false /*inplace*/, name); + return {ELEMENTUNARY_FWD_TASK_ID, b}; } -bool ElementUnaryParams::is_valid(ParallelTensorShape const &input) const { - return input.is_valid(); -} +OpTaskInvocation backward(ElementScalarUnaryAttrs const &attrs) { + OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); -bool operator==(ElementUnaryParams const &lhs, ElementUnaryParams const &rhs) { - return lhs.op_type == rhs.op_type && lhs.scalar == rhs.scalar && - lhs.inplace == rhs.inplace; + return {ELEMENTUNARY_BWD_TASK_ID, b}; } -ElementUnary::ElementUnary(FFModel &model, - OperatorType _op_type, - const ParallelTensor x, - bool _inplace, - char const *name, - float _scalar) - : Op(model, - _op_type, - x->data_type, - name, - 1 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - x), - inplace(_inplace), scalar(_scalar) { - numOutputs = 1; - int numdim = x->num_dims; - ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdim; i++) { - dims[i] = x->dims[i]; - } - outputs[0] = model.create_parallel_tensor_legion_ordering( - numdim, dims, inputs[0]->data_type, this); - // Disable inplace if shape mismatch - if (outputs[0]->get_shape() != inputs[0]->get_shape()) { - inplace = false; - } -} +static DeviceSpecific + init_task_impl(TaskArgumentAccessor const &acc) { -ElementUnary::ElementUnary(FFModel &model, - ElementUnaryParams const ¶ms, - const ParallelTensor input, - char const *name) - : ElementUnary( - model, params.op_type, input, params.inplace, name, params.scalar) {} - -void ElementUnary::map_output_tensors(FFModel &ff) { - if (has_inplace_output()) { - assert(numOutputs == 1); - assert(outputs[0]->get_volume() == inputs[0]->get_volume()); - outputs[0]->parallel_is = inputs[0]->parallel_is; - outputs[0]->region = inputs[0]->region; - outputs[0]->part = inputs[0]->part; - outputs[0]->region_grad = inputs[0]->region_grad; - outputs[0]->part_grad = inputs[0]->part_grad; - } else { - Op::map_output_tensors(ff); - } -} + auto const &attrs = acc.get_argument(ATTRS); + ProfilingSettings profiling = acc.get_argument(PROFILING); + PerDeviceFFHandle handle = acc.get_argument(HANDLE); + ParallelTensorShape input_shape = + acc.get_argument(INPUT_SHAPE); + ParallelTensorShape output_shape = get_output_shape(attrs, input_shape); -bool ElementUnary::can_inplace_output(void) { - return outputs[0]->get_shape() == inputs[0]->get_shape(); + DeviceSpecific per_device_state = + acc.create_device_specific( + init_kernel(handle, + {input_shape.dims}, + {output_shape.dims}, + input_shape.data_type)); + return per_device_state; } -bool ElementUnary::has_inplace_output(void) { - return inplace; +static DeviceSpecific + init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + return init_task_impl(acc); } -void ElementUnary::do_inplace_output(void) { - inplace = true; -} +static optional forward_task_impl(TaskArgumentAccessor const &acc) { + auto input = acc.get_tensor(INPUT); + auto output = acc.get_tensor(OUTPUT); -void ElementUnary::init(FFModel const &ff) { - assert(check_output_input_weight_same_parallel_is()); - parallel_is = outputs[0]->parallel_is; - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - set_argumentmap_for_init(ff, argmap); - IndexLauncher init_launcher(ELEMENTUNARY_INIT_TASK_ID, - parallel_is, - TaskArgument(this, sizeof(ElementUnary)), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - outputs[0]->machine_view.hash()); - if (!inplace) { - init_launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - inputs[0]->region)); - init_launcher.add_field(0, FID_DATA); - init_launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - outputs[0]->region)); - init_launcher.add_field(1, FID_DATA); - } else { - init_launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - 0 /*projection id*/, - READ_WRITE, - EXCLUSIVE, - inputs[0]->region)); - init_launcher.add_field(0, FID_DATA); - } - FutureMap fm = runtime->execute_index_space(ctx, init_launcher); - fm.wait_all_results(); - set_opmeta_from_futuremap(ff, fm); -} + ProfilingSettings profiling = acc.get_argument(PROFILING); + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); -PerDeviceOpState * - ElementUnary::init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - ElementUnary *eu = (ElementUnary *)task->args; - FFHandler handle = *((FFHandler *)task->local_args); - ElementUnaryMeta *m = new ElementUnaryMeta(handle); - m->op_type = eu->op_type; - m->data_type = eu->outputs[0]->data_type; - // Input and output should have the same data type - assert(eu->outputs[0]->data_type == eu->inputs[0]->data_type); - m->profiling = eu->profiling; - m->inplace = eu->inplace; - m->scalar = eu->scalar; - std::strcpy(m->op_name, eu->name); - if (m->inplace) { - assert(regions.size() == 1); - assert(task->regions.size() == 1); - } else { - assert(regions.size() == 2); - assert(task->regions.size() == 2); - } - - Domain input_domain = runtime->get_index_space_domain( - ctx, task->regions[0].region.get_index_space()); - init_kernel(m, input_domain, input_domain); - return m; + return profile(forward_kernel, + profiling, + "[ElementUnary] forward_time = %.2lfms\n", + per_device_state, + input, + output); } -void ElementUnary::forward(FFModel const &ff) { - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - set_argumentmap_for_forward(ff, argmap); - IndexLauncher launcher(ELEMENTUNARY_FWD_TASK_ID, - parallel_is, - TaskArgument(NULL, 0), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - outputs[0]->machine_view.hash()); - if (inplace) { - assert(outputs[0]->part == inputs[0]->part); - assert(outputs[0]->region == inputs[0]->region); - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection id*/, - READ_WRITE, - EXCLUSIVE, - outputs[0]->region)); - launcher.add_field(0, FID_DATA); - } else { - launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - inputs[0]->region)); - launcher.add_field(0, FID_DATA); - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - outputs[0]->region)); - launcher.add_field(1, FID_DATA); - } - runtime->execute_index_space(ctx, launcher); +static void forward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + forward_task_impl(acc); } -void ElementUnary::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - ElementUnaryMeta const *m = *((ElementUnaryMeta **)task->local_args); - if (m->data_type == DT_FLOAT) { - forward_task_with_type(task, regions, ctx, runtime); - } else if (m->data_type == DT_DOUBLE) { - forward_task_with_type(task, regions, ctx, runtime); - } else if (m->data_type == DT_INT32) { - forward_task_with_type(task, regions, ctx, runtime); - } else if (m->data_type == DT_INT64) { - forward_task_with_type(task, regions, ctx, runtime); - } else { - assert(false && "Unsupported data type in Embedding forward"); - } -} +static optional backward_task_impl(TaskArgumentAccessor const &acc) { + auto input = acc.get_tensor(INPUT); + auto input_grad = acc.get_tensor_grad(INPUT); + auto output = acc.get_tensor(OUTPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); + + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + ProfilingSettings profiling = acc.get_argument(PROFILING); -/* - regions[0](I): input - regions[1](O): output -*/ -template -void ElementUnary::forward_task_with_type( - Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - // const ElementUnary* ele = (const ElementUnary*) task->args; - ElementUnaryMeta const *m = *((ElementUnaryMeta **)task->local_args); - Domain input_domain = runtime->get_index_space_domain( - ctx, task->regions[0].region.get_index_space()); - const DT *input_ptr = NULL; - DT *output_ptr = NULL; - if (m->inplace) { - assert(regions.size() == 1); - assert(task->regions.size() == 1); - output_ptr = helperGetTensorPointerRW
( - regions[0], task->regions[0], FID_DATA, ctx, runtime); - input_ptr = output_ptr; - } else { - assert(regions.size() == 2); - assert(task->regions.size() == 2); - Domain output_domain = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); - assert(output_domain == input_domain); - input_ptr = helperGetTensorPointerRO
( - regions[0], task->regions[0], FID_DATA, ctx, runtime); - output_ptr = helperGetTensorPointerWO
( - regions[1], task->regions[1], FID_DATA, ctx, runtime); - } - - forward_kernel_wrapper
( - m, input_ptr, output_ptr, input_domain.get_volume()); + return profile(backward_kernel, + profiling, + "[ElementUnary] backward_time = %.2lfms\n", + per_device_state, + input, + input_grad, + output, + output_grad); } -void ElementUnary::backward(FFModel const &ff) { - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - set_argumentmap_for_backward(ff, argmap); - IndexLauncher launcher(ELEMENTUNARY_BWD_TASK_ID, - parallel_is, - TaskArgument(NULL, 0), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - outputs[0]->machine_view.hash()); - if (inplace) { - assert(inputs[0]->part == outputs[0]->part); - assert(inputs[0]->part_grad == outputs[0]->part_grad); - // regions[2](I): output_grad - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - outputs[0]->region)); - launcher.add_field(0, FID_DATA); - // regions[3](I): output_grad - launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, - 0 /*projection id*/, - READ_WRITE, - EXCLUSIVE, - outputs[0]->region_grad)); - launcher.add_field(1, FID_DATA); - } else { - // regions[0](I): input - launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - inputs[0]->region)); - launcher.add_field(0, FID_DATA); - // regions[1](I/O): input_grad - launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, - 0 /*projection id*/, - READ_WRITE, - EXCLUSIVE, - inputs[0]->region_grad)); - launcher.add_field(1, FID_DATA); - // regions[2](I): output_grad - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - outputs[0]->region)); - launcher.add_field(2, FID_DATA); - // regions[3](I): output_grad - launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - outputs[0]->region_grad)); - launcher.add_field(3, FID_DATA); - } - runtime->execute_index_space(ctx, launcher); +static void backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + backward_task_impl(acc); } -void ElementUnary::backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - ElementUnaryMeta const *m = *((ElementUnaryMeta **)task->local_args); - if (m->data_type == DT_FLOAT) { - backward_task_with_type(task, regions, ctx, runtime); - } else if (m->data_type == DT_DOUBLE) { - backward_task_with_type(task, regions, ctx, runtime); - } else if (m->data_type == DT_INT32) { - backward_task_with_type(task, regions, ctx, runtime); - } else if (m->data_type == DT_INT64) { - backward_task_with_type(task, regions, ctx, runtime); - } else { - assert(false && "Unsupported data type in Embedding forward"); - } +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + ElementUnaryAttrs const &attrs, + InputParallelTensorDesc const &input_shape, + ProfilingSettings const &settings, + MachineView const &mv) { + auto env = sim.new_environment(); + + ParallelTensorShape output_shape = get_output_shape(attrs, input_shape); + + SimTaskBinding init_binding; + init_binding.bind_arg(HANDLE, ff_handle()); + init_binding.bind_arg(ATTRS, attrs); + init_binding.bind_arg(INPUT_SHAPE, input_parallel_tensor_shape(0)); + + auto init_accessor = + env.get_init_accessor(ELEMENTUNARY_INIT_TASK_ID, init_binding); + DeviceSpecific per_device_state = + init_task_impl(init_accessor); + + SimTaskBinding fwd_binding; + fwd_binding.bind(INPUT, input_shape); + fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind_arg(PROFILING, settings); + fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); + + SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); + + auto fwd_accessor = + env.get_fwd_accessor(ELEMENTUNARY_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = + env.get_bwd_accessor(ELEMENTUNARY_BWD_TASK_ID, bwd_binding); + + float forward_time = forward_task_impl(fwd_accessor).value(); + float backward_time = backward_task_impl(bwd_accessor).value(); + + float sync_time = default_estimate_sync_time(env); + return make_metrics(forward_time, backward_time, sync_time, env); } -/* - regions[0](I): input - regions[1](I/O): input_grad - regions[2](I): output - regions[3](I): output_grad -*/ -template -void ElementUnary::backward_task_with_type( - Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - // const ElementUnary* ele = (const ElementUnary*) task->args; - ElementUnaryMeta const *m = *((ElementUnaryMeta **)task->local_args); - const DT *input_ptr = NULL, *output_ptr = NULL, *output_grad_ptr = NULL; - DT *input_grad_ptr = NULL; - Domain input_domain = runtime->get_index_space_domain( - ctx, task->regions[0].region.get_index_space()); - if (m->inplace) { - assert(regions.size() == 2); - assert(task->regions.size() == 2); - Domain input_grad_domain = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); - assert(input_grad_domain == input_domain); - input_ptr = helperGetTensorPointerRO
( - regions[0], task->regions[0], FID_DATA, ctx, runtime); - input_grad_ptr = helperGetTensorPointerRW
( - regions[1], task->regions[1], FID_DATA, ctx, runtime); - output_ptr = input_ptr; - output_grad_ptr = input_grad_ptr; - } else { - assert(regions.size() == 4); - assert(task->regions.size() == 4); - Domain input_grad_domain = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); - Domain output_domain = runtime->get_index_space_domain( - ctx, task->regions[2].region.get_index_space()); - Domain output_grad_domain = runtime->get_index_space_domain( - ctx, task->regions[3].region.get_index_space()); - assert(output_grad_domain == input_domain); - assert(output_grad_domain == output_domain); - assert(output_grad_domain == input_grad_domain); - input_ptr = helperGetTensorPointerRO
( - regions[0], task->regions[0], FID_DATA, ctx, runtime); - input_grad_ptr = helperGetTensorPointerRW
( - regions[1], task->regions[1], FID_DATA, ctx, runtime); - output_ptr = helperGetTensorPointerRO
( - regions[2], task->regions[2], FID_DATA, ctx, runtime); - output_grad_ptr = helperGetTensorPointerRO
( - regions[3], task->regions[3], FID_DATA, ctx, runtime); - } - - backward_kernel_wrapper
(m, - input_ptr, - input_grad_ptr, - output_ptr, - output_grad_ptr, - input_domain.get_volume()); +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + ElementScalarUnaryAttrs const &attrs, + InputParallelTensorDesc const &input_shape, + ProfilingSettings const &settings, + MachineView const &mv) { + auto env = sim.new_environment(); + + ParallelTensorShape output_shape = get_output_shape(attrs, input_shape); + + SimTaskBinding init_binding; + init_binding.bind_arg(HANDLE, ff_handle()); + init_binding.bind_arg(ATTRS, attrs); + init_binding.bind_arg(INPUT_SHAPE, input_parallel_tensor_shape(0)); + + auto init_accessor = + env.get_init_accessor(ELEMENTUNARY_INIT_TASK_ID, init_binding); + DeviceSpecific per_device_state = + init_task_impl(init_accessor); + + SimTaskBinding fwd_binding; + fwd_binding.bind(INPUT, input_shape); + fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind_arg(PROFILING, settings); + fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); + + SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); + + auto fwd_accessor = + env.get_fwd_accessor(ELEMENTUNARY_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = + env.get_bwd_accessor(ELEMENTUNARY_BWD_TASK_ID, bwd_binding); + + float forward_time = forward_task_impl(fwd_accessor).value(); + float backward_time = backward_task_impl(bwd_accessor).value(); + + float sync_time = default_estimate_sync_time(env); + return make_metrics(forward_time, backward_time, sync_time, env); } -void ElementUnary::serialize(Legion::Serializer &sez) const { - sez.serialize(this->op_type); - sez.serialize(this->inplace); - sez.serialize(scalar); +template <> +OpTaskSignature init_signature() { + OpTaskSignature init(OpTaskType::INIT); + init.add_arg_slot(INPUT_SHAPE); + init.add_arg_slot(ATTRS); + init.add_unchecked_arg_slot(HANDLE); + + init.add_return_value(); + + return init; } -bool ElementUnary::measure_operator_cost(Simulator *sim, - MachineView const &mv, - CostMetrics &cost_metrics) const { - ParallelTensorBase sub_output, sub_input; - if (!outputs[0]->get_sub_tensor(mv, sub_output)) { - return false; - } - if (!inputs[0]->get_sub_tensor(mv, sub_input)) { - return false; - } - ElementUnaryMeta *m = sim->ele_unary_meta; - m->op_type = op_type; - Domain input_domain, output_domain; - input_domain.dim = sub_input.num_dims; - for (int i = 0; i < sub_input.num_dims; i++) { - input_domain.rect_data[i] = 0; - input_domain.rect_data[i + input_domain.dim] = sub_input.dims[i].size - 1; - } - output_domain.dim = sub_output.num_dims; - for (int i = 0; i < sub_output.num_dims; i++) { - output_domain.rect_data[i] = 0; - output_domain.rect_data[i + input_domain.dim] = sub_output.dims[i].size - 1; - } - init_kernel(m, input_domain, output_domain); - sim->free_all(); - float *input_ptr = - (float *)sim->allocate(sub_input.get_volume(), inputs[0]->data_type); - assert(input_ptr != NULL); - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *output_ptr = NULL; - if (inplace) { - output_ptr = input_ptr; - } else { - output_ptr = - (float *)sim->allocate(sub_output.get_volume(), outputs[0]->data_type); - } - assert(output_ptr != NULL); - cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - assert(m->profiling == false); - - std::function forward, backward; - forward = [&] { - forward_kernel_wrapper(m, input_ptr, output_ptr, sub_output.get_volume()); - }; - if (sim->computationMode == COMP_MODE_TRAINING) { - float *input_grad_ptr = - (float *)sim->allocate(sub_input.get_volume(), inputs[0]->data_type); - assert(input_grad_ptr != NULL); - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *output_grad_ptr = NULL; - if (inplace) { - output_grad_ptr = input_grad_ptr; - } else { - output_grad_ptr = (float *)sim->allocate(sub_output.get_volume(), - outputs[0]->data_type); - } - assert(output_grad_ptr != NULL); - cost_metrics.outputs_memory += - cost_metrics.total_mem_diff_from(sim->offset); - - backward = [&] { - backward_kernel_wrapper(m, - input_ptr, - input_grad_ptr, - output_ptr, - output_grad_ptr, - sub_output.get_volume()); - }; - } - - inner_measure_operator_cost(sim, forward, backward, cost_metrics); - - if (sim->computationMode == COMP_MODE_TRAINING) { - log_measure.debug("[Measure Elewise Unary] name(%s) num_elements(%zu) " - "forward_time(%.4lf) backward_time(%.4lf)\n", - name, - sub_output.get_volume(), - cost_metrics.forward_time, - cost_metrics.backward_time); - } else { - log_measure.debug("[Measure Elewise Unary] name(%s) num_elements(%zu) " - "forward_time(%.4lf)\n", - name, - sub_output.get_volume(), - cost_metrics.forward_time); - } - return true; +template <> +void register_task() { + register_task(ELEMENTUNARY_INIT_TASK_ID, + "MultiHeadAttention Init", + init_signature(), + init_task); } -using PCG::Node; -/*static*/ -Node ElementUnary::deserialize(FFModel &ff, - Legion::Deserializer &dez, - ParallelTensor inputs[], - int num_inputs) { - assert(num_inputs == 1); - OperatorType op_type; - float scalar; - bool inplace; - dez.deserialize(op_type); - dez.deserialize(inplace); - dez.deserialize(scalar); - - ElementUnaryParams params; - params.op_type = op_type; - params.inplace = inplace; - params.scalar = scalar; - return ff.get_or_create_node(inputs[0], params); +template <> +OpTaskSignature fwd_signature() { + OpTaskSignature fwd(OpTaskType::FWD); + + fwd.add_input_slot(INPUT); + fwd.add_output_slot(OUTPUT); + + fwd.add_arg_slot(PROFILING); + fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); + + return fwd; } -Op *ElementUnary::materialize(FFModel &ff, - ParallelTensor inputs[], - int num_inputs) const { - assert(num_inputs == 1); - return new ElementUnary( - ff, this->op_type, inputs[0], this->inplace, this->name, this->scalar); +template <> +void register_task() { + register_task(ELEMENTUNARY_FWD_TASK_ID, + "MultiHeadAttention Fwd", + fwd_signature(), + forward_task); } -}; // namespace FlexFlow +template <> +OpTaskSignature bwd_signature() { + OpTaskSignature bwd = + infer_bwd_signature(fwd_signature()); -namespace std { -size_t hash::operator()( - FlexFlow::ElementUnaryParams const ¶ms) const { - size_t key = 0; - hash_combine(key, params.op_type); - hash_combine(key, params.scalar); - hash_combine(key, params.inplace); - return key; + return bwd; } -}; // namespace std + +template <> +void register_task() { + register_task(ELEMENTUNARY_BWD_TASK_ID, + "MultiHeadAttention Bwd", + bwd_signature(), + backward_task); +} + +} // namespace FlexFlow diff --git a/lib/runtime/src/ops/element_unary.h b/lib/runtime/src/ops/element_unary.h index ae661f1177..aa393821bd 100644 --- a/lib/runtime/src/ops/element_unary.h +++ b/lib/runtime/src/ops/element_unary.h @@ -24,76 +24,16 @@ OpTaskInvocation backward(ElementScalarUnaryAttrs const &); CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, ElementUnaryAttrs const &attrs, - ParallelTensorShape const &input_shape, + InputParallelTensorDesc const &input_shape, ProfilingSettings const &settings, MachineView const &machine_view); CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, ElementScalarUnaryAttrs const &attrs, - ParallelTensorShape const &input_shape, + InputParallelTensorDesc const &input_shape, ProfilingSettings const &settings, MachineView const &machine_view); -/* class ElementUnary : public Op { */ -/* public: */ -/* ElementUnary(FFModel &model, */ -/* OperatorType type, */ -/* const ParallelTensor x, */ -/* bool inplace, */ -/* char const *name, */ -/* float scalar); */ -/* void init(FFModel const &) override; */ -/* void forward(FFModel const &) override; */ -/* void backward(FFModel const &) override; */ -/* void map_output_tensors(FFModel &model) override; */ -/* bool can_inplace_output() override; */ -/* bool has_inplace_output() override; */ -/* void do_inplace_output() override; */ -/* static Op * */ -/* create_operator_from_layer(FFModel &model, */ -/* Layer const *layer, */ -/* std::vector const &inputs); - */ - -/* static PerDeviceOpState *init_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void forward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void backward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* template */ -/* static void */ -/* forward_task_with_type(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* template */ -/* static void backward_task_with_type( */ -/* Legion::Task const *task, */ -/* std::vector const ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* bool measure_operator_cost(Simulator *sim, */ -/* MachineView const &pc, */ -/* CostMetrics &cost_metrics) const override; */ - -/* private: */ -/* bool inplace; */ - -/* public: */ -/* float scalar; */ -/* }; */ - } // namespace FlexFlow #endif diff --git a/lib/runtime/src/ops/embedding.cc b/lib/runtime/src/ops/embedding.cc index 281ad9bc26..a1bc915d2f 100644 --- a/lib/runtime/src/ops/embedding.cc +++ b/lib/runtime/src/ops/embedding.cc @@ -14,1187 +14,165 @@ */ #include "embedding.h" -#include "utils/hash_utils.h" +#include "kernels/embedding_kernels.h" +#include "legion.h" +#include "op-attrs/ops/embedding.h" namespace FlexFlow { // declare Legion names -using Legion::ArgumentMap; using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; -using Legion::InlineLauncher; using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; using Legion::Runtime; using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; using namespace FlexFlow::Kernels::Embedding; -Tensor FFModel::embedding(const Tensor input, - int num_entries, - int out_dim, - AggrMode aggr, - DataType dtype, - Layer const *shared_op, - Initializer *kernel_initializer, - char const *name) { - Layer *embed = new Layer(this, - OP_EMBEDDING, - dtype, - name, - 1 /*inputs*/, - 1 /*weights*/, - 1 /*outputs*/, - input); - if (aggr == AGGR_MODE_NONE) { - int numdims = input->num_dims + 1; - int dims[MAX_TENSOR_DIM]; - for (int i = 1; i < numdims; i++) { - dims[i] = input->dims[i - 1]; - } - dims[0] = out_dim; - embed->outputs[0] = create_tensor_legion_ordering( - numdims, dims, embed->data_type, embed, 0, true /*create_grad*/); - } else { - int numdims = input->num_dims; - int dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdims; i++) { - dims[i] = input->dims[i]; - } - dims[0] = out_dim; - embed->outputs[0] = create_tensor_legion_ordering( - numdims, dims, embed->data_type, embed, 0, true /*create_grad*/); - } - { - int dims[2] = {out_dim, num_entries}; - embed->weights[0] = create_weight_legion_ordering(2, - dims, - dtype, - embed, - true /*create_grad*/, - kernel_initializer, - CHOSEN_SYNC_TYPE); - } - embed->data_type = dtype; - embed->add_int_property("num_entries", num_entries); - embed->add_int_property("out_dim", out_dim); - embed->add_int_property("aggr_mode", aggr); - embed->add_initializer("kernel", kernel_initializer); - layers.push_back(embed); - return embed->outputs[0]; -} - -EmbeddingParams Embedding::get_params() const { - EmbeddingParams params; - params.num_entries = this->num_entries; - params.out_channels = this->out_channels; - params.aggr = this->aggr; - params.data_type = this->data_type; - // TODO: get rid of layer_guid - // https://github.com/flexflow/FlexFlow/issues/304 - params.layer_guid = this->layer_guid; - return params; -} - -Op *Embedding::create_operator_from_layer( - FFModel &model, - Layer const *layer, - std::vector const &inputs) { - long long value; - layer->get_int_property("num_entries", value); - int num_entries = value; - layer->get_int_property("out_dim", value); - int out_dim = value; - layer->get_int_property("aggr_mode", value); - AggrMode aggr = (AggrMode)value; - Initializer *kernel_initializer; - layer->get_initializer("kernel", kernel_initializer); - return new Embedding(model, - layer->layer_guid, - inputs[0], - num_entries, - out_dim, - aggr, - false /*allocate_weights*/, - layer->data_type, - layer->name); -} - -int Embedding::input_vocab_size_replica_dim() const { - return this->inputs[0]->num_dims - 1; -} - -int Embedding::input_channel_out_replica_dim() const { - return this->inputs[0]->num_dims - 2; -} - -int Embedding::output_vocab_size_replica_dim() const { - assert(this->outputs[0] != nullptr); - return this->outputs[0]->num_dims - 1; -} - -int Embedding::output_size(ParallelDim output_dims[MAX_TENSOR_DIM]) { - ParallelTensor const &input = this->inputs[0]; - - int const OUT_CHANNELS = Output::OUT_CHANNELS; - if (aggr == AGGR_MODE_NONE) { - int num_dims = input->num_dims + 1; - for (int i = 1; i < num_dims - 1; i++) { - output_dims[i] = input->dims[i - 1]; - } - assert(OUT_CHANNELS == 0); - output_dims[OUT_CHANNELS].size = this->out_channels; - output_dims[OUT_CHANNELS].degree = 1; - output_dims[OUT_CHANNELS].parallel_idx = -1; - // Currently do not support parallelizing over the replica dim - output_dims[num_dims - 1].size = 1; - output_dims[num_dims - 1].degree = 1; - output_dims[num_dims - 1].parallel_idx = -1; - output_dims[num_dims - 1].is_replica_dim = true; - return num_dims; - } else { - int num_dims = input->num_dims; - for (int i = 1; i < num_dims - 1; i++) { - output_dims[i] = input->dims[i]; - } - assert(OUT_CHANNELS == 0); - output_dims[OUT_CHANNELS].size = this->out_channels; - output_dims[OUT_CHANNELS].degree = 1; - output_dims[OUT_CHANNELS].parallel_idx = -1; - // Currently do not support parallelizing over the replica dim - output_dims[num_dims - 1].size = 1; - output_dims[num_dims - 1].degree = 1; - output_dims[num_dims - 1].parallel_idx = -1; - output_dims[num_dims - 1].is_replica_dim = true; - return num_dims; - } - // const int REPLICA = this->output_vocab_size_replica_dim(); -} +enum Slots { INPUT, WEIGHT, OUTPUT, ATTRS, PROFILING }; -int Embedding::weight_size(ParallelDim weight_dims[MAX_TENSOR_DIM]) { - ParallelTensor const &input = this->inputs[0]; +OpTaskInvocation forward(EmbeddingAttrs const &attrs) { + OpTaskBinding b; - weight_dims[Weight::OUT_CHANNELS].size = this->out_channels; - weight_dims[Weight::OUT_CHANNELS].degree = 1; - weight_dims[Weight::OUT_CHANNELS].parallel_idx = -1; - weight_dims[Weight::VOCAB_SIZE].size = this->num_entries; - weight_dims[Weight::VOCAB_SIZE].degree = 1; - weight_dims[Weight::VOCAB_SIZE].parallel_idx = -1; - for (int i = 2; i < input->num_dims; i++) { - weight_dims[i].size = input->dims[i - 1].degree; - weight_dims[i].degree = weight_dims[i].size; - weight_dims[i].parallel_idx = input->dims[i - 1].parallel_idx; - weight_dims[i].is_replica_dim = true; - } - return input->num_dims; -} + b.bind(INPUT, input_tensor(0)); + b.bind(WEIGHT, weight_tensor(0)); + b.bind(OUTPUT, output_tensor(0)); -void Embedding::register_output_mappings() { - if (aggr == AGGR_MODE_NONE) { - int num_dims = this->inputs[0]->num_dims + 1; - for (int i = 1; i < num_dims - 1; i++) { - this->register_output_parallel_dims(i - 1, i); - } - } else { - int num_dims = this->inputs[0]->num_dims; - for (int i = 1; i < num_dims - 1; i++) { - this->register_output_parallel_dims(i, i); - } - } -} + b.bind_arg(ATTRS, attrs); + b.bind_arg(PROFILING, profiling_settings()); -void Embedding::register_weight_mappings() { - for (int i = 2; i < this->inputs[0]->num_dims; i++) { - this->register_weight_parallel_dims(i - 1, i); - } + return {EMBED_FWD_TASK_ID, b}; } -void Embedding::register_mappings() { - this->register_output_mappings(); - this->register_weight_mappings(); -} +OpTaskInvocation backward(EmbeddingAttrs const &attrs) { + OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); -/* Params */ - -bool operator==(EmbeddingParams const &lhs, EmbeddingParams const &rhs) { - return lhs.layer_guid == rhs.layer_guid && - lhs.out_channels == rhs.out_channels && - lhs.num_entries == rhs.num_entries && lhs.aggr == rhs.aggr && - lhs.data_type == rhs.data_type; + return {EMBED_BWD_TASK_ID, b}; } -Embedding::Embedding(FFModel &model, - EmbeddingParams const ¶ms, - ParallelTensor const input, - bool allocate_weights, - char const *name) - : Embedding(model, - params.layer_guid, - input, - params.num_entries, - params.out_channels, - params.aggr, - allocate_weights, - params.data_type, - name) {} - -Embedding::Embedding(FFModel &model, - Embedding const &other, - const ParallelTensor input, - bool allocate_weights) - : Embedding(model, - other.layer_guid, - input, - other.num_entries, - other.out_channels, - other.aggr, - allocate_weights, - other.data_type, - other.name) {} - -Embedding::Embedding(FFModel &model, - LayerID const &_layer_guid, - const ParallelTensor _input, - int _num_entries, - int _out_channels, - AggrMode _aggr, - bool allocate_weights, - DataType dtype, - char const *name) - : Op(model, - OP_EMBEDDING, - dtype, - name, - 1 /*inputs*/, - 1 /*weights*/, - allocate_weights, - 1 /*outputs*/, - _input), - num_entries(_num_entries), out_channels(_out_channels), aggr(_aggr) { - layer_guid = _layer_guid; - std::vector weight_dim_sets; - - int weight_ndim; - ParallelDim weight_dims[MAX_TENSOR_DIM]; - if (allocate_weights) { - weight_ndim = this->weight_size(weight_dims); - weight_dim_sets.push_back(weight_dims); - } - - ParallelDim output_dims[MAX_TENSOR_DIM]; - int output_ndim = this->output_size(output_dims); - - // register mappings between inputs/weights and outputs - this->register_mappings(); - - this->solve_parallel_dim_mappings( - {_input->dims}, weight_dim_sets, {output_dims}); - - if (allocate_weights) { - Initializer *weight_initializer = new GlorotUniform(std::rand() /*seed*/); - // Initializer *weight_initializer = new ZeroInitializer(/*seed*/); - - weights[0] = - model.create_parallel_weight_legion_ordering(weight_ndim, - weight_dims, - dtype, - nullptr /*owner_op*/, - true /*create_grad*/, - weight_initializer, - CHOSEN_SYNC_TYPE); - } +static optional forward_task_impl(TaskArgumentAccessor const &acc) { + auto input = acc.get_tensor(INPUT); + auto weight = acc.get_tensor(WEIGHT); + auto output = acc.get_tensor(OUTPUT); - outputs[0] = model.create_parallel_tensor_legion_ordering( - output_ndim, output_dims, dtype, this); + ProfilingSettings profiling = acc.get_argument(PROFILING); + EmbeddingAttrs attrs = acc.get_argument(ATTRS); - assert(check_output_input_weight_parallel_dims(allocate_weights)); + return profile(forward_kernel, + profiling, + "[Embedding] forward_time = %.2lfms\n", + input, + output, + weight, + input.data_type, + output.data_type, + attrs.aggr, + input.shape.get_dim(), + output.shape.get_dim(), + input.shape[legion_dim_t(1)]); } -void Embedding::init(FFModel const &ff) { - assert(check_output_input_weight_same_parallel_is()); - parallel_is = outputs[0]->parallel_is; - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - set_argumentmap_for_init(ff, argmap); - IndexLauncher launcher(EMBED_INIT_TASK_ID, - parallel_is, - TaskArgument(this, sizeof(Embedding)), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - outputs[0]->machine_view.hash()); - // regions[0]: input - // launcher.add_region_requirement( - // RegionRequirement(input_lps[0], 0/*projection*/, - // READ_ONLY, EXCLUSIVE, inputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // regions[1]: output - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection*/, - WRITE_ONLY, - EXCLUSIVE, - outputs[0]->region)); - launcher.add_field(0, FID_DATA); - // regions[2]: weight - launcher.add_region_requirement(RegionRequirement(weights[0]->part, - 0 /*projection*/, - READ_ONLY, - EXCLUSIVE, - weights[0]->region)); - launcher.add_field(1, FID_DATA); - // regions[3]: input_grad - launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, - 0 /*projection*/, - WRITE_ONLY, - EXCLUSIVE, - inputs[0]->region_grad)); - launcher.add_field(2, FID_DATA); - FutureMap fm = runtime->execute_index_space(ctx, launcher); - fm.wait_all_results(); - set_opmeta_from_futuremap(ff, fm); -} - -PerDeviceOpState * - Embedding::init_task(Task const *task, +static void forward_task(Task const *task, std::vector const ®ions, Context ctx, Runtime *runtime) { - Embedding const *embed = (Embedding *)task->args; - FFHandler handle = *((FFHandler const *)task->local_args); - EmbeddingMeta *m = new EmbeddingMeta(handle, embed); - m->profiling = embed->profiling; - m->aggr = embed->aggr; - return m; -} - -void Embedding::forward(FFModel const &ff) { - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - set_argumentmap_for_forward(ff, argmap); - IndexLauncher launcher(EMBED_FWD_TASK_ID, - parallel_is, - TaskArgument(NULL, 0), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - outputs[0]->machine_view.hash()); - // regions[0]: input - launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - 0 /*projection*/, - READ_ONLY, - EXCLUSIVE, - inputs[0]->region)); - launcher.add_field(0, FID_DATA); - // regions[1]: output - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection*/, - WRITE_ONLY, - EXCLUSIVE, - outputs[0]->region, - MAP_TO_ZC_MEMORY)); - launcher.add_field(1, FID_DATA); - // regions[2]: weight - launcher.add_region_requirement(RegionRequirement(weights[0]->part, - 0 /*projection*/, - READ_ONLY, - EXCLUSIVE, - weights[0]->region)); - launcher.add_field(2, FID_DATA); - runtime->execute_index_space(ctx, launcher); -} - -/* - regions[0](I): input - regions[1](O): output - regions[2](I): kernel -*/ -void Embedding::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - EmbeddingMeta const *m = *((EmbeddingMeta **)task->local_args); - assert(regions.size() == 3); - assert(task->regions.size() == 3); - // Assert that weight and output must have the same data type - // otherwise, a cast operator should be inserted - assert(m->weight_type[0] == m->output_type[0]); - assert(m->input_type[0] == DT_INT32 || m->input_type[0] == DT_INT64); - GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( - m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); - GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( - m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); - GenericTensorAccessorR kernel = helperGetGenericTensorAccessorRO( - m->weight_type[0], regions[2], task->regions[2], FID_DATA, ctx, runtime); - if (m->aggr == AGGR_MODE_NONE) { - // assert(kernel_domain.get_dim() == 2); - assert(input.domain.get_dim() + 1 == output.domain.get_dim()); - for (size_t i = 0; i < input.domain.get_dim(); i++) { - assert(input.domain.hi()[i] == output.domain.hi()[i + 1]); - assert(input.domain.lo()[i] == output.domain.lo()[i + 1]); - } - assert(kernel.domain.hi()[0] - kernel.domain.lo()[0] == - output.domain.hi()[0] - output.domain.lo()[0]); - } else { - // assert(kernel_domain.get_dim() == 2); - assert(input.domain.get_dim() == output.domain.get_dim()); - for (size_t i = 1; i < input.domain.get_dim(); i++) { - assert(input.domain.hi()[i] == output.domain.hi()[i]); - assert(input.domain.lo()[i] == output.domain.lo()[i]); - } - assert(kernel.domain.hi()[0] - kernel.domain.lo()[0] == - output.domain.hi()[0] - output.domain.lo()[0]); - } - - int in_dim, out_dim, effective_batch_size; - if (m->aggr == AGGR_MODE_NONE) { - in_dim = 1; - out_dim = output.domain.hi()[0] - output.domain.lo()[0] + 1; - effective_batch_size = output.domain.get_volume() / out_dim; - assert(effective_batch_size * in_dim == input.domain.get_volume()); - } else { - in_dim = input.domain.hi()[0] - input.domain.lo()[0] + 1; - out_dim = output.domain.hi()[0] - output.domain.lo()[0] + 1; - effective_batch_size = output.domain.get_volume() / out_dim; - assert(effective_batch_size * in_dim == input.domain.get_volume()); - } - forward_kernel_wrapper( - m, input, output, kernel, in_dim, out_dim, effective_batch_size); + TaskArgumentAccessor acc(task, regions, ctx, runtime); + forward_task_impl(acc); } -#ifdef DEADCODE -template -void Embedding::forward_task_with_type( - Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 3); - assert(task->regions.size() == 3); - // const Embedding* embed = (Embedding*) task->args; - EmbeddingMeta const *m = *((EmbeddingMeta **)task->local_args); - Domain input_domain = runtime->get_index_space_domain( - ctx, task->regions[0].region.get_index_space()); - Domain output_domain = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); - Domain kernel_domain = runtime->get_index_space_domain( - ctx, task->regions[2].region.get_index_space()); - if (m->aggr == AGGR_MODE_NONE) { - // assert(kernel_domain.get_dim() == 2); - assert(input_domain.get_dim() + 1 == output_domain.get_dim()); - for (size_t i = 0; i < input_domain.get_dim(); i++) { - assert(input_domain.hi()[i] == output_domain.hi()[i + 1]); - assert(input_domain.lo()[i] == output_domain.lo()[i + 1]); - } - assert(kernel_domain.hi()[0] - kernel_domain.lo()[0] == - output_domain.hi()[0] - output_domain.lo()[0]); - } else { - // assert(kernel_domain.get_dim() == 2); - assert(input_domain.get_dim() == output_domain.get_dim()); - for (size_t i = 1; i < input_domain.get_dim(); i++) { - assert(input_domain.hi()[i] == output_domain.hi()[i]); - assert(input_domain.lo()[i] == output_domain.lo()[i]); - } - assert(kernel_domain.hi()[0] - kernel_domain.lo()[0] == - output_domain.hi()[0] - output_domain.lo()[0]); - } - const TI *input_ptr = helperGetTensorPointerRO( - regions[0], task->regions[0], FID_DATA, ctx, runtime); - float *output_ptr = helperGetTensorPointerWO( - regions[1], task->regions[1], FID_DATA, ctx, runtime); - float const *kernel_ptr = helperGetTensorPointerRO( - regions[2], task->regions[2], FID_DATA, ctx, runtime); +static optional backward_task_impl(TaskArgumentAccessor const &acc) { + auto input = acc.get_tensor(INPUT); + auto output = acc.get_tensor(OUTPUT); + auto weight_grad = acc.get_tensor_grad(WEIGHT); - int in_dim, out_dim, effective_batch_size; - if (m->aggr == AGGR_MODE_NONE) { - in_dim = 1; - out_dim = output_domain.hi()[0] - output_domain.lo()[0] + 1; - effective_batch_size = output_domain.get_volume() / out_dim; - assert(effective_batch_size * in_dim == input_domain.get_volume()); - } else { - in_dim = input_domain.hi()[0] - input_domain.lo()[0] + 1; - out_dim = output_domain.hi()[0] - output_domain.lo()[0] + 1; - effective_batch_size = output_domain.get_volume() / out_dim; - assert(effective_batch_size * in_dim == input_domain.get_volume()); - } + ProfilingSettings profiling = acc.get_argument(PROFILING); + EmbeddingAttrs attrs = acc.get_argument(ATTRS); - forward_kernel_wrapper(m, - input_ptr, - output_ptr, - kernel_ptr, - in_dim, - out_dim, - effective_batch_size, - m->aggr, - output_domain.get_volume()); + return profile(backward_kernel, + profiling, + "[Embedding] forward_time = %.2lfms\n", + input, + output, + weight_grad, + input.data_type, + output.data_type, + attrs.aggr, + input.shape.get_dim(), + output.shape.get_dim(), + input.shape[ff_dim_t(0)]); } -#endif -void Embedding::backward(FFModel const &ff) { - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - set_argumentmap_for_backward(ff, argmap); - IndexLauncher launcher(EMBED_BWD_TASK_ID, - parallel_is, - TaskArgument(NULL, 0), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - outputs[0]->machine_view.hash()); - // regions[0]: input - launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - 0 /*projection*/, - READ_ONLY, - EXCLUSIVE, - inputs[0]->region)); - launcher.add_field(0, FID_DATA); - // regions[1]: output_grad - launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, - 0 /*projection*/, - READ_ONLY, - EXCLUSIVE, - outputs[0]->region_grad)); - launcher.add_field(1, FID_DATA); - // regions[2]: weight_grad - launcher.add_region_requirement(RegionRequirement(weights[0]->part_grad, - 0 /*projection*/, - READ_WRITE, - EXCLUSIVE, - weights[0]->region_grad)); - launcher.add_field(2, FID_DATA); - runtime->execute_index_space(ctx, launcher); +static void backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + backward_task_impl(acc); } -void Embedding::backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - EmbeddingMeta const *m = *((EmbeddingMeta **)task->local_args); - assert(regions.size() == 3); - assert(task->regions.size() == 3); - // Assert that weight and output must have the same data type - // otherwise, a cast operator should be inserted - assert(m->weight_type[0] == m->output_type[0]); - assert(m->input_type[0] == DT_INT32 || m->input_type[0] == DT_INT64); - GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( - m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); - GenericTensorAccessorR output_grad = helperGetGenericTensorAccessorRO( - m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); - GenericTensorAccessorW kernel_grad = helperGetGenericTensorAccessorRW( - m->weight_type[0], regions[2], task->regions[2], FID_DATA, ctx, runtime); - if (m->aggr == AGGR_MODE_NONE) { - // assert(kernel_grad_domain.get_dim() == 2); - assert(input.domain.get_dim() + 1 == output_grad.domain.get_dim()); - for (size_t i = 0; i < input.domain.get_dim(); i++) { - assert(input.domain.hi()[i] == output_grad.domain.hi()[i + 1]); - assert(input.domain.lo()[i] == output_grad.domain.lo()[i + 1]); - } - assert(kernel_grad.domain.hi()[0] - kernel_grad.domain.lo()[0] == - output_grad.domain.hi()[0] - output_grad.domain.lo()[0]); - } else { - // assert(kernel_grad_domain.get_dim() == 2); - assert(input.domain.get_dim() == output_grad.domain.get_dim()); - for (size_t i = 1; i < input.domain.get_dim(); i++) { - assert(input.domain.hi()[i] == output_grad.domain.hi()[i]); - assert(input.domain.lo()[i] == output_grad.domain.lo()[i]); - } - assert(kernel_grad.domain.hi()[0] - kernel_grad.domain.lo()[0] == - output_grad.domain.hi()[0] - output_grad.domain.lo()[0]); - } - int in_dim, out_dim, effective_batch_size; - if (m->aggr == AGGR_MODE_NONE) { - in_dim = 1; - out_dim = output_grad.domain.hi()[0] - output_grad.domain.lo()[0] + 1; - effective_batch_size = output_grad.domain.get_volume() / out_dim; - assert(effective_batch_size * in_dim == input.domain.get_volume()); - } else { - in_dim = input.domain.hi()[0] - input.domain.lo()[0] + 1; - out_dim = output_grad.domain.hi()[0] - output_grad.domain.lo()[0] + 1; - effective_batch_size = output_grad.domain.get_volume() / out_dim; - assert(effective_batch_size * in_dim == input.domain.get_volume()); - } - backward_kernel_wrapper(m, - input, - output_grad, - kernel_grad, - in_dim, - out_dim, - effective_batch_size); -} - -#ifdef DEADCODE -template -void Embedding::backward_task_with_type( - Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 3); - assert(task->regions.size() == 3); - // const Embedding* embed = (Embedding*) task->args; - EmbeddingMeta const *m = *((EmbeddingMeta **)task->local_args); - Domain input_domain = runtime->get_index_space_domain( - ctx, task->regions[0].region.get_index_space()); - Domain output_grad_domain = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); - Domain kernel_grad_domain = runtime->get_index_space_domain( - ctx, task->regions[2].region.get_index_space()); - if (m->aggr == AGGR_MODE_NONE) { - // assert(kernel_grad_domain.get_dim() == 2); - assert(input_domain.get_dim() + 1 == output_grad_domain.get_dim()); - for (size_t i = 0; i < input_domain.get_dim(); i++) { - assert(input_domain.hi()[i] == output_grad_domain.hi()[i + 1]); - assert(input_domain.lo()[i] == output_grad_domain.lo()[i + 1]); - } - assert(kernel_grad_domain.hi()[0] - kernel_grad_domain.lo()[0] == - output_grad_domain.hi()[0] - output_grad_domain.lo()[0]); - } else { - // assert(kernel_grad_domain.get_dim() == 2); - assert(input_domain.get_dim() == output_grad_domain.get_dim()); - for (size_t i = 1; i < input_domain.get_dim(); i++) { - assert(input_domain.hi()[i] == output_grad_domain.hi()[i]); - assert(input_domain.lo()[i] == output_grad_domain.lo()[i]); - } - assert(kernel_grad_domain.hi()[0] - kernel_grad_domain.lo()[0] == - output_grad_domain.hi()[0] - output_grad_domain.lo()[0]); - } - const TI *input_ptr = helperGetTensorPointerRO( - regions[0], task->regions[0], FID_DATA, ctx, runtime); - float const *output_grad_ptr = helperGetTensorPointerWO( - regions[1], task->regions[1], FID_DATA, ctx, runtime); - float *kernel_grad_ptr = helperGetTensorPointerRW( - regions[2], task->regions[2], FID_DATA, ctx, runtime); - - int in_dim, out_dim, effective_batch_size; - if (m->aggr == AGGR_MODE_NONE) { - in_dim = 1; - out_dim = output_grad_domain.hi()[0] - output_grad_domain.lo()[0] + 1; - effective_batch_size = output_grad_domain.get_volume() / out_dim; - assert(effective_batch_size * in_dim == input_domain.get_volume()); - } else { - in_dim = input_domain.hi()[0] - input_domain.lo()[0] + 1; - out_dim = output_grad_domain.hi()[0] - output_grad_domain.lo()[0] + 1; - effective_batch_size = output_grad_domain.get_volume() / out_dim; - assert(effective_batch_size * in_dim == input_domain.get_volume()); - } - backward_kernel_wrapper(m, - input_ptr, - output_grad_ptr, - kernel_grad_ptr, - in_dim, - out_dim, - effective_batch_size, - m->aggr, - output_grad_domain.get_volume()); -} -#endif - -bool Embedding::measure_operator_cost(Simulator *sim, - MachineView const &mv, - CostMetrics &cost_metrics) const { - ParallelTensorBase sub_input, sub_output; - if (!outputs[0]->get_sub_tensor(mv, sub_output)) { - return false; - } - if (!inputs[0]->get_sub_tensor(mv, sub_input)) { - return false; - } - - EmbeddingMeta *m = new EmbeddingMeta(sim->handler, this); - assert(m->profiling == false); - m->aggr = this->aggr; +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + EmbeddingAttrs const &attrs, + InputParallelTensorDesc const &input, + ProfilingSettings const &settings, + MachineView const &mv) { + auto env = sim.new_environment(); - sim->free_all(); - bool out_of_memory = false; - Domain in_domain = sub_input.get_domain(); - void *input_ptr = sim->allocate(sub_input.get_volume(), inputs[0]->data_type); - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - GenericTensorAccessorW input_acc(inputs[0]->data_type, in_domain, input_ptr); + ParallelTensorShape output_shape = get_output_shape(attrs, input.shape); + TensorShape weight_shape = + get_weights_shape(attrs, get_piece_shape(input.shape)); - out_of_memory = out_of_memory || (input_ptr == NULL); - Domain out_domain = sub_output.get_domain(); - void *output_ptr = - sim->allocate(sub_output.get_volume(), outputs[0]->data_type); - out_of_memory = out_of_memory || (output_ptr == NULL); - cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - GenericTensorAccessorW output_acc( - outputs[0]->data_type, out_domain, output_ptr); + SimTaskBinding fwd_binding; + fwd_binding.bind(INPUT, input.shape); + fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind(WEIGHT, weight_shape); + fwd_binding.bind_arg(PROFILING, settings); + fwd_binding.bind_arg(ATTRS, attrs); - Domain weight_domain; - weight_domain.dim = 2; - weight_domain.rect_data[0] = 0; - weight_domain.rect_data[1] = 0; - weight_domain.rect_data[2] = num_entries - 1; - weight_domain.rect_data[3] = out_channels - 1; + SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); - void *weight_ptr = sim->allocate(num_entries * out_channels, this->data_type); - cost_metrics.weights_memory += cost_metrics.total_mem_diff_from(sim->offset); - out_of_memory = out_of_memory || (weight_ptr == NULL); - GenericTensorAccessorR weight_acc(this->data_type, weight_domain, weight_ptr); - if (out_of_memory) { - cost_metrics.forward_time = Simulator::MAXIMUM_TASK_RUN_TIME; - cost_metrics.backward_time = Simulator::MAXIMUM_TASK_RUN_TIME; - return true; - } + auto fwd_accessor = env.get_fwd_accessor(EMBED_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = env.get_bwd_accessor(EMBED_BWD_TASK_ID, bwd_binding); - int in_dim = this->aggr == AGGR_MODE_NONE ? 1 : sub_input.dims[0].size; - int out_dim = sub_output.dims[0].size; - int effective_batch_size = sub_output.get_volume() / out_dim; - assert(effective_batch_size * in_dim == sub_input.get_volume()); + float forward_time = forward_task_impl(fwd_accessor).value(); + float backward_time = backward_task_impl(bwd_accessor).value(); - // Randomly initialize the intput tensor to avoid out of index range issues - if (inputs[0]->data_type == DT_INT32) { - rand_generate_int32_wrapper( - input_acc.get_int32_ptr(), sub_input.get_volume(), num_entries); - } else if (inputs[0]->data_type == DT_INT64) { - rand_generate_int64_wrapper( - input_acc.get_int64_ptr(), sub_input.get_volume(), num_entries); - } - - std::function forward, backward; - forward = [&] { - forward_kernel_wrapper(m, - input_acc, - output_acc, - weight_acc, - in_dim, - out_dim, - effective_batch_size); - }; - if (sim->computationMode == COMP_MODE_TRAINING) { - void *weight_grad_ptr = - sim->allocate(num_entries * out_channels, this->data_type); - cost_metrics.weights_memory += - cost_metrics.total_mem_diff_from(sim->offset); - out_of_memory = out_of_memory || (weight_grad_ptr == NULL); - GenericTensorAccessorW weight_grad_acc( - this->data_type, weight_domain, weight_grad_ptr); - - void *output_grad_ptr = - sim->allocate(sub_output.get_volume(), outputs[0]->data_type); - cost_metrics.outputs_memory += - cost_metrics.total_mem_diff_from(sim->offset); - out_of_memory = out_of_memory || (output_grad_ptr == NULL); - GenericTensorAccessorR output_grad_acc( - outputs[0]->data_type, out_domain, output_grad_ptr); - - void *input_grad_ptr = - sim->allocate(sub_input.get_volume(), inputs[0]->data_type); - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - out_of_memory = out_of_memory || (input_grad_ptr == NULL); - GenericTensorAccessorW input_grad_acc( - inputs[0]->data_type, in_domain, input_grad_ptr); - - if (out_of_memory) { - cost_metrics.forward_time = Simulator::MAXIMUM_TASK_RUN_TIME; - cost_metrics.backward_time = Simulator::MAXIMUM_TASK_RUN_TIME; - return true; - } - backward = [&] { - backward_kernel_wrapper(m, - input_grad_acc, - output_grad_acc, - weight_grad_acc, - in_dim, - out_dim, - effective_batch_size); - }; - } - - inner_measure_operator_cost(sim, forward, backward, cost_metrics); - - if (sim->computationMode == COMP_MODE_TRAINING) { - printf("[Measure Embedding] name(%s) forward_time(%.4lf) " - "backward_time(%.4lf)\n", - name, - cost_metrics.forward_time, - cost_metrics.backward_time); - } else { - printf("[Measure Embedding] name(%s) forward_time(%.4lf)\n", - name, - cost_metrics.forward_time); - } - delete m; - return true; + float sync_time = default_estimate_sync_time(env); + return make_metrics(forward_time, backward_time, sync_time, env); } -void EmbeddingLookup_int64_t_float_float__avx2_fma(int const block_size, - int const output_size, - int const index_size, - int const data_size, - float const *input, - int64_t const *indices, - int const *lengths, - float const *weight, - bool normalize_by_lengths, - float *out) { -#ifdef FF_USE_AVX2 - const int64_t prefdist_T0 = 16; - if (block_size == 128) { - // unrolling 16 times - int64_t dataInd = 0; - for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { - float *op = &out[rangeIndex * block_size]; - __m256 vop0 = _mm256_setzero_ps(); - __m256 vop8 = _mm256_setzero_ps(); - __m256 vop16 = _mm256_setzero_ps(); - __m256 vop24 = _mm256_setzero_ps(); - __m256 vop32 = _mm256_setzero_ps(); - __m256 vop40 = _mm256_setzero_ps(); - __m256 vop48 = _mm256_setzero_ps(); - __m256 vop56 = _mm256_setzero_ps(); - __m256 vop64 = _mm256_setzero_ps(); - __m256 vop72 = _mm256_setzero_ps(); - __m256 vop80 = _mm256_setzero_ps(); - __m256 vop88 = _mm256_setzero_ps(); - __m256 vop96 = _mm256_setzero_ps(); - __m256 vop104 = _mm256_setzero_ps(); - __m256 vop112 = _mm256_setzero_ps(); - __m256 vop120 = _mm256_setzero_ps(); - for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; - ++dataInd) { - const int64_t idx = indices[dataInd]; - float wgt = 1.f; - if (weight) { - wgt = weight[dataInd]; - } - __m256 vwgt = _mm256_set1_ps(wgt); - float const *ip = &input[idx * block_size]; - const int64_t next_T0 = (dataInd < index_size - prefdist_T0) - ? (dataInd + prefdist_T0) - : dataInd; - const int64_t idx_pref_T0 = indices[next_T0]; - assert(idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && - idx_pref_T0 < data_size); - float const *ip_next_T0 = &input[idx_pref_T0 * block_size]; - vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); - vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); - _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0); - vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); - _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); - vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); - _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0); - vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); - _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); - vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); - _mm_prefetch((&ip_next_T0[40]), _MM_HINT_T0); - vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); - _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); - vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); - _mm_prefetch((&ip_next_T0[56]), _MM_HINT_T0); - vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64); - _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0); - vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72); - _mm_prefetch((&ip_next_T0[72]), _MM_HINT_T0); - vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80); - _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0); - vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88); - _mm_prefetch((&ip_next_T0[88]), _MM_HINT_T0); - vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96); - _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0); - vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104); - _mm_prefetch((&ip_next_T0[104]), _MM_HINT_T0); - vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112); - _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0); - vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120); - _mm_prefetch((&ip_next_T0[120]), _MM_HINT_T0); - } - if (normalize_by_lengths == false) { - _mm256_storeu_ps(&op[0], vop0); - _mm256_storeu_ps(&op[8], vop8); - _mm256_storeu_ps(&op[16], vop16); - _mm256_storeu_ps(&op[24], vop24); - _mm256_storeu_ps(&op[32], vop32); - _mm256_storeu_ps(&op[40], vop40); - _mm256_storeu_ps(&op[48], vop48); - _mm256_storeu_ps(&op[56], vop56); - _mm256_storeu_ps(&op[64], vop64); - _mm256_storeu_ps(&op[72], vop72); - _mm256_storeu_ps(&op[80], vop80); - _mm256_storeu_ps(&op[88], vop88); - _mm256_storeu_ps(&op[96], vop96); - _mm256_storeu_ps(&op[104], vop104); - _mm256_storeu_ps(&op[112], vop112); - _mm256_storeu_ps(&op[120], vop120); - } else if (lengths[rangeIndex]) { - __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); - _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); - _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); - _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); - _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); - _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); - _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); - _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); - _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); - _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); - _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); - _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); - _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); - _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); - _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); - _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); - _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); - } - } - __m256 vwgt = _mm256_set1_ps(wgt); - float const *ip = &input[idx * block_size]; - const int64_t next_T0 = (dataInd < index_size - prefdist_T0) - ? (dataInd + prefdist_T0) - : dataInd; - const int64_t idx_pref_T0 = indices[next_T0]; - assert(idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && - idx_pref_T0 < data_size); - float const *ip_next_T0 = &input[idx_pref_T0 * block_size]; - vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); - vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); - _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0); - vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); - _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); - vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); - _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0); - } - if (normalize_by_lengths == false) { - _mm256_storeu_ps(&op[0], vop0); - _mm256_storeu_ps(&op[8], vop8); - _mm256_storeu_ps(&op[16], vop16); - _mm256_storeu_ps(&op[24], vop24); - } else if (lengths[rangeIndex]) { - __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); - _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); - _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); - _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); - _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); - } -} -} -else { - // generic code - int64_t dataInd = 0; - for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { - float *op = &out[rangeIndex * block_size]; - int j = 0; - for (; j + 8 <= block_size; j += 8) { - _mm256_storeu_ps(op + j, _mm256_setzero_ps()); - } - for (; j < block_size; j++) { - op[j] = 0.0f; - } - for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; - ++dataInd) { - const int64_t idx = indices[dataInd]; - float wgt = 1.f; - if (weight) { - wgt = weight[dataInd]; - } - __m256 vwgt = _mm256_set1_ps(wgt); - float const *ip = &input[idx * block_size]; - const int64_t next_T0 = (dataInd < index_size - prefdist_T0) - ? (dataInd + prefdist_T0) - : dataInd; - const int64_t idx_pref_T0 = indices[next_T0]; - assert(idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && - idx_pref_T0 < data_size); - float const *ip_next_T0 = &input[idx_pref_T0 * block_size]; - j = 0; - for (; j + 8 <= block_size; j += 8) { - _mm256_storeu_ps(&op[j], - _mm256_fmadd_ps(vwgt, - _mm256_loadu_ps(&ip[j]), - _mm256_loadu_ps(&op[j]))); - _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0); - } - for (; j < block_size; j++) { - op[j] += wgt * ip[j]; - } - } - if (normalize_by_lengths && lengths[rangeIndex]) { - float len_inv = 1.0f / lengths[rangeIndex]; - __m256 vlen_inv = _mm256_set1_ps(len_inv); - j = 0; - for (; j + 8 <= block_size; j += 8) { - _mm256_storeu_ps(&op[j], - _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); - } - for (; j < block_size; j++) { - op[j] = len_inv * op[j]; - } - } - } -} -#else - assert(0); -#endif -} +template <> +OpTaskSignature fwd_signature() { + OpTaskSignature fwd(OpTaskType::FWD); -void embed_forward(int64_t const *input, - int const *lengths, - float *output, - float const *embed, - int block_size, - int output_size, - int index_size, - int data_size) { - EmbeddingLookup_int64_t_float_float__avx2_fma(block_size, - output_size, - index_size, - data_size, - embed, - input, - lengths, - nullptr, - false, - output); -} + fwd.add_input_slot(INPUT); + fwd.add_input_slot(OUTPUT); + fwd.add_input_slot(WEIGHT); -void embed_backward_generic(int64_t const *input, - int const *lengths, - float const *output, - float *embed, - int block_size, - int output_size, - int index_size, - int data_size) { - // FIXME: Not functionaly correct. - for (int i = 0; i < output_size * block_size; i++) { - int idx = i / block_size; - int off = i % block_size; - int64_t wordIdx = input[idx]; - // FIXME: Need to be atomic depending on the strategy - embed[wordIdx * block_size + off] += output[i]; - ; - } -} + fwd.add_arg_slot(ATTRS); + fwd.add_arg_slot(PROFILING); -void embed_backward(int64_t const *input, - int const *lengths, - float const *output, - float *embed, - int block_size, - int output_size, - int index_size, - int data_size) { - embed_backward_generic(input, - lengths, - output, - embed, - block_size, - output_size, - index_size, - data_size); + return fwd; } -void Embedding::forward_task_cpu(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 3); - assert(task->regions.size() == 3); - // const Embedding* embed = (Embedding*) task->args; - AccessorRO const acc_input(regions[0], FID_DATA); - AccessorWO const acc_output(regions[1], FID_DATA); - AccessorRO const acc_weight(regions[2], FID_DATA); - Rect<2> rect_input = runtime->get_index_space_domain( - ctx, task->regions[0].region.get_index_space()); - Rect<2> rect_output = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); - Rect<2> rect_weight = runtime->get_index_space_domain( - ctx, task->regions[2].region.get_index_space()); - coord_t batch_size = rect_input.hi[1] - rect_input.lo[1] + 1; - // Input and output have same batch size - assert(batch_size == rect_output.hi[1] - rect_output.lo[1] + 1); - coord_t out_dim = rect_output.hi[0] - rect_output.lo[0] + 1; - // Weight and output have same out dim - assert(out_dim == rect_weight.hi[1] - rect_weight.lo[1] + 1); - // const int64_t* input = acc_input.ptr(rect_input); - // float* output = acc_output.ptr(rect_output); - // const float* weight = acc_weight.ptr(rect_weight); - int block_size = out_dim; - int output_size = batch_size; - int data_size = 1000000; // FIXME - // For now we are assuming the length is always 1 - int index_size = rect_input.hi[1] - rect_input.lo[1] + 1; - coord_t in_dim = rect_input.hi[0] - rect_input.lo[0] + 1; - assert(in_dim == 1); - std::vector lengths(output_size, 1); - embed_forward(acc_input.ptr(rect_input), - lengths.data(), - acc_output.ptr(rect_output), - acc_weight.ptr(rect_weight), - block_size, - output_size, - index_size, - data_size); +template <> +void register_task() { + register_task(EMBED_FWD_TASK_ID, + "Embed Fwd", + fwd_signature(), + forward_task); } -void Embedding::backward_task_cpu(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 3); - assert(task->regions.size() == 3); - // const Embedding* embed = (Embedding*) task->args; - AccessorRO const acc_input(regions[0], FID_DATA); - AccessorRO const acc_output(regions[1], FID_DATA); - AccessorRW const acc_weight(regions[2], FID_DATA); - Rect<2> rect_input = runtime->get_index_space_domain( - ctx, task->regions[0].region.get_index_space()); - Rect<2> rect_output = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); - Rect<2> rect_weight = runtime->get_index_space_domain( - ctx, task->regions[2].region.get_index_space()); - coord_t batch_size = rect_input.hi[1] - rect_input.lo[1] + 1; - // Input and output have same batch size - assert(batch_size == rect_output.hi[1] - rect_output.lo[1] + 1); - // coord_t in_dim = rect_input.hi[0] - rect_input.lo[0] + 1; - coord_t out_dim = rect_output.hi[0] - rect_output.lo[0] + 1; - // Weight and output have same out dim - assert(out_dim == rect_weight.hi[1] - rect_weight.lo[1] + 1); - // const int64_t* input = acc_input.ptr(rect_input); - // const float* output = acc_output.ptr(rect_output); - // float* weight = acc_weight.ptr(rect_weight); - int block_size = out_dim; - int output_size = batch_size; - int index_size = rect_input.hi[1] - rect_input.lo[0] + 1; - int data_size = 1000000; // FIXME - std::vector lengths(output_size, 1); - embed_backward(acc_input.ptr(rect_input), - lengths.data(), - acc_output.ptr(rect_output), - acc_weight.ptr(rect_weight), - block_size, - output_size, - index_size, - data_size); +template <> +OpTaskSignature bwd_signature() { + OpTaskSignature bwd = infer_bwd_signature(fwd_signature()); + return bwd; } -EmbeddingMeta::EmbeddingMeta(FFHandler _handle, Op const *op) - : PerDeviceOpState(_handle, op) {} +template <> +void register_task() { + register_task(EMBED_BWD_TASK_ID, + "Embed Bwd", + bwd_signature(), + backward_task); } -; // namespace FlexFlow -namespace std { -size_t hash::operator()( - FlexFlow::EmbeddingParams const ¶ms) const { - size_t key = 0; - hash_combine(key, params.layer_guid.id); - hash_combine(key, params.out_channels); - hash_combine(key, params.aggr); - hash_combine(key, params.num_entries); - hash_combine(key, params.data_type); - return key; -} -}; // namespace std +} // namespace FlexFlow diff --git a/lib/runtime/src/ops/embedding.h b/lib/runtime/src/ops/embedding.h index 0496d93dd9..cd1b14fa66 100644 --- a/lib/runtime/src/ops/embedding.h +++ b/lib/runtime/src/ops/embedding.h @@ -2,124 +2,25 @@ #define _FLEXFLOW_EMBEDDING_H #include "op-attrs/ops/embedding.h" -#include "op_task_invocation.h" #include "sim_environment.h" +#include "task_spec/op_task_invocation.h" namespace FlexFlow { -template <> -void register_task(); template <> void register_task(); template <> void register_task(); -OpTaskInvocation init(EmbeddingAttrs const &); OpTaskInvocation forward(EmbeddingAttrs const &); OpTaskInvocation backward(EmbeddingAttrs const &); CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, EmbeddingAttrs const &attrs, - ParallelTensorShape const &input_shape, + InputParallelTensorDesc const &input_shape, ProfilingSettings const &settings, MachineView const &machine_view); -/* namespace Weight { */ -/* enum { */ -/* OUT_CHANNELS = 0, */ -/* VOCAB_SIZE = 1, */ -/* }; */ -/* }; */ - -/* namespace Output { */ -/* enum { OUT_CHANNELS = 0 }; */ -/* }; */ - -/* class Embedding; */ - -/* class Embedding : public Op { */ -/* public: */ -/* using Attrs = EmbeddingAttrs; */ - -/* Embedding(FFModel &model, */ -/* LayerID const &_layer_guid, */ -/* const ParallelTensor _input, */ -/* int _num_entries, */ -/* int _out_channels, */ -/* AggrMode _aggr, */ -/* bool allocate_weights, */ -/* DataType _dtype, */ -/* char const *name); */ -/* Embedding(FFModel &model, */ -/* Embedding const &other, */ -/* const ParallelTensor input, */ -/* bool allocate_weights); */ -/* Embedding(FFModel &model, */ -/* Attrs const ¶ms, */ -/* std::vector const &input, */ -/* bool allocate_weights = false, */ -/* char const *name = nullptr); */ -/* void init(FFModel const &) override; */ -/* void forward(FFModel const &) override; */ -/* void backward(FFModel const &) override; */ -/* // void update(const FFModel&); */ -/* // Parameter* get_parameter(int index); */ -/* // void create_weights(FFModel& model); */ -/* // void create_input_partition(FFModel& model); */ -/* static Op * */ -/* create_operator_from_layer(FFModel &model, */ -/* Layer const *layer, */ -/* std::vector const &inputs); - */ - -/* static PerDeviceOpState *init_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void forward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void backward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void */ -/* forward_task_cpu(Legion::Task const *task, */ -/* std::vector const ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void */ -/* backward_task_cpu(Legion::Task const *task, */ -/* std::vector const ®ions, - */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ - -/* bool measure_operator_cost(Simulator *sim, */ -/* MachineView const &pc, */ -/* CostMetrics &cost_metrics) const override; */ - -/* private: */ -/* int input_vocab_size_replica_dim() const; */ -/* int input_channel_out_replica_dim() const; */ -/* int output_vocab_size_replica_dim() const; */ - -/* int output_size(ParallelDim output_dims[MAX_TENSOR_DIM]); */ -/* int weight_size(ParallelDim weights_dims[MAX_TENSOR_DIM]); */ - -/* void register_mappings(); */ -/* void register_output_mappings(); */ -/* void register_weight_mappings(); */ - -/* public: */ -/* int num_entries, out_channels; */ -/* AggrMode aggr; */ -/* }; */ - } // namespace FlexFlow #endif