diff --git a/lib/kernels/include/kernels/flat_kernels.h b/lib/kernels/include/kernels/flat_kernels.h index 955bf115f4..90faec9427 100644 --- a/lib/kernels/include/kernels/flat_kernels.h +++ b/lib/kernels/include/kernels/flat_kernels.h @@ -1,26 +1,20 @@ #ifndef _FLEXFLOW_OPS_KERNELS_FLAT_KERNELS_H #define _FLEXFLOW_OPS_KERNELS_FLAT_KERNELS_H +#include "kernels/accessor.h" #include "kernels/device.h" namespace FlexFlow { - -class FlatPerDeviceState : public PerDeviceOpState { -public: - FlatPerDeviceState(FFHandler handle) : PerDeviceOpState(handle){}; -}; - namespace Kernels { namespace Flat { void forward_kernel(ffStream_t stream, - float const *input_ptr, - float *output_ptr, - size_t num_elements); + GenericTensorAccessorR input, + float *output_ptr); void backward_kernel(ffStream_t stream, + GenericTensorAccessorR input, float *input_grad_ptr, - float const *output_grad_ptr, - size_t num_elements); + float const *output_grad_ptr); } // namespace Flat } // namespace Kernels diff --git a/lib/kernels/src/cuda/flat_kernels.cu b/lib/kernels/src/cuda/flat_kernels.cu index ea00371aff..67cd28da87 100644 --- a/lib/kernels/src/cuda/flat_kernels.cu +++ b/lib/kernels/src/cuda/flat_kernels.cu @@ -13,40 +13,35 @@ * limitations under the License. */ -#include "kernels/cuda_helper.h" +#include "device.h" +#include "kernels/accessor.h" +#include "kernels/device.h" #include "kernels/flat_kernels.h" namespace FlexFlow { - namespace Kernels { namespace Flat { void forward_kernel(cudaStream_t stream, - float const *input_ptr, - float *output_ptr, - size_t num_elements) { + GenericTensorAccessorR input, + float *output_ptr) { checkCUDA(cudaMemcpyAsync(output_ptr, - input_ptr, - num_elements * sizeof(float), + input.get_float_ptr(), + (input.shape.num_elements()) * sizeof(float), cudaMemcpyDeviceToDevice, stream)); - // checkCUDA(cudaDeviceSynchronize()); } void backward_kernel(cudaStream_t stream, + GenericTensorAccessorR input, float *input_grad_ptr, - float const *output_grad_ptr, - size_t num_elements) { + float const *output_grad_ptr) { float alpha = 1.0f; apply_add_with_scale - <<>>( - input_grad_ptr, output_grad_ptr, num_elements, alpha); - // checkCUDA(cudaMemcpyAsync(acc_input_grad.ptr, acc_output_grad.ptr, - // acc_input_grad.rect.volume() * sizeof(float), - // cudaMemcpyDeviceToDevice)); - // checkCUDA(cudaDeviceSynchronize()); + <<>>( + input_grad_ptr, output_grad_ptr, input.shape.num_elements(), alpha); } } // namespace Flat diff --git a/lib/kernels/src/hip/flat_kernels.cpp b/lib/kernels/src/hip/flat_kernels.cpp index 47d6fa0ce8..93543e6177 100644 --- a/lib/kernels/src/hip/flat_kernels.cpp +++ b/lib/kernels/src/hip/flat_kernels.cpp @@ -23,32 +23,31 @@ namespace Kernels { namespace Flat { void forward_kernel(hipStream_t stream, - float const *input_ptr, - float *output_ptr, - size_t num_elements) { + GenericTensorAccessorR input, + float *output_ptr) { checkCUDA(hipMemcpyAsync(output_ptr, - input_ptr, - num_elements * sizeof(float), + input.get_float_ptr(), + (input.shape.num_elements()) * sizeof(float), hipMemcpyDeviceToDevice, stream)); // checkCUDA(hipDeviceSynchronize()); } void backward_kernel(hipStream_t stream, + GenericTensorAccessorR input, float *input_grad_ptr, - float const *output_grad_ptr, - size_t num_elements) { + float const *output_grad_ptr) { float alpha = 1.0f; hipLaunchKernelGGL(HIP_KERNEL_NAME(apply_add_with_scale), - GET_BLOCKS(num_elements), + GET_BLOCKS(input.shape.num_elements()), CUDA_NUM_THREADS, 0, stream, input_grad_ptr, output_grad_ptr, - num_elements, + input.shape.num_elements(), alpha); // checkCUDA(hipMemcpyAsync(acc_input_grad.ptr, acc_output_grad.ptr, // acc_input_grad.rect.volume() * sizeof(float), diff --git a/lib/runtime/src/ops/flat.cc b/lib/runtime/src/ops/flat.cc index 53b4c4b770..f53a6185b6 100644 --- a/lib/runtime/src/ops/flat.cc +++ b/lib/runtime/src/ops/flat.cc @@ -1,350 +1,136 @@ #include "flat.h" +#include "kernels/flat_kernels.h" +#include "op-attrs/get_output_shapes.h" namespace FlexFlow { // declare Legion names -using Legion::ArgumentMap; using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; using Legion::Runtime; using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; using namespace FlexFlow::Kernels::Flat; -Tensor FFModel::flat(const Tensor input, char const *name) { - assert(input->num_dims == 4); - Layer *flat = new Layer(this, - OP_FLAT, - DT_FLOAT, - name, - 1 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - input); - int dims[MAX_TENSOR_DIM]; - dims[1] = input->dims[3]; - dims[0] = input->dims[2] * input->dims[1] * input->dims[0]; - flat->outputs[0] = create_tensor_legion_ordering( - 2, dims, DT_FLOAT, flat, 0, true /*create_grad*/); - layers.push_back(flat); - return flat->outputs[0]; -} +enum SLOTS { INPUT, OUTPUT, HANDLE, PROFILING }; -Op *Flat::create_operator_from_layer( - FFModel &model, - Layer const *layer, - std::vector const &inputs) { - return new Flat(model, inputs[0], layer->name); -} +OpTaskInvocation forward(FlatAttrs const &attrs) { + OpTaskBinding binding; -int FlatParams::output_size(ParallelTensorShape const &input, - ParallelDim output_dims[MAX_TENSOR_DIM]) const { - output_dims[FlatOutput::REPLICA].is_replica_dim = true; - output_dims[FlatOutput::SAMPLE].size = input.dims[FlatInput::SAMPLE].size; - output_dims[FlatOutput::CHANNEL].size = - (input.dims[FlatInput::CHANNEL].size * - input.dims[FlatInput::HEIGHT].size * input.dims[FlatInput::WIDTH].size); + binding.bind(INPUT, input_tensor(0)); + binding.bind(OUTPUT, output_tensor(0)); - return FlatOutput::NUMDIM; + binding.bind_arg(PROFILING, profiling_settings()); + return {FLAT_FWD_TASK_ID, binding}; } -Flat::Flat(FFModel &model, const ParallelTensor _input, char const *name) - : Op(model, - OP_FLAT, - _input->data_type, - name, - 1 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - _input) { - assert(_input->num_dims == FlatInput::NUMDIM); - - Flat::construct_output_mappings(*this->parallel_dims_mapping); +OpTaskInvocation backward(FlatAttrs const &attrs) { + OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - ParallelDim output_dims[MAX_TENSOR_DIM]; - int output_ndims; - this->get_params().solve_dims( - this->inputs[0]->get_shape(), output_dims, &output_ndims); - - outputs[0] = model.create_parallel_tensor_legion_ordering( - output_ndims, output_dims, _input->data_type, this); - - assert(check_output_input_weight_parallel_dims()); + return {FLAT_BWD_TASK_ID, b}; } -Flat::Flat(FFModel &model, - FlatParams const ¶ms, - const ParallelTensor input, - char const *name) - : Flat(model, input, name) {} +static optional forward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_argument(PROFILING); + auto input = acc.get_tensor(INPUT); + auto output = acc.get_tensor(OUTPUT); -void Flat::init(FFModel const &ff) { - assert(check_output_input_weight_same_parallel_is()); - parallel_is = outputs[0]->parallel_is; - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - set_argumentmap_for_init(ff, argmap); - IndexLauncher launcher(FLAT_INIT_TASK_ID, - parallel_is, - TaskArgument(this, sizeof(Flat)), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - outputs[0]->machine_view.hash()); - launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - inputs[0]->region)); - launcher.add_field(0, FID_DATA); - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - outputs[0]->region)); - launcher.add_field(1, FID_DATA); - FutureMap fm = runtime->execute_index_space(ctx, launcher); - fm.wait_all_results(); - set_opmeta_from_futuremap(ff, fm); + return profile(forward_kernel, + profiling, + "[Flat] forward_time = %.2lfms\n", + input, + output.get_float_ptr()); } -PerDeviceOpState *Flat::init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - FFHandler handler = *((FFHandler const *)task->local_args); - FlatMeta *m = new FlatMeta(handler); - return m; -} - -void Flat::forward(FFModel const &ff) { - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - set_argumentmap_for_forward(ff, argmap); - IndexLauncher launcher(FLAT_FWD_TASK_ID, - parallel_is, - TaskArgument(NULL, 0), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - outputs[0]->machine_view.hash()); - launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - inputs[0]->region)); - launcher.add_field(0, FID_DATA); - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - outputs[0]->region)); - launcher.add_field(1, FID_DATA); - runtime->execute_index_space(ctx, launcher); -} - -/* - regions[0](I): input - regions[1](O): output -*/ -void Flat::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 2); - assert(task->regions.size() == 2); - TensorAccessorR acc_input( - regions[0], task->regions[0], FID_DATA, ctx, runtime); - TensorAccessorW acc_output(regions[1], - task->regions[1], - FID_DATA, - ctx, - runtime, - false /*readOutput*/); - assert(acc_input.rect.volume() == acc_output.rect.volume()); - - forward_kernel_wrapper( - acc_input.ptr, acc_output.ptr, acc_input.rect.volume()); - // checkCUDA(cudaDeviceSynchronize()); -} - -void Flat::backward(FFModel const &ff) { - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - set_argumentmap_for_backward(ff, argmap); - IndexLauncher launcher(FLAT_BWD_TASK_ID, - parallel_is, - TaskArgument(NULL, 0), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - outputs[0]->machine_view.hash()); - launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, - 0 /*projection id*/, - READ_WRITE, - EXCLUSIVE, - inputs[0]->region_grad)); - launcher.add_field(0, FID_DATA); - launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - outputs[0]->region_grad)); - launcher.add_field(1, FID_DATA); - runtime->execute_index_space(ctx, launcher); -} - -/* - regions[0](I/O) : input_grad - regions[1](I) : output_grad -*/ -void Flat::backward_task(Task const *task, +static void forward_task(Task const *task, std::vector const ®ions, Context ctx, Runtime *runtime) { - assert(regions.size() == 2); - assert(task->regions.size() == 2); - TensorAccessorW acc_input_grad(regions[0], - task->regions[0], - FID_DATA, - ctx, - runtime, - true /*readOutput*/); - TensorAccessorR acc_output_grad( - regions[1], task->regions[1], FID_DATA, ctx, runtime); - assert(acc_input_grad.rect.volume() == acc_output_grad.rect.volume()); - - backward_kernel_wrapper( - acc_input_grad.ptr, acc_output_grad.ptr, acc_input_grad.rect.volume()); + TaskArgumentAccessor acc(task, regions, ctx, runtime); + forward_task_impl(acc); } -Domain Flat::get_input_tensor_shape(ParallelConfig const &pc, - int input_idx, - int part_idx) const { - assert(input_idx < numInputs); - assert(pc.nDims == 3); - assert(pc.dim[0] == 1); - assert(pc.dim[2] == 1); +static optional backward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_argument(PROFILING); + + auto input = acc.get_tensor(INPUT); + auto input_grad = acc.get_tensor_grad(INPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); - Domain d; - d.dim = inputs[input_idx]->num_dims; - for (int i = 0; i < d.dim - 1; i++) { - d.rect_data[i] = 0; - d.rect_data[i + d.dim] = inputs[input_idx]->dims[i].size - 1; - } - assert(inputs[input_idx]->dims[d.dim - 2].size % pc.num_parts() == 0); - int dim_size = inputs[input_idx]->dims[d.dim - 2].size / pc.num_parts(); - d.rect_data[d.dim - 2] = part_idx * dim_size; - d.rect_data[2 * d.dim - 2] = d.rect_data[d.dim - 2] + dim_size - 1; - return d; + return profile(backward_kernel, + profiling, + "[Flat] forward_time = %.2lfms\n", + input, + input_grad.get_float_ptr(), + output_grad.get_float_ptr()); } -void Flat::serialize(Legion::Serializer &sez) const { - return; +static void backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + backward_task_impl(acc); } -bool Flat::measure_operator_cost(Simulator *sim, - MachineView const &mv, - CostMetrics &cost_metrics) const { - ParallelTensorBase sub_input, sub_output; - if (!outputs[0]->get_sub_tensor(mv, sub_output)) { - return false; - } - if (!inputs[0]->get_sub_tensor(mv, sub_input)) { - return false; - } +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + FlatAttrs const &attrs, + InputParallelTensorDesc const &input_shape, + ProfilingSettings const &settings, + MachineView const &mv) { + auto env = sim.new_environment(); - sim->free_all(); - float *input_ptr = (float *)sim->allocate(sub_input.get_volume(), DT_FLOAT); - assert(input_ptr != NULL); - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); + ParallelTensorShape output_shape = get_output_shape(attrs, input_shape.shape); - float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - assert(output_ptr != NULL); - cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - size_t num_elements = sub_output.get_volume(); + SimTaskBinding fwd_binding; + fwd_binding.bind(INPUT, input_shape); + fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind_arg(PROFILING, settings); - std::function forward, backward; - forward = [&] { - forward_kernel_wrapper(input_ptr, output_ptr, num_elements); - }; - if (sim->computationMode == COMP_MODE_TRAINING) { - float *input_grad_ptr = - (float *)sim->allocate(sub_input.get_volume(), DT_FLOAT); - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); + SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); - float *output_grad_ptr = - (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - cost_metrics.outputs_memory += - cost_metrics.total_mem_diff_from(sim->offset); + auto fwd_accessor = env.get_fwd_accessor(FLAT_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = env.get_bwd_accessor(FLAT_BWD_TASK_ID, bwd_binding); - assert(output_grad_ptr != NULL); - assert(input_grad_ptr != NULL); - backward = [&] { - backward_kernel_wrapper(input_grad_ptr, output_grad_ptr, num_elements); - }; - } + float forward_time = forward_task_impl(fwd_accessor).value(); + float backward_time = backward_task_impl(bwd_accessor).value(); - inner_measure_operator_cost(sim, forward, backward, cost_metrics); + float sync_time = default_estimate_sync_time(env); + return make_metrics(forward_time, backward_time, sync_time, env); +} + +template <> +OpTaskSignature fwd_signature() { + OpTaskSignature fwd(OpTaskType::FWD); - if (sim->computationMode == COMP_MODE_TRAINING) { - log_measure.debug( - "[Measure Flat] name(%s) forward_time(%.4lf) backward_time(%.4lf)\n", - name, - cost_metrics.forward_time, - cost_metrics.backward_time); - } else { - log_measure.debug("[Measure Flat] name(%s) forward_time(%.4lf)\n", - name, - cost_metrics.forward_time); - } + fwd.add_arg_slot(PROFILING); + fwd.add_input_slot(INPUT); + fwd.add_output_slot(OUTPUT); - return true; + return fwd; } -FlatParams Flat::get_params() const { - FlatParams params; - return params; +template <> +void register_task() { + register_task(FLAT_FWD_TASK_ID, + "Flat Fwd", + fwd_signature(), + forward_task); } -using PCG::Node; -/*static*/ -Node Flat::deserialize(FFModel &ff, - Legion::Deserializer &dez, - ParallelTensor inputs[], - int num_inputs) { - assert(num_inputs == 1); - return ff.get_or_create_node(inputs[0], {}); +template <> +OpTaskSignature bwd_signature() { + OpTaskSignature bwd = infer_bwd_signature(fwd_signature()); + + return bwd; } -Op *Flat::materialize(FFModel &ff, - ParallelTensor inputs[], - int num_inputs) const { - assert(num_inputs == 1); - return new Flat(ff, inputs[0], this->name); +template <> +void register_task() { + register_task(FLAT_BWD_TASK_ID, + "Flat Bwd", + bwd_signature(), + backward_task); } }; // namespace FlexFlow - -namespace std { -size_t hash::operator()( - FlexFlow::FlatParams const ¶ms) const { - size_t key = 0; - return hash{}(key); -} -}; // namespace std diff --git a/lib/runtime/src/ops/flat.h b/lib/runtime/src/ops/flat.h index 653e302e8e..13246028fb 100644 --- a/lib/runtime/src/ops/flat.h +++ b/lib/runtime/src/ops/flat.h @@ -2,19 +2,15 @@ #define _FLEXFLOW_FLAT_H #include "op-attrs/ops/flat.h" -#include "op_task_invocation.h" #include "sim_environment.h" namespace FlexFlow { -template <> -void register_task(); template <> void register_task(); template <> void register_task(); -OpTaskInvocation init(FlatAttrs const &); OpTaskInvocation forward(FlatAttrs const &); OpTaskInvocation backward(FlatAttrs const &); @@ -23,53 +19,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, ParallelTensorShape const &input_shape, ProfilingSettings const &settings, MachineView const &machine_view); - -/* namespace FlatInput { */ -/* constexpr int NUMDIM = 5, WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, */ -/* REPLICA = 4; */ -/* } */ - -/* namespace FlatOutput { */ -/* constexpr int NUMDIM = 3, CHANNEL = 0, SAMPLE = 1, REPLICA = 2; */ -/* } */ - -/* class Flat : public Op { */ -/* public: */ -/* Flat(FFModel &model, ParallelTensor const &input, char const *name); */ -/* Flat(FFModel &model, */ -/* FlatAttrs const ¶ms, */ -/* std::vector const &input, */ -/* char const *name = nullptr); */ - -/* void init(FFModel const &) override; */ -/* void forward(FFModel const &) override; */ -/* void backward(FFModel const &) override; */ -/* static Op * */ -/* create_operator_from_layer(FFModel &model, */ -/* Layer const *layer, */ -/* std::vector const &inputs); - */ - -/* static PerDeviceOpState *init_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void forward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void backward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* bool measure_operator_cost(Simulator *sim, */ -/* MachineView const &pc, */ -/* CostMetrics &cost_metrics) const override; */ -/* }; */ - } // namespace FlexFlow #endif