From 9f8ba469fd529f86d11f697a71d373435628b232 Mon Sep 17 00:00:00 2001 From: Reyna Abhyankar Date: Mon, 1 Jan 2024 11:46:16 -0800 Subject: [PATCH] Embedding --- .../include/kernels/embedding_kernels.h | 16 +- lib/kernels/src/cuda/embedding_kernels.cu | 44 +- lib/kernels/src/hip/embedding_kernels.cpp | 40 +- lib/runtime/src/ops/embedding.cc | 1238 ++--------------- lib/runtime/src/ops/embedding.h | 103 +- 5 files changed, 162 insertions(+), 1279 deletions(-) diff --git a/lib/kernels/include/kernels/embedding_kernels.h b/lib/kernels/include/kernels/embedding_kernels.h index 9d70fd9a79..34b892c17e 100644 --- a/lib/kernels/include/kernels/embedding_kernels.h +++ b/lib/kernels/include/kernels/embedding_kernels.h @@ -5,29 +5,25 @@ #include "kernels/device.h" namespace FlexFlow { - -class EmbeddingPerDeviceState : public PerDeviceOpState { -public: - EmbeddingPerDeviceState(FFHandler handle); - DataType input_data_type, output_data_type; - AggrMode aggr; -}; - namespace Kernels { namespace Embedding { void forward_kernel(ffStream_t stream, - EmbeddingPerDeviceState const *m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, GenericTensorAccessorR const &weight, + DataType input_data_type, + DataType output_data_type, + AggrMode aggr, int in_dim, int out_dim, int batch_size); void backward_kernel(ffStream_t stream, - EmbeddingPerDeviceState const *m, GenericTensorAccessorR const &input, GenericTensorAccessorR const &output, GenericTensorAccessorW const &weight_grad, + DataType input_data_type, + DataType output_data_type, + AggrMode aggr, int in_dim, int out_dim, int batch_size); diff --git a/lib/kernels/src/cuda/embedding_kernels.cu b/lib/kernels/src/cuda/embedding_kernels.cu index b97d74d010..9d3cca66a0 100644 --- a/lib/kernels/src/cuda/embedding_kernels.cu +++ b/lib/kernels/src/cuda/embedding_kernels.cu @@ -24,7 +24,7 @@ namespace Embedding { template struct ForwardKernel { void operator()(cudaStream_t stream, - EmbeddingPerDeviceState const *m, + AggrMode aggr, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, GenericTensorAccessorR const &weight, @@ -35,8 +35,8 @@ struct ForwardKernel { assert(weight.data_type == DT_HALF || weight.data_type == DT_FLOAT || weight.data_type == DT_DOUBLE); - if (m->aggr == AGGR_MODE_NONE) { - embed_forward_no_aggr<<<<>>(input.get(), @@ -45,8 +45,8 @@ struct ForwardKernel { out_dim, batch_size); } else { - assert(m->aggr == AGGR_MODE_AVG || m->aggr == AGGR_MODE_SUM); - embed_forward_with_aggr<<<<>>(input.get(), @@ -55,7 +55,7 @@ struct ForwardKernel { out_dim, in_dim, batch_size, - m->aggr); + aggr); } } } @@ -63,7 +63,7 @@ struct ForwardKernel { template struct BackwardKernel { void operator()(cudaStream_t stream, - EmbeddingPerDeviceState const *m, + AggrMode aggr, GenericTensorAccessorR const &input, GenericTensorAccessorR const &output, GenericTensorAccessorW const &weight_grad, @@ -73,8 +73,8 @@ struct BackwardKernel { assert(input.data_type == DT_INT32 || input.data_type == DT_INT64); assert(output.data_type == DT_HALF || output.data_type == DT_FLOAT, || output.data_type == DT_DOUBLE); - if (m->aggr == AGGR_MODE_NONE) { - embed_backward_no_aggr<<<<>>(input.get(), @@ -83,7 +83,7 @@ struct BackwardKernel { out_dim, batch_size); } else { - embed_backward_with_aggr<<<<>>(input.get(), @@ -92,23 +92,25 @@ struct BackwardKernel { out_dim, in_dim, batch_size, - m->aggr); + aggr); } } } -void forward_kernel(cudaStream_t stream, - EmbeddingPerDeviceState const *m, +void forward_kernel(ffStream_t stream, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, GenericTensorAccessorR const &weight, + DataType input_data_type, + DataType output_data_type, + AggrMode aggr, int in_dim, int out_dim, int batch_size) { - DataTypeDispatch2{}(m->input_data_type, - m->output_data_type, + DataTypeDispatch2{}(input_data_type, + output_data_type, stream, - m, + aggr, input, output, weight, @@ -118,17 +120,19 @@ void forward_kernel(cudaStream_t stream, } void backward_kernel(cudaStream_t stream, - EmbeddingPerDeviceState const *m, GenericTensorAccessorR const &input, GenericTensorAccessorR const &output, GenericTensorAccessorW const &weight_grad, + DataType input_data_type, + DataType output_data_type, + AggrMode aggr, int in_dim, int out_dim, int batch_size) { - DataTypeDispatch2{}(m->input_data_type, - m->output_data_type, + DataTypeDispatch2{}(input_data_type, + output_data_type, stream, - m, + aggr, input, output, weight, diff --git a/lib/kernels/src/hip/embedding_kernels.cpp b/lib/kernels/src/hip/embedding_kernels.cpp index 93bb7276cb..17edfea5c1 100644 --- a/lib/kernels/src/hip/embedding_kernels.cpp +++ b/lib/kernels/src/hip/embedding_kernels.cpp @@ -25,7 +25,7 @@ namespace Embedding { template struct ForwardKernel { void operator()(hipStream_t stream, - EmbeddingPerDeviceState const *m, + AggrMode aggr, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, GenericTensorAccessorR const &weight, @@ -36,9 +36,9 @@ struct ForwardKernel { assert(weight.data_type == DT_HALF || weight.data_type == DT_FLOAT || weight.data_type == DT_DOUBLE); - if (m->aggr == AGGR_MODE_NONE) { + if (aggr == AGGR_MODE_NONE) { hipLaunchKernelGGL(HIP_KERNEL_NAME(embed_forward_no_aggr), - GET_BLOCKS(output.domain.get_volume()), + GET_BLOCKS(output.shape.get_volume()), CUDA_NUM_THREADS, 0, stream, @@ -49,7 +49,7 @@ struct ForwardKernel { batch_size); } else { hipLaunchKernelGGL(HIP_KERNEL_NAME(embed_forward_with_aggr), - GET_BLOCKS(output.domain.get_volume()), + GET_BLOCKS(output.shape.get_volume()), CUDA_NUM_THREADS, 0, stream, @@ -59,7 +59,7 @@ struct ForwardKernel { out_dim, in_dim, batch_size, - m->aggr); + aggr); } } } @@ -67,7 +67,7 @@ struct ForwardKernel { template struct BackwardKernel { void operator()(hipStream_t stream, - EmbeddingPerDeviceState const *m, + AggrMode aggr, GenericTensorAccessorR const &input, GenericTensorAccessorR const &output, GenericTensorAccessorW const &weight_grad, @@ -77,9 +77,9 @@ struct BackwardKernel { assert(input.data_type == DT_INT32 || input.data_type == DT_INT64); assert(output.data_type == DT_HALF || output.data_type == DT_FLOAT, || output.data_type == DT_DOUBLE); - if (m->aggr == AGGR_MODE_NONE) { + if (aggr == AGGR_MODE_NONE) { hipLaunchKernelGGL(HIP_KERNEL_NAME(embed_backward_no_aggr), - GET_BLOCKS(output.domain.get_volume()), + GET_BLOCKS(output.shape.get_volume()), CUDA_NUM_THREADS, 0, stream, @@ -90,7 +90,7 @@ struct BackwardKernel { batch_size); } else { hipLaunchKernelGGL(HIP_KERNEL_NAME(embed_backward_with_aggr), - GET_BLOCKS(output.domain.get_volume()), + GET_BLOCKS(output.shape.get_volume()), CUDA_NUM_THREADS, 0, stream, @@ -100,23 +100,25 @@ struct BackwardKernel { out_dim, in_dim, batch_size, - m->aggr); + aggr); } } } void forward_kernel(hipStream_t stream, - EmbeddingPerDeviceState const *m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, GenericTensorAccessorR const &weight, + DataType input_data_type, + DataType output_data_type, + AggrMode aggr, int in_dim, int out_dim, int batch_size) { - DataTypeDispatch2{}(m->input_data_type, - m->output_data_type, + DataTypeDispatch2{}(input_data_type, + output_data_type, stream, - m, + aggr, input, output, weight, @@ -126,17 +128,19 @@ void forward_kernel(hipStream_t stream, } void backward_kernel(hipStream_t stream, - EmbeddingPerDeviceState const *m, GenericTensorAccessorR const &input, GenericTensorAccessorR const &output, GenericTensorAccessorW const &weight_grad, + DataType input_data_type, + DataType output_data_type, + AggrMode aggr, int in_dim, int out_dim, int batch_size) { - DataTypeDispatch2{}(m->input_data_type, - m->output_data_type, + DataTypeDispatch2{}(input_data_type, + output_data_type, stream, - m, + aggr, input, output, weight, diff --git a/lib/runtime/src/ops/embedding.cc b/lib/runtime/src/ops/embedding.cc index 281ad9bc26..a1bc915d2f 100644 --- a/lib/runtime/src/ops/embedding.cc +++ b/lib/runtime/src/ops/embedding.cc @@ -14,1187 +14,165 @@ */ #include "embedding.h" -#include "utils/hash_utils.h" +#include "kernels/embedding_kernels.h" +#include "legion.h" +#include "op-attrs/ops/embedding.h" namespace FlexFlow { // declare Legion names -using Legion::ArgumentMap; using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; -using Legion::InlineLauncher; using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; using Legion::Runtime; using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; using namespace FlexFlow::Kernels::Embedding; -Tensor FFModel::embedding(const Tensor input, - int num_entries, - int out_dim, - AggrMode aggr, - DataType dtype, - Layer const *shared_op, - Initializer *kernel_initializer, - char const *name) { - Layer *embed = new Layer(this, - OP_EMBEDDING, - dtype, - name, - 1 /*inputs*/, - 1 /*weights*/, - 1 /*outputs*/, - input); - if (aggr == AGGR_MODE_NONE) { - int numdims = input->num_dims + 1; - int dims[MAX_TENSOR_DIM]; - for (int i = 1; i < numdims; i++) { - dims[i] = input->dims[i - 1]; - } - dims[0] = out_dim; - embed->outputs[0] = create_tensor_legion_ordering( - numdims, dims, embed->data_type, embed, 0, true /*create_grad*/); - } else { - int numdims = input->num_dims; - int dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdims; i++) { - dims[i] = input->dims[i]; - } - dims[0] = out_dim; - embed->outputs[0] = create_tensor_legion_ordering( - numdims, dims, embed->data_type, embed, 0, true /*create_grad*/); - } - { - int dims[2] = {out_dim, num_entries}; - embed->weights[0] = create_weight_legion_ordering(2, - dims, - dtype, - embed, - true /*create_grad*/, - kernel_initializer, - CHOSEN_SYNC_TYPE); - } - embed->data_type = dtype; - embed->add_int_property("num_entries", num_entries); - embed->add_int_property("out_dim", out_dim); - embed->add_int_property("aggr_mode", aggr); - embed->add_initializer("kernel", kernel_initializer); - layers.push_back(embed); - return embed->outputs[0]; -} - -EmbeddingParams Embedding::get_params() const { - EmbeddingParams params; - params.num_entries = this->num_entries; - params.out_channels = this->out_channels; - params.aggr = this->aggr; - params.data_type = this->data_type; - // TODO: get rid of layer_guid - // https://github.com/flexflow/FlexFlow/issues/304 - params.layer_guid = this->layer_guid; - return params; -} - -Op *Embedding::create_operator_from_layer( - FFModel &model, - Layer const *layer, - std::vector const &inputs) { - long long value; - layer->get_int_property("num_entries", value); - int num_entries = value; - layer->get_int_property("out_dim", value); - int out_dim = value; - layer->get_int_property("aggr_mode", value); - AggrMode aggr = (AggrMode)value; - Initializer *kernel_initializer; - layer->get_initializer("kernel", kernel_initializer); - return new Embedding(model, - layer->layer_guid, - inputs[0], - num_entries, - out_dim, - aggr, - false /*allocate_weights*/, - layer->data_type, - layer->name); -} - -int Embedding::input_vocab_size_replica_dim() const { - return this->inputs[0]->num_dims - 1; -} - -int Embedding::input_channel_out_replica_dim() const { - return this->inputs[0]->num_dims - 2; -} - -int Embedding::output_vocab_size_replica_dim() const { - assert(this->outputs[0] != nullptr); - return this->outputs[0]->num_dims - 1; -} - -int Embedding::output_size(ParallelDim output_dims[MAX_TENSOR_DIM]) { - ParallelTensor const &input = this->inputs[0]; - - int const OUT_CHANNELS = Output::OUT_CHANNELS; - if (aggr == AGGR_MODE_NONE) { - int num_dims = input->num_dims + 1; - for (int i = 1; i < num_dims - 1; i++) { - output_dims[i] = input->dims[i - 1]; - } - assert(OUT_CHANNELS == 0); - output_dims[OUT_CHANNELS].size = this->out_channels; - output_dims[OUT_CHANNELS].degree = 1; - output_dims[OUT_CHANNELS].parallel_idx = -1; - // Currently do not support parallelizing over the replica dim - output_dims[num_dims - 1].size = 1; - output_dims[num_dims - 1].degree = 1; - output_dims[num_dims - 1].parallel_idx = -1; - output_dims[num_dims - 1].is_replica_dim = true; - return num_dims; - } else { - int num_dims = input->num_dims; - for (int i = 1; i < num_dims - 1; i++) { - output_dims[i] = input->dims[i]; - } - assert(OUT_CHANNELS == 0); - output_dims[OUT_CHANNELS].size = this->out_channels; - output_dims[OUT_CHANNELS].degree = 1; - output_dims[OUT_CHANNELS].parallel_idx = -1; - // Currently do not support parallelizing over the replica dim - output_dims[num_dims - 1].size = 1; - output_dims[num_dims - 1].degree = 1; - output_dims[num_dims - 1].parallel_idx = -1; - output_dims[num_dims - 1].is_replica_dim = true; - return num_dims; - } - // const int REPLICA = this->output_vocab_size_replica_dim(); -} +enum Slots { INPUT, WEIGHT, OUTPUT, ATTRS, PROFILING }; -int Embedding::weight_size(ParallelDim weight_dims[MAX_TENSOR_DIM]) { - ParallelTensor const &input = this->inputs[0]; +OpTaskInvocation forward(EmbeddingAttrs const &attrs) { + OpTaskBinding b; - weight_dims[Weight::OUT_CHANNELS].size = this->out_channels; - weight_dims[Weight::OUT_CHANNELS].degree = 1; - weight_dims[Weight::OUT_CHANNELS].parallel_idx = -1; - weight_dims[Weight::VOCAB_SIZE].size = this->num_entries; - weight_dims[Weight::VOCAB_SIZE].degree = 1; - weight_dims[Weight::VOCAB_SIZE].parallel_idx = -1; - for (int i = 2; i < input->num_dims; i++) { - weight_dims[i].size = input->dims[i - 1].degree; - weight_dims[i].degree = weight_dims[i].size; - weight_dims[i].parallel_idx = input->dims[i - 1].parallel_idx; - weight_dims[i].is_replica_dim = true; - } - return input->num_dims; -} + b.bind(INPUT, input_tensor(0)); + b.bind(WEIGHT, weight_tensor(0)); + b.bind(OUTPUT, output_tensor(0)); -void Embedding::register_output_mappings() { - if (aggr == AGGR_MODE_NONE) { - int num_dims = this->inputs[0]->num_dims + 1; - for (int i = 1; i < num_dims - 1; i++) { - this->register_output_parallel_dims(i - 1, i); - } - } else { - int num_dims = this->inputs[0]->num_dims; - for (int i = 1; i < num_dims - 1; i++) { - this->register_output_parallel_dims(i, i); - } - } -} + b.bind_arg(ATTRS, attrs); + b.bind_arg(PROFILING, profiling_settings()); -void Embedding::register_weight_mappings() { - for (int i = 2; i < this->inputs[0]->num_dims; i++) { - this->register_weight_parallel_dims(i - 1, i); - } + return {EMBED_FWD_TASK_ID, b}; } -void Embedding::register_mappings() { - this->register_output_mappings(); - this->register_weight_mappings(); -} +OpTaskInvocation backward(EmbeddingAttrs const &attrs) { + OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); -/* Params */ - -bool operator==(EmbeddingParams const &lhs, EmbeddingParams const &rhs) { - return lhs.layer_guid == rhs.layer_guid && - lhs.out_channels == rhs.out_channels && - lhs.num_entries == rhs.num_entries && lhs.aggr == rhs.aggr && - lhs.data_type == rhs.data_type; + return {EMBED_BWD_TASK_ID, b}; } -Embedding::Embedding(FFModel &model, - EmbeddingParams const ¶ms, - ParallelTensor const input, - bool allocate_weights, - char const *name) - : Embedding(model, - params.layer_guid, - input, - params.num_entries, - params.out_channels, - params.aggr, - allocate_weights, - params.data_type, - name) {} - -Embedding::Embedding(FFModel &model, - Embedding const &other, - const ParallelTensor input, - bool allocate_weights) - : Embedding(model, - other.layer_guid, - input, - other.num_entries, - other.out_channels, - other.aggr, - allocate_weights, - other.data_type, - other.name) {} - -Embedding::Embedding(FFModel &model, - LayerID const &_layer_guid, - const ParallelTensor _input, - int _num_entries, - int _out_channels, - AggrMode _aggr, - bool allocate_weights, - DataType dtype, - char const *name) - : Op(model, - OP_EMBEDDING, - dtype, - name, - 1 /*inputs*/, - 1 /*weights*/, - allocate_weights, - 1 /*outputs*/, - _input), - num_entries(_num_entries), out_channels(_out_channels), aggr(_aggr) { - layer_guid = _layer_guid; - std::vector weight_dim_sets; - - int weight_ndim; - ParallelDim weight_dims[MAX_TENSOR_DIM]; - if (allocate_weights) { - weight_ndim = this->weight_size(weight_dims); - weight_dim_sets.push_back(weight_dims); - } - - ParallelDim output_dims[MAX_TENSOR_DIM]; - int output_ndim = this->output_size(output_dims); - - // register mappings between inputs/weights and outputs - this->register_mappings(); - - this->solve_parallel_dim_mappings( - {_input->dims}, weight_dim_sets, {output_dims}); - - if (allocate_weights) { - Initializer *weight_initializer = new GlorotUniform(std::rand() /*seed*/); - // Initializer *weight_initializer = new ZeroInitializer(/*seed*/); - - weights[0] = - model.create_parallel_weight_legion_ordering(weight_ndim, - weight_dims, - dtype, - nullptr /*owner_op*/, - true /*create_grad*/, - weight_initializer, - CHOSEN_SYNC_TYPE); - } +static optional forward_task_impl(TaskArgumentAccessor const &acc) { + auto input = acc.get_tensor(INPUT); + auto weight = acc.get_tensor(WEIGHT); + auto output = acc.get_tensor(OUTPUT); - outputs[0] = model.create_parallel_tensor_legion_ordering( - output_ndim, output_dims, dtype, this); + ProfilingSettings profiling = acc.get_argument(PROFILING); + EmbeddingAttrs attrs = acc.get_argument(ATTRS); - assert(check_output_input_weight_parallel_dims(allocate_weights)); + return profile(forward_kernel, + profiling, + "[Embedding] forward_time = %.2lfms\n", + input, + output, + weight, + input.data_type, + output.data_type, + attrs.aggr, + input.shape.get_dim(), + output.shape.get_dim(), + input.shape[legion_dim_t(1)]); } -void Embedding::init(FFModel const &ff) { - assert(check_output_input_weight_same_parallel_is()); - parallel_is = outputs[0]->parallel_is; - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - set_argumentmap_for_init(ff, argmap); - IndexLauncher launcher(EMBED_INIT_TASK_ID, - parallel_is, - TaskArgument(this, sizeof(Embedding)), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - outputs[0]->machine_view.hash()); - // regions[0]: input - // launcher.add_region_requirement( - // RegionRequirement(input_lps[0], 0/*projection*/, - // READ_ONLY, EXCLUSIVE, inputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // regions[1]: output - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection*/, - WRITE_ONLY, - EXCLUSIVE, - outputs[0]->region)); - launcher.add_field(0, FID_DATA); - // regions[2]: weight - launcher.add_region_requirement(RegionRequirement(weights[0]->part, - 0 /*projection*/, - READ_ONLY, - EXCLUSIVE, - weights[0]->region)); - launcher.add_field(1, FID_DATA); - // regions[3]: input_grad - launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, - 0 /*projection*/, - WRITE_ONLY, - EXCLUSIVE, - inputs[0]->region_grad)); - launcher.add_field(2, FID_DATA); - FutureMap fm = runtime->execute_index_space(ctx, launcher); - fm.wait_all_results(); - set_opmeta_from_futuremap(ff, fm); -} - -PerDeviceOpState * - Embedding::init_task(Task const *task, +static void forward_task(Task const *task, std::vector const ®ions, Context ctx, Runtime *runtime) { - Embedding const *embed = (Embedding *)task->args; - FFHandler handle = *((FFHandler const *)task->local_args); - EmbeddingMeta *m = new EmbeddingMeta(handle, embed); - m->profiling = embed->profiling; - m->aggr = embed->aggr; - return m; -} - -void Embedding::forward(FFModel const &ff) { - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - set_argumentmap_for_forward(ff, argmap); - IndexLauncher launcher(EMBED_FWD_TASK_ID, - parallel_is, - TaskArgument(NULL, 0), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - outputs[0]->machine_view.hash()); - // regions[0]: input - launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - 0 /*projection*/, - READ_ONLY, - EXCLUSIVE, - inputs[0]->region)); - launcher.add_field(0, FID_DATA); - // regions[1]: output - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection*/, - WRITE_ONLY, - EXCLUSIVE, - outputs[0]->region, - MAP_TO_ZC_MEMORY)); - launcher.add_field(1, FID_DATA); - // regions[2]: weight - launcher.add_region_requirement(RegionRequirement(weights[0]->part, - 0 /*projection*/, - READ_ONLY, - EXCLUSIVE, - weights[0]->region)); - launcher.add_field(2, FID_DATA); - runtime->execute_index_space(ctx, launcher); -} - -/* - regions[0](I): input - regions[1](O): output - regions[2](I): kernel -*/ -void Embedding::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - EmbeddingMeta const *m = *((EmbeddingMeta **)task->local_args); - assert(regions.size() == 3); - assert(task->regions.size() == 3); - // Assert that weight and output must have the same data type - // otherwise, a cast operator should be inserted - assert(m->weight_type[0] == m->output_type[0]); - assert(m->input_type[0] == DT_INT32 || m->input_type[0] == DT_INT64); - GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( - m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); - GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( - m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); - GenericTensorAccessorR kernel = helperGetGenericTensorAccessorRO( - m->weight_type[0], regions[2], task->regions[2], FID_DATA, ctx, runtime); - if (m->aggr == AGGR_MODE_NONE) { - // assert(kernel_domain.get_dim() == 2); - assert(input.domain.get_dim() + 1 == output.domain.get_dim()); - for (size_t i = 0; i < input.domain.get_dim(); i++) { - assert(input.domain.hi()[i] == output.domain.hi()[i + 1]); - assert(input.domain.lo()[i] == output.domain.lo()[i + 1]); - } - assert(kernel.domain.hi()[0] - kernel.domain.lo()[0] == - output.domain.hi()[0] - output.domain.lo()[0]); - } else { - // assert(kernel_domain.get_dim() == 2); - assert(input.domain.get_dim() == output.domain.get_dim()); - for (size_t i = 1; i < input.domain.get_dim(); i++) { - assert(input.domain.hi()[i] == output.domain.hi()[i]); - assert(input.domain.lo()[i] == output.domain.lo()[i]); - } - assert(kernel.domain.hi()[0] - kernel.domain.lo()[0] == - output.domain.hi()[0] - output.domain.lo()[0]); - } - - int in_dim, out_dim, effective_batch_size; - if (m->aggr == AGGR_MODE_NONE) { - in_dim = 1; - out_dim = output.domain.hi()[0] - output.domain.lo()[0] + 1; - effective_batch_size = output.domain.get_volume() / out_dim; - assert(effective_batch_size * in_dim == input.domain.get_volume()); - } else { - in_dim = input.domain.hi()[0] - input.domain.lo()[0] + 1; - out_dim = output.domain.hi()[0] - output.domain.lo()[0] + 1; - effective_batch_size = output.domain.get_volume() / out_dim; - assert(effective_batch_size * in_dim == input.domain.get_volume()); - } - forward_kernel_wrapper( - m, input, output, kernel, in_dim, out_dim, effective_batch_size); + TaskArgumentAccessor acc(task, regions, ctx, runtime); + forward_task_impl(acc); } -#ifdef DEADCODE -template -void Embedding::forward_task_with_type( - Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 3); - assert(task->regions.size() == 3); - // const Embedding* embed = (Embedding*) task->args; - EmbeddingMeta const *m = *((EmbeddingMeta **)task->local_args); - Domain input_domain = runtime->get_index_space_domain( - ctx, task->regions[0].region.get_index_space()); - Domain output_domain = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); - Domain kernel_domain = runtime->get_index_space_domain( - ctx, task->regions[2].region.get_index_space()); - if (m->aggr == AGGR_MODE_NONE) { - // assert(kernel_domain.get_dim() == 2); - assert(input_domain.get_dim() + 1 == output_domain.get_dim()); - for (size_t i = 0; i < input_domain.get_dim(); i++) { - assert(input_domain.hi()[i] == output_domain.hi()[i + 1]); - assert(input_domain.lo()[i] == output_domain.lo()[i + 1]); - } - assert(kernel_domain.hi()[0] - kernel_domain.lo()[0] == - output_domain.hi()[0] - output_domain.lo()[0]); - } else { - // assert(kernel_domain.get_dim() == 2); - assert(input_domain.get_dim() == output_domain.get_dim()); - for (size_t i = 1; i < input_domain.get_dim(); i++) { - assert(input_domain.hi()[i] == output_domain.hi()[i]); - assert(input_domain.lo()[i] == output_domain.lo()[i]); - } - assert(kernel_domain.hi()[0] - kernel_domain.lo()[0] == - output_domain.hi()[0] - output_domain.lo()[0]); - } - const TI *input_ptr = helperGetTensorPointerRO( - regions[0], task->regions[0], FID_DATA, ctx, runtime); - float *output_ptr = helperGetTensorPointerWO( - regions[1], task->regions[1], FID_DATA, ctx, runtime); - float const *kernel_ptr = helperGetTensorPointerRO( - regions[2], task->regions[2], FID_DATA, ctx, runtime); +static optional backward_task_impl(TaskArgumentAccessor const &acc) { + auto input = acc.get_tensor(INPUT); + auto output = acc.get_tensor(OUTPUT); + auto weight_grad = acc.get_tensor_grad(WEIGHT); - int in_dim, out_dim, effective_batch_size; - if (m->aggr == AGGR_MODE_NONE) { - in_dim = 1; - out_dim = output_domain.hi()[0] - output_domain.lo()[0] + 1; - effective_batch_size = output_domain.get_volume() / out_dim; - assert(effective_batch_size * in_dim == input_domain.get_volume()); - } else { - in_dim = input_domain.hi()[0] - input_domain.lo()[0] + 1; - out_dim = output_domain.hi()[0] - output_domain.lo()[0] + 1; - effective_batch_size = output_domain.get_volume() / out_dim; - assert(effective_batch_size * in_dim == input_domain.get_volume()); - } + ProfilingSettings profiling = acc.get_argument(PROFILING); + EmbeddingAttrs attrs = acc.get_argument(ATTRS); - forward_kernel_wrapper(m, - input_ptr, - output_ptr, - kernel_ptr, - in_dim, - out_dim, - effective_batch_size, - m->aggr, - output_domain.get_volume()); + return profile(backward_kernel, + profiling, + "[Embedding] forward_time = %.2lfms\n", + input, + output, + weight_grad, + input.data_type, + output.data_type, + attrs.aggr, + input.shape.get_dim(), + output.shape.get_dim(), + input.shape[ff_dim_t(0)]); } -#endif -void Embedding::backward(FFModel const &ff) { - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - set_argumentmap_for_backward(ff, argmap); - IndexLauncher launcher(EMBED_BWD_TASK_ID, - parallel_is, - TaskArgument(NULL, 0), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - outputs[0]->machine_view.hash()); - // regions[0]: input - launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - 0 /*projection*/, - READ_ONLY, - EXCLUSIVE, - inputs[0]->region)); - launcher.add_field(0, FID_DATA); - // regions[1]: output_grad - launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, - 0 /*projection*/, - READ_ONLY, - EXCLUSIVE, - outputs[0]->region_grad)); - launcher.add_field(1, FID_DATA); - // regions[2]: weight_grad - launcher.add_region_requirement(RegionRequirement(weights[0]->part_grad, - 0 /*projection*/, - READ_WRITE, - EXCLUSIVE, - weights[0]->region_grad)); - launcher.add_field(2, FID_DATA); - runtime->execute_index_space(ctx, launcher); +static void backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + backward_task_impl(acc); } -void Embedding::backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - EmbeddingMeta const *m = *((EmbeddingMeta **)task->local_args); - assert(regions.size() == 3); - assert(task->regions.size() == 3); - // Assert that weight and output must have the same data type - // otherwise, a cast operator should be inserted - assert(m->weight_type[0] == m->output_type[0]); - assert(m->input_type[0] == DT_INT32 || m->input_type[0] == DT_INT64); - GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( - m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); - GenericTensorAccessorR output_grad = helperGetGenericTensorAccessorRO( - m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); - GenericTensorAccessorW kernel_grad = helperGetGenericTensorAccessorRW( - m->weight_type[0], regions[2], task->regions[2], FID_DATA, ctx, runtime); - if (m->aggr == AGGR_MODE_NONE) { - // assert(kernel_grad_domain.get_dim() == 2); - assert(input.domain.get_dim() + 1 == output_grad.domain.get_dim()); - for (size_t i = 0; i < input.domain.get_dim(); i++) { - assert(input.domain.hi()[i] == output_grad.domain.hi()[i + 1]); - assert(input.domain.lo()[i] == output_grad.domain.lo()[i + 1]); - } - assert(kernel_grad.domain.hi()[0] - kernel_grad.domain.lo()[0] == - output_grad.domain.hi()[0] - output_grad.domain.lo()[0]); - } else { - // assert(kernel_grad_domain.get_dim() == 2); - assert(input.domain.get_dim() == output_grad.domain.get_dim()); - for (size_t i = 1; i < input.domain.get_dim(); i++) { - assert(input.domain.hi()[i] == output_grad.domain.hi()[i]); - assert(input.domain.lo()[i] == output_grad.domain.lo()[i]); - } - assert(kernel_grad.domain.hi()[0] - kernel_grad.domain.lo()[0] == - output_grad.domain.hi()[0] - output_grad.domain.lo()[0]); - } - int in_dim, out_dim, effective_batch_size; - if (m->aggr == AGGR_MODE_NONE) { - in_dim = 1; - out_dim = output_grad.domain.hi()[0] - output_grad.domain.lo()[0] + 1; - effective_batch_size = output_grad.domain.get_volume() / out_dim; - assert(effective_batch_size * in_dim == input.domain.get_volume()); - } else { - in_dim = input.domain.hi()[0] - input.domain.lo()[0] + 1; - out_dim = output_grad.domain.hi()[0] - output_grad.domain.lo()[0] + 1; - effective_batch_size = output_grad.domain.get_volume() / out_dim; - assert(effective_batch_size * in_dim == input.domain.get_volume()); - } - backward_kernel_wrapper(m, - input, - output_grad, - kernel_grad, - in_dim, - out_dim, - effective_batch_size); -} - -#ifdef DEADCODE -template -void Embedding::backward_task_with_type( - Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 3); - assert(task->regions.size() == 3); - // const Embedding* embed = (Embedding*) task->args; - EmbeddingMeta const *m = *((EmbeddingMeta **)task->local_args); - Domain input_domain = runtime->get_index_space_domain( - ctx, task->regions[0].region.get_index_space()); - Domain output_grad_domain = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); - Domain kernel_grad_domain = runtime->get_index_space_domain( - ctx, task->regions[2].region.get_index_space()); - if (m->aggr == AGGR_MODE_NONE) { - // assert(kernel_grad_domain.get_dim() == 2); - assert(input_domain.get_dim() + 1 == output_grad_domain.get_dim()); - for (size_t i = 0; i < input_domain.get_dim(); i++) { - assert(input_domain.hi()[i] == output_grad_domain.hi()[i + 1]); - assert(input_domain.lo()[i] == output_grad_domain.lo()[i + 1]); - } - assert(kernel_grad_domain.hi()[0] - kernel_grad_domain.lo()[0] == - output_grad_domain.hi()[0] - output_grad_domain.lo()[0]); - } else { - // assert(kernel_grad_domain.get_dim() == 2); - assert(input_domain.get_dim() == output_grad_domain.get_dim()); - for (size_t i = 1; i < input_domain.get_dim(); i++) { - assert(input_domain.hi()[i] == output_grad_domain.hi()[i]); - assert(input_domain.lo()[i] == output_grad_domain.lo()[i]); - } - assert(kernel_grad_domain.hi()[0] - kernel_grad_domain.lo()[0] == - output_grad_domain.hi()[0] - output_grad_domain.lo()[0]); - } - const TI *input_ptr = helperGetTensorPointerRO( - regions[0], task->regions[0], FID_DATA, ctx, runtime); - float const *output_grad_ptr = helperGetTensorPointerWO( - regions[1], task->regions[1], FID_DATA, ctx, runtime); - float *kernel_grad_ptr = helperGetTensorPointerRW( - regions[2], task->regions[2], FID_DATA, ctx, runtime); - - int in_dim, out_dim, effective_batch_size; - if (m->aggr == AGGR_MODE_NONE) { - in_dim = 1; - out_dim = output_grad_domain.hi()[0] - output_grad_domain.lo()[0] + 1; - effective_batch_size = output_grad_domain.get_volume() / out_dim; - assert(effective_batch_size * in_dim == input_domain.get_volume()); - } else { - in_dim = input_domain.hi()[0] - input_domain.lo()[0] + 1; - out_dim = output_grad_domain.hi()[0] - output_grad_domain.lo()[0] + 1; - effective_batch_size = output_grad_domain.get_volume() / out_dim; - assert(effective_batch_size * in_dim == input_domain.get_volume()); - } - backward_kernel_wrapper(m, - input_ptr, - output_grad_ptr, - kernel_grad_ptr, - in_dim, - out_dim, - effective_batch_size, - m->aggr, - output_grad_domain.get_volume()); -} -#endif - -bool Embedding::measure_operator_cost(Simulator *sim, - MachineView const &mv, - CostMetrics &cost_metrics) const { - ParallelTensorBase sub_input, sub_output; - if (!outputs[0]->get_sub_tensor(mv, sub_output)) { - return false; - } - if (!inputs[0]->get_sub_tensor(mv, sub_input)) { - return false; - } - - EmbeddingMeta *m = new EmbeddingMeta(sim->handler, this); - assert(m->profiling == false); - m->aggr = this->aggr; +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + EmbeddingAttrs const &attrs, + InputParallelTensorDesc const &input, + ProfilingSettings const &settings, + MachineView const &mv) { + auto env = sim.new_environment(); - sim->free_all(); - bool out_of_memory = false; - Domain in_domain = sub_input.get_domain(); - void *input_ptr = sim->allocate(sub_input.get_volume(), inputs[0]->data_type); - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - GenericTensorAccessorW input_acc(inputs[0]->data_type, in_domain, input_ptr); + ParallelTensorShape output_shape = get_output_shape(attrs, input.shape); + TensorShape weight_shape = + get_weights_shape(attrs, get_piece_shape(input.shape)); - out_of_memory = out_of_memory || (input_ptr == NULL); - Domain out_domain = sub_output.get_domain(); - void *output_ptr = - sim->allocate(sub_output.get_volume(), outputs[0]->data_type); - out_of_memory = out_of_memory || (output_ptr == NULL); - cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - GenericTensorAccessorW output_acc( - outputs[0]->data_type, out_domain, output_ptr); + SimTaskBinding fwd_binding; + fwd_binding.bind(INPUT, input.shape); + fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind(WEIGHT, weight_shape); + fwd_binding.bind_arg(PROFILING, settings); + fwd_binding.bind_arg(ATTRS, attrs); - Domain weight_domain; - weight_domain.dim = 2; - weight_domain.rect_data[0] = 0; - weight_domain.rect_data[1] = 0; - weight_domain.rect_data[2] = num_entries - 1; - weight_domain.rect_data[3] = out_channels - 1; + SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); - void *weight_ptr = sim->allocate(num_entries * out_channels, this->data_type); - cost_metrics.weights_memory += cost_metrics.total_mem_diff_from(sim->offset); - out_of_memory = out_of_memory || (weight_ptr == NULL); - GenericTensorAccessorR weight_acc(this->data_type, weight_domain, weight_ptr); - if (out_of_memory) { - cost_metrics.forward_time = Simulator::MAXIMUM_TASK_RUN_TIME; - cost_metrics.backward_time = Simulator::MAXIMUM_TASK_RUN_TIME; - return true; - } + auto fwd_accessor = env.get_fwd_accessor(EMBED_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = env.get_bwd_accessor(EMBED_BWD_TASK_ID, bwd_binding); - int in_dim = this->aggr == AGGR_MODE_NONE ? 1 : sub_input.dims[0].size; - int out_dim = sub_output.dims[0].size; - int effective_batch_size = sub_output.get_volume() / out_dim; - assert(effective_batch_size * in_dim == sub_input.get_volume()); + float forward_time = forward_task_impl(fwd_accessor).value(); + float backward_time = backward_task_impl(bwd_accessor).value(); - // Randomly initialize the intput tensor to avoid out of index range issues - if (inputs[0]->data_type == DT_INT32) { - rand_generate_int32_wrapper( - input_acc.get_int32_ptr(), sub_input.get_volume(), num_entries); - } else if (inputs[0]->data_type == DT_INT64) { - rand_generate_int64_wrapper( - input_acc.get_int64_ptr(), sub_input.get_volume(), num_entries); - } - - std::function forward, backward; - forward = [&] { - forward_kernel_wrapper(m, - input_acc, - output_acc, - weight_acc, - in_dim, - out_dim, - effective_batch_size); - }; - if (sim->computationMode == COMP_MODE_TRAINING) { - void *weight_grad_ptr = - sim->allocate(num_entries * out_channels, this->data_type); - cost_metrics.weights_memory += - cost_metrics.total_mem_diff_from(sim->offset); - out_of_memory = out_of_memory || (weight_grad_ptr == NULL); - GenericTensorAccessorW weight_grad_acc( - this->data_type, weight_domain, weight_grad_ptr); - - void *output_grad_ptr = - sim->allocate(sub_output.get_volume(), outputs[0]->data_type); - cost_metrics.outputs_memory += - cost_metrics.total_mem_diff_from(sim->offset); - out_of_memory = out_of_memory || (output_grad_ptr == NULL); - GenericTensorAccessorR output_grad_acc( - outputs[0]->data_type, out_domain, output_grad_ptr); - - void *input_grad_ptr = - sim->allocate(sub_input.get_volume(), inputs[0]->data_type); - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - out_of_memory = out_of_memory || (input_grad_ptr == NULL); - GenericTensorAccessorW input_grad_acc( - inputs[0]->data_type, in_domain, input_grad_ptr); - - if (out_of_memory) { - cost_metrics.forward_time = Simulator::MAXIMUM_TASK_RUN_TIME; - cost_metrics.backward_time = Simulator::MAXIMUM_TASK_RUN_TIME; - return true; - } - backward = [&] { - backward_kernel_wrapper(m, - input_grad_acc, - output_grad_acc, - weight_grad_acc, - in_dim, - out_dim, - effective_batch_size); - }; - } - - inner_measure_operator_cost(sim, forward, backward, cost_metrics); - - if (sim->computationMode == COMP_MODE_TRAINING) { - printf("[Measure Embedding] name(%s) forward_time(%.4lf) " - "backward_time(%.4lf)\n", - name, - cost_metrics.forward_time, - cost_metrics.backward_time); - } else { - printf("[Measure Embedding] name(%s) forward_time(%.4lf)\n", - name, - cost_metrics.forward_time); - } - delete m; - return true; + float sync_time = default_estimate_sync_time(env); + return make_metrics(forward_time, backward_time, sync_time, env); } -void EmbeddingLookup_int64_t_float_float__avx2_fma(int const block_size, - int const output_size, - int const index_size, - int const data_size, - float const *input, - int64_t const *indices, - int const *lengths, - float const *weight, - bool normalize_by_lengths, - float *out) { -#ifdef FF_USE_AVX2 - const int64_t prefdist_T0 = 16; - if (block_size == 128) { - // unrolling 16 times - int64_t dataInd = 0; - for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { - float *op = &out[rangeIndex * block_size]; - __m256 vop0 = _mm256_setzero_ps(); - __m256 vop8 = _mm256_setzero_ps(); - __m256 vop16 = _mm256_setzero_ps(); - __m256 vop24 = _mm256_setzero_ps(); - __m256 vop32 = _mm256_setzero_ps(); - __m256 vop40 = _mm256_setzero_ps(); - __m256 vop48 = _mm256_setzero_ps(); - __m256 vop56 = _mm256_setzero_ps(); - __m256 vop64 = _mm256_setzero_ps(); - __m256 vop72 = _mm256_setzero_ps(); - __m256 vop80 = _mm256_setzero_ps(); - __m256 vop88 = _mm256_setzero_ps(); - __m256 vop96 = _mm256_setzero_ps(); - __m256 vop104 = _mm256_setzero_ps(); - __m256 vop112 = _mm256_setzero_ps(); - __m256 vop120 = _mm256_setzero_ps(); - for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; - ++dataInd) { - const int64_t idx = indices[dataInd]; - float wgt = 1.f; - if (weight) { - wgt = weight[dataInd]; - } - __m256 vwgt = _mm256_set1_ps(wgt); - float const *ip = &input[idx * block_size]; - const int64_t next_T0 = (dataInd < index_size - prefdist_T0) - ? (dataInd + prefdist_T0) - : dataInd; - const int64_t idx_pref_T0 = indices[next_T0]; - assert(idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && - idx_pref_T0 < data_size); - float const *ip_next_T0 = &input[idx_pref_T0 * block_size]; - vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); - vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); - _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0); - vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); - _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); - vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); - _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0); - vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); - _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); - vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); - _mm_prefetch((&ip_next_T0[40]), _MM_HINT_T0); - vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); - _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); - vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); - _mm_prefetch((&ip_next_T0[56]), _MM_HINT_T0); - vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64); - _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0); - vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72); - _mm_prefetch((&ip_next_T0[72]), _MM_HINT_T0); - vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80); - _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0); - vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88); - _mm_prefetch((&ip_next_T0[88]), _MM_HINT_T0); - vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96); - _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0); - vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104); - _mm_prefetch((&ip_next_T0[104]), _MM_HINT_T0); - vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112); - _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0); - vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120); - _mm_prefetch((&ip_next_T0[120]), _MM_HINT_T0); - } - if (normalize_by_lengths == false) { - _mm256_storeu_ps(&op[0], vop0); - _mm256_storeu_ps(&op[8], vop8); - _mm256_storeu_ps(&op[16], vop16); - _mm256_storeu_ps(&op[24], vop24); - _mm256_storeu_ps(&op[32], vop32); - _mm256_storeu_ps(&op[40], vop40); - _mm256_storeu_ps(&op[48], vop48); - _mm256_storeu_ps(&op[56], vop56); - _mm256_storeu_ps(&op[64], vop64); - _mm256_storeu_ps(&op[72], vop72); - _mm256_storeu_ps(&op[80], vop80); - _mm256_storeu_ps(&op[88], vop88); - _mm256_storeu_ps(&op[96], vop96); - _mm256_storeu_ps(&op[104], vop104); - _mm256_storeu_ps(&op[112], vop112); - _mm256_storeu_ps(&op[120], vop120); - } else if (lengths[rangeIndex]) { - __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); - _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); - _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); - _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); - _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); - _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); - _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); - _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); - _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); - _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); - _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); - _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); - _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); - _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); - _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); - _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); - _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); - } - } - __m256 vwgt = _mm256_set1_ps(wgt); - float const *ip = &input[idx * block_size]; - const int64_t next_T0 = (dataInd < index_size - prefdist_T0) - ? (dataInd + prefdist_T0) - : dataInd; - const int64_t idx_pref_T0 = indices[next_T0]; - assert(idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && - idx_pref_T0 < data_size); - float const *ip_next_T0 = &input[idx_pref_T0 * block_size]; - vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); - vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); - _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0); - vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); - _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); - vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); - _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0); - } - if (normalize_by_lengths == false) { - _mm256_storeu_ps(&op[0], vop0); - _mm256_storeu_ps(&op[8], vop8); - _mm256_storeu_ps(&op[16], vop16); - _mm256_storeu_ps(&op[24], vop24); - } else if (lengths[rangeIndex]) { - __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); - _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); - _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); - _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); - _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); - } -} -} -else { - // generic code - int64_t dataInd = 0; - for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { - float *op = &out[rangeIndex * block_size]; - int j = 0; - for (; j + 8 <= block_size; j += 8) { - _mm256_storeu_ps(op + j, _mm256_setzero_ps()); - } - for (; j < block_size; j++) { - op[j] = 0.0f; - } - for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; - ++dataInd) { - const int64_t idx = indices[dataInd]; - float wgt = 1.f; - if (weight) { - wgt = weight[dataInd]; - } - __m256 vwgt = _mm256_set1_ps(wgt); - float const *ip = &input[idx * block_size]; - const int64_t next_T0 = (dataInd < index_size - prefdist_T0) - ? (dataInd + prefdist_T0) - : dataInd; - const int64_t idx_pref_T0 = indices[next_T0]; - assert(idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && - idx_pref_T0 < data_size); - float const *ip_next_T0 = &input[idx_pref_T0 * block_size]; - j = 0; - for (; j + 8 <= block_size; j += 8) { - _mm256_storeu_ps(&op[j], - _mm256_fmadd_ps(vwgt, - _mm256_loadu_ps(&ip[j]), - _mm256_loadu_ps(&op[j]))); - _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0); - } - for (; j < block_size; j++) { - op[j] += wgt * ip[j]; - } - } - if (normalize_by_lengths && lengths[rangeIndex]) { - float len_inv = 1.0f / lengths[rangeIndex]; - __m256 vlen_inv = _mm256_set1_ps(len_inv); - j = 0; - for (; j + 8 <= block_size; j += 8) { - _mm256_storeu_ps(&op[j], - _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); - } - for (; j < block_size; j++) { - op[j] = len_inv * op[j]; - } - } - } -} -#else - assert(0); -#endif -} +template <> +OpTaskSignature fwd_signature() { + OpTaskSignature fwd(OpTaskType::FWD); -void embed_forward(int64_t const *input, - int const *lengths, - float *output, - float const *embed, - int block_size, - int output_size, - int index_size, - int data_size) { - EmbeddingLookup_int64_t_float_float__avx2_fma(block_size, - output_size, - index_size, - data_size, - embed, - input, - lengths, - nullptr, - false, - output); -} + fwd.add_input_slot(INPUT); + fwd.add_input_slot(OUTPUT); + fwd.add_input_slot(WEIGHT); -void embed_backward_generic(int64_t const *input, - int const *lengths, - float const *output, - float *embed, - int block_size, - int output_size, - int index_size, - int data_size) { - // FIXME: Not functionaly correct. - for (int i = 0; i < output_size * block_size; i++) { - int idx = i / block_size; - int off = i % block_size; - int64_t wordIdx = input[idx]; - // FIXME: Need to be atomic depending on the strategy - embed[wordIdx * block_size + off] += output[i]; - ; - } -} + fwd.add_arg_slot(ATTRS); + fwd.add_arg_slot(PROFILING); -void embed_backward(int64_t const *input, - int const *lengths, - float const *output, - float *embed, - int block_size, - int output_size, - int index_size, - int data_size) { - embed_backward_generic(input, - lengths, - output, - embed, - block_size, - output_size, - index_size, - data_size); + return fwd; } -void Embedding::forward_task_cpu(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 3); - assert(task->regions.size() == 3); - // const Embedding* embed = (Embedding*) task->args; - AccessorRO const acc_input(regions[0], FID_DATA); - AccessorWO const acc_output(regions[1], FID_DATA); - AccessorRO const acc_weight(regions[2], FID_DATA); - Rect<2> rect_input = runtime->get_index_space_domain( - ctx, task->regions[0].region.get_index_space()); - Rect<2> rect_output = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); - Rect<2> rect_weight = runtime->get_index_space_domain( - ctx, task->regions[2].region.get_index_space()); - coord_t batch_size = rect_input.hi[1] - rect_input.lo[1] + 1; - // Input and output have same batch size - assert(batch_size == rect_output.hi[1] - rect_output.lo[1] + 1); - coord_t out_dim = rect_output.hi[0] - rect_output.lo[0] + 1; - // Weight and output have same out dim - assert(out_dim == rect_weight.hi[1] - rect_weight.lo[1] + 1); - // const int64_t* input = acc_input.ptr(rect_input); - // float* output = acc_output.ptr(rect_output); - // const float* weight = acc_weight.ptr(rect_weight); - int block_size = out_dim; - int output_size = batch_size; - int data_size = 1000000; // FIXME - // For now we are assuming the length is always 1 - int index_size = rect_input.hi[1] - rect_input.lo[1] + 1; - coord_t in_dim = rect_input.hi[0] - rect_input.lo[0] + 1; - assert(in_dim == 1); - std::vector lengths(output_size, 1); - embed_forward(acc_input.ptr(rect_input), - lengths.data(), - acc_output.ptr(rect_output), - acc_weight.ptr(rect_weight), - block_size, - output_size, - index_size, - data_size); +template <> +void register_task() { + register_task(EMBED_FWD_TASK_ID, + "Embed Fwd", + fwd_signature(), + forward_task); } -void Embedding::backward_task_cpu(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 3); - assert(task->regions.size() == 3); - // const Embedding* embed = (Embedding*) task->args; - AccessorRO const acc_input(regions[0], FID_DATA); - AccessorRO const acc_output(regions[1], FID_DATA); - AccessorRW const acc_weight(regions[2], FID_DATA); - Rect<2> rect_input = runtime->get_index_space_domain( - ctx, task->regions[0].region.get_index_space()); - Rect<2> rect_output = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); - Rect<2> rect_weight = runtime->get_index_space_domain( - ctx, task->regions[2].region.get_index_space()); - coord_t batch_size = rect_input.hi[1] - rect_input.lo[1] + 1; - // Input and output have same batch size - assert(batch_size == rect_output.hi[1] - rect_output.lo[1] + 1); - // coord_t in_dim = rect_input.hi[0] - rect_input.lo[0] + 1; - coord_t out_dim = rect_output.hi[0] - rect_output.lo[0] + 1; - // Weight and output have same out dim - assert(out_dim == rect_weight.hi[1] - rect_weight.lo[1] + 1); - // const int64_t* input = acc_input.ptr(rect_input); - // const float* output = acc_output.ptr(rect_output); - // float* weight = acc_weight.ptr(rect_weight); - int block_size = out_dim; - int output_size = batch_size; - int index_size = rect_input.hi[1] - rect_input.lo[0] + 1; - int data_size = 1000000; // FIXME - std::vector lengths(output_size, 1); - embed_backward(acc_input.ptr(rect_input), - lengths.data(), - acc_output.ptr(rect_output), - acc_weight.ptr(rect_weight), - block_size, - output_size, - index_size, - data_size); +template <> +OpTaskSignature bwd_signature() { + OpTaskSignature bwd = infer_bwd_signature(fwd_signature()); + return bwd; } -EmbeddingMeta::EmbeddingMeta(FFHandler _handle, Op const *op) - : PerDeviceOpState(_handle, op) {} +template <> +void register_task() { + register_task(EMBED_BWD_TASK_ID, + "Embed Bwd", + bwd_signature(), + backward_task); } -; // namespace FlexFlow -namespace std { -size_t hash::operator()( - FlexFlow::EmbeddingParams const ¶ms) const { - size_t key = 0; - hash_combine(key, params.layer_guid.id); - hash_combine(key, params.out_channels); - hash_combine(key, params.aggr); - hash_combine(key, params.num_entries); - hash_combine(key, params.data_type); - return key; -} -}; // namespace std +} // namespace FlexFlow diff --git a/lib/runtime/src/ops/embedding.h b/lib/runtime/src/ops/embedding.h index 0496d93dd9..cd1b14fa66 100644 --- a/lib/runtime/src/ops/embedding.h +++ b/lib/runtime/src/ops/embedding.h @@ -2,124 +2,25 @@ #define _FLEXFLOW_EMBEDDING_H #include "op-attrs/ops/embedding.h" -#include "op_task_invocation.h" #include "sim_environment.h" +#include "task_spec/op_task_invocation.h" namespace FlexFlow { -template <> -void register_task(); template <> void register_task(); template <> void register_task(); -OpTaskInvocation init(EmbeddingAttrs const &); OpTaskInvocation forward(EmbeddingAttrs const &); OpTaskInvocation backward(EmbeddingAttrs const &); CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, EmbeddingAttrs const &attrs, - ParallelTensorShape const &input_shape, + InputParallelTensorDesc const &input_shape, ProfilingSettings const &settings, MachineView const &machine_view); -/* namespace Weight { */ -/* enum { */ -/* OUT_CHANNELS = 0, */ -/* VOCAB_SIZE = 1, */ -/* }; */ -/* }; */ - -/* namespace Output { */ -/* enum { OUT_CHANNELS = 0 }; */ -/* }; */ - -/* class Embedding; */ - -/* class Embedding : public Op { */ -/* public: */ -/* using Attrs = EmbeddingAttrs; */ - -/* Embedding(FFModel &model, */ -/* LayerID const &_layer_guid, */ -/* const ParallelTensor _input, */ -/* int _num_entries, */ -/* int _out_channels, */ -/* AggrMode _aggr, */ -/* bool allocate_weights, */ -/* DataType _dtype, */ -/* char const *name); */ -/* Embedding(FFModel &model, */ -/* Embedding const &other, */ -/* const ParallelTensor input, */ -/* bool allocate_weights); */ -/* Embedding(FFModel &model, */ -/* Attrs const ¶ms, */ -/* std::vector const &input, */ -/* bool allocate_weights = false, */ -/* char const *name = nullptr); */ -/* void init(FFModel const &) override; */ -/* void forward(FFModel const &) override; */ -/* void backward(FFModel const &) override; */ -/* // void update(const FFModel&); */ -/* // Parameter* get_parameter(int index); */ -/* // void create_weights(FFModel& model); */ -/* // void create_input_partition(FFModel& model); */ -/* static Op * */ -/* create_operator_from_layer(FFModel &model, */ -/* Layer const *layer, */ -/* std::vector const &inputs); - */ - -/* static PerDeviceOpState *init_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void forward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void backward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void */ -/* forward_task_cpu(Legion::Task const *task, */ -/* std::vector const ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void */ -/* backward_task_cpu(Legion::Task const *task, */ -/* std::vector const ®ions, - */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ - -/* bool measure_operator_cost(Simulator *sim, */ -/* MachineView const &pc, */ -/* CostMetrics &cost_metrics) const override; */ - -/* private: */ -/* int input_vocab_size_replica_dim() const; */ -/* int input_channel_out_replica_dim() const; */ -/* int output_vocab_size_replica_dim() const; */ - -/* int output_size(ParallelDim output_dims[MAX_TENSOR_DIM]); */ -/* int weight_size(ParallelDim weights_dims[MAX_TENSOR_DIM]); */ - -/* void register_mappings(); */ -/* void register_output_mappings(); */ -/* void register_weight_mappings(); */ - -/* public: */ -/* int num_entries, out_channels; */ -/* AggrMode aggr; */ -/* }; */ - } // namespace FlexFlow #endif