From 951222e7b4188adaafa8ac44b30e4998c4eda6e6 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Thu, 7 Sep 2023 13:36:40 -0700 Subject: [PATCH 1/9] cast --- lib/kernels/include/kernels/cast_kernels.h | 19 +- lib/runtime/src/ops/cast.cc | 503 +++++---------------- lib/runtime/src/ops/cast.h | 267 ++++++++++- 3 files changed, 396 insertions(+), 393 deletions(-) diff --git a/lib/kernels/include/kernels/cast_kernels.h b/lib/kernels/include/kernels/cast_kernels.h index d43446883c..44f98de751 100644 --- a/lib/kernels/include/kernels/cast_kernels.h +++ b/lib/kernels/include/kernels/cast_kernels.h @@ -3,19 +3,26 @@ #include "kernels/accessor.h" #include "kernels/device.h" -#include "op-attrs/ffconst.h" namespace FlexFlow { -class CastPerDeviceState : public PerDeviceOpState { -public: - CastPerDeviceState(FFHandler handle); - DataType input_data_type, output_data_type; +struct CastPerDeviceState { + PerDeviceFFHandle handle; + DataType input_data_type; + req output_data_type; }; +FF_VISITABLE_STRUCT_NO_EQ(CastPerDeviceState, + handle, + input_data_type, + output_data_type); + namespace Kernels { namespace Cast { +CastPerDeviceState + init_kernel(PerDeviceFFHandle const &, DataType input, DataType output); + void forward_kernel(ffStream_t stream, CastPerDeviceState const *, GenericTensorAccessorR const &input, @@ -30,4 +37,4 @@ void backward_kernel(ffStream_t stream, } // namespace Kernels } // namespace FlexFlow -#endif +#endif \ No newline at end of file diff --git a/lib/runtime/src/ops/cast.cc b/lib/runtime/src/ops/cast.cc index 23c1bc9940..c0a933f4e8 100644 --- a/lib/runtime/src/ops/cast.cc +++ b/lib/runtime/src/ops/cast.cc @@ -16,441 +16,178 @@ #include "cast.h" #include "kernels/cast_kernels.h" #include "legion/legion_utilities.h" -#include "task_spec.h" #include "utils/hash-utils.h" using namespace FlexFlow::Kernels::Cast; -namespace FlexFlow { - -enum Slots { - INPUT, - OUTPUT, - INPUT_GRAD, - OUTPUT_GRAD, - ATTRS, - PROFILING -} - -// declare Legion names -using Legion::ArgumentMap; using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; using Legion::Runtime; using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; - -Tensor FFModel::cast(const Tensor input, DataType dtype, char const *name) { - Layer *cast = new Layer(this, - OP_CAST, - dtype, - name, - 1 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - input); - int numdims = input->num_dims; - int dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdims; i++) { - dims[i] = input->dims[i]; - } - cast->outputs[0] = create_tensor_legion_ordering( - numdims, dims, dtype, cast, 0, true /*create_grad*/); - cast->add_int_property("dtype", dtype); - layers.push_back(cast); - return cast->outputs[0]; -} - -Op *Cast::create_operator_from_layer( - FFModel &model, - Layer const *layer, - std::vector const &inputs) { - long long value; - layer->get_int_property("dtype", value); - DataType dtype = (DataType)value; - return new Cast(model, inputs[0], dtype, layer->name); -} - -CastParams Cast::get_params() const { - CastParams params; - params.dtype = this->outputs[0]->data_type; - return params; -} - -Cast::Cast(FFModel &model, - ParallelTensor const &input, - DataType _dtype, - char const *name) - : Op(model, - OP_CAST, - _dtype, - name, - 1 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - input) { - numOutputs = 1; - numWeights = 0; - int numdim = input->num_dims; - ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdim; i++) { - dims[i] = input->dims[i]; - } - outputs[0] = - model.create_parallel_tensor_legion_ordering(numdim, dims, _dtype, this); -} - -Cast::Cast(FFModel &model, - CastParams const ¶ms, - ParallelTensor const &input, - char const *name) - : Cast(model, input, params.dtype, name) {} - -static OpTaskSignature get_init_task_signature() { - OpTaskSignature init(OpTaskType::INIT); - - init.add_arg_slot(ATTRS); - init.add_arg_slot(PROFILING); - - init.add_input_slot(INPUT); - init.add_output_slot(OUTPUT); - - return init; -} - -static OpTaskSignature get_fwd_task_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_arg_slot(ATTRS); - - fwd.add_input_slot(INPUT); - fwd.add_output_slot(OUTPUT); - - return init; -} - -static OpTaskSignature get_bwd_task_signature() { - OpTaskSignature bwd(OpTaskType::BWD); - - bwd.add_arg_slot(ATTRS); - bwd.add_input_grad_slot(INPUT_GRAD); - bwd.add_output_grad_slot(OUTPUT_GRAD); +namespace FlexFlow { - return bwd; -} +enum Slots { INPUT, OUTPUT, ATTRS, PROFILING, PER_DEVICE_STATE, HANDLE }; -OpTaskBinding Cast::get_init_task_binding() const { +// declare Legion names +// using Legion::ArgumentMap; +// using Legion::Context; +// using Legion::coord_t; +// using Legion::Domain; +// using Legion::FutureMap; +// using Legion::IndexLauncher; +// using Legion::PhysicalRegion; +// using Legion::Predicate; +// using Legion::Rect; +// using Legion::RegionRequirement; +// using Legion::Runtime; +// using Legion::Task; +// using Legion::TaskArgument; +// using Legion::TaskLauncher; + +OpTaskInvocation init(CastAttrs const &attrs) { OpTaskBinding binding; - binding.bind_arg(PROFILING, this->profiling); - binding.bind_arg(ATTRS, this->attrs); + binding.bind_arg(HANDLE, ff_handle()); binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return binding; + return {CAST_INIT_TASK_ID, binding}; } -OpTaskBinding Cast::get_fwd_task_binding() const { +OpTaskInvocation forward(CastAttrs const &attrs) { OpTaskBinding binding; - binding.bind_arg(ATTRS, this->attrs); + binding.bind_arg(PER_DEVICE_STATE, per_device_op_state()); + binding.bind_arg(PROFILING, profiling_settings()); binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return binding; + return {CAST_FWD_TASK_ID, binding}; } -OpTaskBinding Cast::get_bwd_task_binding() const { - OpTaskBinding binding; +OpTaskInvocation backward(CastAttrs const &attrs) { + OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - binding.bind_arg(ATTRS, this->attrs); + return {CAST_BWD_TASK_ID, binding}; +} - binding.bind(INPUT_GRAD, input_tensor(0).grad()); - binding.bind(OUTPUT_GRAD, output_tensor(0).grad()); +static DeviceSpecific + init_task_impl(TaskArgumentAccessor const &acc) { - return binding; -} + PerDeviceFFHandle handle = acc.get_argument(HANDLE); + auto input = acc.get_tensor(INPUT); + auto output = acc.get_tensor(OUTPUT); -void Cast::init(FFModel const &ff) { - this->execute_task(ff, CAST_INIT_TASK_ID, get_init_task_signature()); - // assert(check_output_input_weight_same_parallel_is()); - // parallel_is = outputs[0]->parallel_is; - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_init(ff, argmap); - // IndexLauncher launcher(CAST_INIT_TASK_ID, - // parallel_is, - // TaskArgument(this, sizeof(Cast)), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[0]->region)); - // launcher.add_field(1, FID_DATA); - // FutureMap fm = runtime->execute_index_space(ctx, launcher); - // fm.wait_all_results(); - // set_opmeta_from_futuremap(ff, fm); + DeviceSpecific per_device_state = + acc.create_device_specific( + init_kernel(handle, input.data_type, output.data_type)); + return per_device_state; } -PerDeviceOpState *Cast::init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { +static DeviceSpecific + init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { TaskArgumentAccessor acc(task, regions, ctx, runtime); - - FFHandler handler = *((FFHandler const *)task->local_args); - CastPerDeviceState *m = new CastPerDeviceState(handler); - bool profiling = acc.get_argument(PROFILING); - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - m->input_data_type = input->data_type; - m->output_data_type = output->data_type; - m->profiling = profiling; - return m; + return init_task_impl(acc); } -void Cast::forward(FFModel const &ff) { - this->execute_task(ff, CAST_FWD_TASK_ID, get_fwd_task_signature()); - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_forward(ff, argmap); - // IndexLauncher launcher(CAST_FWD_TASK_ID, - // parallel_is, - // TaskArgument(NULL, false), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(1, FID_DATA); - // runtime->execute_index_space(ctx, launcher); -} +static optional forward_task_impl(TaskArgumentAccessor const &acc) { + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + ProfilingSettings profiling = acc.get_argument(PROFILING); -// template -// void Cast::forward_task_with_1_type(Task const *task, -// std::vector const -// ®ions, Context ctx, Runtime *runtime) -// { -// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); -// if (m->output_data_type == DT_FLOAT) { -// Cast::forward_task_with_2_type(task, regions, ctx, runtime); -// } else if (m->output_data_type == DT_DOUBLE) { -// Cast::forward_task_with_2_type(task, regions, ctx, runtime); -// } else if (m->output_data_type == DT_INT32) { -// Cast::forward_task_with_2_type(task, regions, ctx, -// runtime); -// } else if (m->output_data_type == DT_INT64) { -// Cast::forward_task_with_2_type(task, regions, ctx, -// runtime); -// } -// } - -// template -// void Cast::forward_task_with_2_type(Task const *task, -// std::vector const -// ®ions, Context ctx, Runtime *runtime) -// { -// assert(regions.size() == 2); -// assert(task->regions.size() == regions.size()); -// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); -// // Domain input_domain = runtime->get_index_space_domain( -// // ctx, task->regions[0].region.get_index_space()); -// Domain output_domain = runtime->get_index_space_domain( -// ctx, task->regions[1].region.get_index_space()); -// const IDT *input_ptr = helperGetTensorPointerRO( -// regions[0], task->regions[0], FID_DATA, ctx, runtime); -// ODT *output_ptr = helperGetTensorPointerWO( -// regions[1], task->regions[1], FID_DATA, ctx, runtime); -// forward_kernel_wrapper( -// m, input_ptr, output_ptr, output_domain.get_volume()); -// } - -void Cast::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); - // if (m->input_data_type == DT_FLOAT) { - // Cast::forward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->input_data_type == DT_DOUBLE) { - // Cast::forward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->input_data_type == DT_INT32) { - // Cast::forward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->input_data_type == DT_INT64) { - // Cast::forward_task_with_1_type(task, regions, ctx, runtime); - // } - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - - profile(forward_kernel, - m->profiling, - "[Cast] forward_time = %.2lfms\n", - m, - input, - output) -} + auto input = acc.get_tensor(INPUT); + auto output = acc.get_tensor(OUTPUT); -void Cast::backward(FFModel const &ff) { - this->execute_task(ff, CAST_BWD_TASK_ID, get_bwd_task_signature()); - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_backward(ff, argmap); - // IndexLauncher launcher(CAST_BWD_TASK_ID, - // parallel_is, - // TaskArgument(NULL, false), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // outputs[0]->region_grad)); - // launcher.add_field(0, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // inputs[0]->region_grad)); - // launcher.add_field(1, FID_DATA); - // runtime->execute_index_space(ctx, launcher); + return profile(forward_kernel, + profiling, + "[Cast] forward_time = %.2lfms\n", + &per_device_state, + input, + output); } -// template -// void Cast::backward_task_with_1_type(Task const *task, -// std::vector const -// ®ions, Context ctx, Runtime *runtime) -// { -// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); -// if (m->input_data_type == DT_FLOAT) { -// Cast::backward_task_with_2_type(task, regions, ctx, runtime); -// } else if (m->input_data_type == DT_DOUBLE) { -// Cast::backward_task_with_2_type(task, regions, ctx, -// runtime); -// } else if (m->input_data_type == DT_INT32) { -// Cast::backward_task_with_2_type(task, regions, ctx, -// runtime); -// } else if (m->input_data_type == DT_INT64) { -// Cast::backward_task_with_2_type(task, regions, ctx, -// runtime); -// } -// } - -// template -// void Cast::backward_task_with_2_type(Task const *task, -// std::vector const -// ®ions, Context ctx, Runtime *runtime) -// { -// assert(regions.size() == 2); -// assert(task->regions.size() == regions.size()); -// // Domain input_domain = runtime->get_index_space_domain( -// // ctx, task->regions[0].region.get_index_space()); -// Domain output_domain = runtime->get_index_space_domain( -// ctx, task->regions[1].region.get_index_space()); -// const IDT *input_ptr = helperGetTensorPointerRO( -// regions[0], task->regions[0], FID_DATA, ctx, runtime); -// ODT *output_ptr = helperGetTensorPointerRW( -// regions[1], task->regions[1], FID_DATA, ctx, runtime); -// backward_kernel_wrapper( -// input_ptr, output_ptr, output_domain.get_volume()); -// } - -void Cast::backward_task(Task const *task, +static void forward_task(Task const *task, std::vector const ®ions, Context ctx, Runtime *runtime) { - CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); - // if (m->output_data_type == DT_FLOAT) { - // Cast::backward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->output_data_type == DT_DOUBLE) { - // Cast::backward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->output_data_type == DT_INT32) { - // Cast::backward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->output_data_type == DT_INT64) { - // Cast::backward_task_with_1_type(task, regions, ctx, runtime); - // } - auto input_grad = acc.get_tensor(INPUT); - auto output_grad = acc.get_tensor(OUTPUT); - - profile(backward_kernel, - m->profiling, - "[Cast] forward_time = %.2lfms\n", - m, - input_grad, - output_grad) + TaskArgumentAccessor acc(task, regions, ctx, runtime); + forward_task_impl(acc); +} + +static optional backward_task_impl(TaskArgumentAccessor const &acc) { + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + ProfilingSettings profiling = acc.get_argument(PROFILING); + + auto input_grad = acc.get_tensor_grad(INPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); + + return profile(backward_kernel, + profiling, + "[Cast] forward_time = %.2lfms\n", + &per_device_state, + input_grad, + output_grad); +} + +static void backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + backward_task_impl(acc); } -bool Cast::measure_operator_cost(Simulator *sim, - MachineView const &mv, - CostMetrics &cost_metrics) const { +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + CastAttrs const &attrs, + InputParallelTensorDesc const &input_shape, + ProfilingSettings const &settings, + MachineView const &mv) { + auto env = sim.new_environment(); + // Assume cast has no cost - cost_metrics.forward_time = 0.0f; - cost_metrics.backward_time = 0.0f; - cost_metrics.inputs_memory = 0; - cost_metrics.outputs_memory = 0; - cost_metrics.weights_memory = 0; - return true; + float forward_time = 0.0; + float backward_time = 0.0; + float sync_time = 0.0; + return make_metrics(forward_time, backward_time, sync_time, env); } -void Cast::serialize(Legion::Serializer &sez) const { - sez.serialize(this->outputs[0]->data_type); +template <> +void register_task() { + OpTaskSignature init(OpTaskType::INIT); + + init.add_unchecked_arg_slot(HANDLE); + + init.add_input_slot(INPUT); + init.add_output_slot(OUTPUT); + + register_task(CAST_INIT_TASK_ID, "Cast Init", init, init_task); } -using PCG::Node; +template <> +void register_task() { + OpTaskSignature fwd(OpTaskType::FWD); + + fwd.add_arg_slot(PROFILING); + fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); -Node Cast::deserialize(FFModel &ff, - Legion::Deserializer &dez, - ParallelTensor inputs[], - int num_inputs) { - assert(num_inputs == 1); - DataType dtype; - dez.deserialize(dtype); - return ff.get_or_create_node(inputs[0], {dtype}); + fwd.add_input_slot(INPUT); + fwd.add_output_slot(OUTPUT); + + register_task(CAST_FWD_TASK_ID, "Cast Fwd", fwd, forward_task); } -Op *Cast::materialize(FFModel &ff, - ParallelTensor inputs[], - int num_inputs) const { - assert(num_inputs == 1); - return new Cast(ff, inputs[0], this->outputs[0]->data_type, this->name); +template <> +void register_task() { + OpTaskSignature bwd = infer_bwd_signature(get_op_signature(CAST_FWD_TASK_ID)); + + register_task(CAST_BWD_TASK_ID, "Cast Bwd", bwd, backward_task); } -}; // namespace FlexFlow +}; // namespace FlexFlow \ No newline at end of file diff --git a/lib/runtime/src/ops/cast.h b/lib/runtime/src/ops/cast.h index 7d346584d5..731baa62ef 100644 --- a/lib/runtime/src/ops/cast.h +++ b/lib/runtime/src/ops/cast.h @@ -16,8 +16,8 @@ #define _FLEXFLOW_CAST_H #include "op-attrs/ops/cast.h" -#include "op_task_invocation.h" #include "sim_environment.h" +#include "task_spec/op_task_invocation.h" namespace FlexFlow { @@ -33,11 +33,54 @@ OpTaskInvocation forward(CastAttrs const &); OpTaskInvocation backward(CastAttrs const &); CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, - BatchNormAttrs const &attrs, + CastAttrs const &attrs, ParallelTensorShape const &input_shape, ProfilingSettings const &settings, MachineView const &machine_view); +} // namespace FlexFlow + +#endif + +// template +// void Cast::backward_task_with_1_type(Task const *task, +// std::vector const +// ®ions, Context ctx, Runtime *runtime) +// { +// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); +// if (m->input_data_type == DT_FLOAT) { +// Cast::backward_task_with_2_type(task, regions, ctx, runtime); +// } else if (m->input_data_type == DT_DOUBLE) { +// Cast::backward_task_with_2_type(task, regions, ctx, +// runtime); +// } else if (m->input_data_type == DT_INT32) { +// Cast::backward_task_with_2_type(task, regions, ctx, +// runtime); +// } else if (m->input_data_type == DT_INT64) { +// Cast::backward_task_with_2_type(task, regions, ctx, +// runtime); +// } +// } + +// template +// void Cast::backward_task_with_2_type(Task const *task, +// std::vector const +// ®ions, Context ctx, Runtime *runtime) +// { +// assert(regions.size() == 2); +// assert(task->regions.size() == regions.size()); +// // Domain input_domain = runtime->get_index_space_domain( +// // ctx, task->regions[0].region.get_index_space()); +// Domain output_domain = runtime->get_index_space_domain( +// ctx, task->regions[1].region.get_index_space()); +// const IDT *input_ptr = helperGetTensorPointerRO( +// regions[0], task->regions[0], FID_DATA, ctx, runtime); +// ODT *output_ptr = helperGetTensorPointerRW( +// regions[1], task->regions[1], FID_DATA, ctx, runtime); +// backward_kernel_wrapper( +// input_ptr, output_ptr, output_domain.get_volume()); +// } + /* class Cast : public Op { */ /* public: */ /* Cast(FFModel &model, */ @@ -80,6 +123,222 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, /* CostMetrics &cost_metrics) const; */ /* }; */ -} // namespace FlexFlow +// void Cast::backward(FFModel const &ff) { +// this->execute_task(ff, CAST_BWD_TASK_ID, get_bwd_task_signature()); +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_backward(ff, argmap); +// IndexLauncher launcher(CAST_BWD_TASK_ID, +// parallel_is, +// TaskArgument(NULL, false), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// outputs[0]->region_grad)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// inputs[0]->region_grad)); +// launcher.add_field(1, FID_DATA); +// runtime->execute_index_space(ctx, launcher); +// } -#endif +// template +// void Cast::forward_task_with_1_type(Task const *task, +// std::vector const +// ®ions, Context ctx, Runtime *runtime) +// { +// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); +// if (m->output_data_type == DT_FLOAT) { +// Cast::forward_task_with_2_type(task, regions, ctx, runtime); +// } else if (m->output_data_type == DT_DOUBLE) { +// Cast::forward_task_with_2_type(task, regions, ctx, runtime); +// } else if (m->output_data_type == DT_INT32) { +// Cast::forward_task_with_2_type(task, regions, ctx, +// runtime); +// } else if (m->output_data_type == DT_INT64) { +// Cast::forward_task_with_2_type(task, regions, ctx, +// runtime); +// } +// } + +// template +// void Cast::forward_task_with_2_type(Task const *task, +// std::vector const +// ®ions, Context ctx, Runtime *runtime) +// { +// assert(regions.size() == 2); +// assert(task->regions.size() == regions.size()); +// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); +// // Domain input_domain = runtime->get_index_space_domain( +// // ctx, task->regions[0].region.get_index_space()); +// Domain output_domain = runtime->get_index_space_domain( +// ctx, task->regions[1].region.get_index_space()); +// const IDT *input_ptr = helperGetTensorPointerRO( +// regions[0], task->regions[0], FID_DATA, ctx, runtime); +// ODT *output_ptr = helperGetTensorPointerWO( +// regions[1], task->regions[1], FID_DATA, ctx, runtime); +// forward_kernel_wrapper( +// m, input_ptr, output_ptr, output_domain.get_volume()); +// } + +// void Cast::forward(FFModel const &ff) { +// this->execute_task(ff, CAST_FWD_TASK_ID, get_fwd_task_signature()); +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_forward(ff, argmap); +// IndexLauncher launcher(CAST_FWD_TASK_ID, +// parallel_is, +// TaskArgument(NULL, false), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(1, FID_DATA); +// runtime->execute_index_space(ctx, launcher); +// } + +// void Cast::init(FFModel const &ff) { +// this->execute_task(ff, CAST_INIT_TASK_ID, get_init_task_signature()); +// assert(check_output_input_weight_same_parallel_is()); +// parallel_is = outputs[0]->parallel_is; +// ArgumentMap argmap; +// Context ctx = ff.config.lg_ctx; +// Runtime *runtime = ff.config.lg_hlr; +// set_argumentmap_for_init(ff, argmap); +// IndexLauncher launcher(CAST_INIT_TASK_ID, +// parallel_is, +// TaskArgument(this, sizeof(Cast)), +// argmap, +// Predicate::TRUE_PRED, +// false /*must*/, +// 0 /*mapper_id*/, +// outputs[0]->machine_view.hash()); +// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, +// 0 /*projection id*/, +// WRITE_ONLY, +// EXCLUSIVE, +// outputs[0]->region)); +// launcher.add_field(0, FID_DATA); +// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, +// 0 /*projection id*/, +// READ_ONLY, +// EXCLUSIVE, +// inputs[0]->region)); +// launcher.add_field(1, FID_DATA); +// FutureMap fm = runtime->execute_index_space(ctx, launcher); +// fm.wait_all_results(); +// set_opmeta_from_futuremap(ff, fm); +// } + +// void Cast::serialize(Legion::Serializer &sez) const { +// sez.serialize(this->outputs[0]->data_type); +// } + +// using PCG::Node; + +// Node Cast::deserialize(FFModel &ff, +// Legion::Deserializer &dez, +// ParallelTensor inputs[], +// int num_inputs) { +// assert(num_inputs == 1); +// DataType dtype; +// dez.deserialize(dtype); +// return ff.get_or_create_node(inputs[0], {dtype}); +// } + +// Op *Cast::materialize(FFModel &ff, +// ParallelTensor inputs[], +// int num_inputs) const { +// assert(num_inputs == 1); +// return new Cast(ff, inputs[0], this->outputs[0]->data_type, this->name); +// } + +// Cast::Cast(FFModel &model, +// ParallelTensor const &input, +// DataType _dtype, +// char const *name) +// : Op(model, +// OP_CAST, +// _dtype, +// name, +// 1 /*inputs*/, +// 0 /*weights*/, +// 1 /*outputs*/, +// input) { +// numOutputs = 1; +// numWeights = 0; +// int numdim = input->num_dims; +// ParallelDim dims[MAX_TENSOR_DIM]; +// for (int i = 0; i < numdim; i++) { +// dims[i] = input->dims[i]; +// } +// outputs[0] = +// model.create_parallel_tensor_legion_ordering(numdim, dims, _dtype, +// this); +// } + +// Tensor FFModel::cast(const Tensor input, DataType dtype, char const *name) { +// Layer *cast = new Layer(this, +// OP_CAST, +// dtype, +// name, +// 1 /*inputs*/, +// 0 /*weights*/, +// 1 /*outputs*/, +// input); +// int numdims = input->num_dims; +// int dims[MAX_TENSOR_DIM]; +// for (int i = 0; i < numdims; i++) { +// dims[i] = input->dims[i]; +// } +// cast->outputs[0] = create_tensor_legion_ordering( +// numdims, dims, dtype, cast, 0, true /*create_grad*/); +// cast->add_int_property("dtype", dtype); +// layers.push_back(cast); +// return cast->outputs[0]; +// } + +// Op *Cast::create_operator_from_layer( +// FFModel &model, +// Layer const *layer, +// std::vector const &inputs) { +// long long value; +// layer->get_int_property("dtype", value); +// DataType dtype = (DataType)value; +// return new Cast(model, inputs[0], dtype, layer->name); +// } + +// CastParams Cast::get_params() const { +// CastParams params; +// params.dtype = this->outputs[0]->data_type; +// return params; +// } + +// Cast::Cast(FFModel &model, +// CastParams const ¶ms, +// ParallelTensor const &input, +// char const *name) +// : Cast(model, input, params.dtype, name) {} \ No newline at end of file From 9dbc79f964f217676d98f5042108dbea6e2c0062 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 12 Sep 2023 12:58:27 -0700 Subject: [PATCH 2/9] cast cuda --- lib/kernels/include/kernels/cast_kernels.h | 26 +- lib/runtime/src/ops/cast.h | 309 +-------------------- 2 files changed, 17 insertions(+), 318 deletions(-) diff --git a/lib/kernels/include/kernels/cast_kernels.h b/lib/kernels/include/kernels/cast_kernels.h index 44f98de751..adb209ed76 100644 --- a/lib/kernels/include/kernels/cast_kernels.h +++ b/lib/kernels/include/kernels/cast_kernels.h @@ -3,35 +3,35 @@ #include "kernels/accessor.h" #include "kernels/device.h" +#include "kernels/ff_handle.h" +#include "op-attrs/activation.h" namespace FlexFlow { struct CastPerDeviceState { PerDeviceFFHandle handle; - DataType input_data_type; - req output_data_type; -}; + }; -FF_VISITABLE_STRUCT_NO_EQ(CastPerDeviceState, - handle, - input_data_type, - output_data_type); +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(CastPerDeviceState, handle); namespace Kernels { namespace Cast { -CastPerDeviceState - init_kernel(PerDeviceFFHandle const &, DataType input, DataType output); +CastPerDeviceState init_kernel(PerDeviceFFHandle const &handle); void forward_kernel(ffStream_t stream, - CastPerDeviceState const *, + CastPerDeviceState const *m, GenericTensorAccessorR const &input, - GenericTensorAccessorW const &output); + GenericTensorAccessorW const &output, + DataType input_type, + DataType output_type); void backward_kernel(ffStream_t stream, - CastPerDeviceState const *, + CastPerDeviceState const *m, GenericTensorAccessorR const &input, - GenericTensorAccessorW const &output); + GenericTensorAccessorW const &output, + DataType input_type, + DataType output_type); } // namespace Cast } // namespace Kernels diff --git a/lib/runtime/src/ops/cast.h b/lib/runtime/src/ops/cast.h index 731baa62ef..c9aa12bc61 100644 --- a/lib/runtime/src/ops/cast.h +++ b/lib/runtime/src/ops/cast.h @@ -32,313 +32,12 @@ OpTaskInvocation init(CastAttrs const &); OpTaskInvocation forward(CastAttrs const &); OpTaskInvocation backward(CastAttrs const &); -CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, +CostMetrics measure_operator_cost(SimEnvFactory const &sim, CastAttrs const &attrs, - ParallelTensorShape const &input_shape, + InputParallelTensorDesc const &input_shape, ProfilingSettings const &settings, - MachineView const &machine_view); + MachineView const &mv); } // namespace FlexFlow -#endif - -// template -// void Cast::backward_task_with_1_type(Task const *task, -// std::vector const -// ®ions, Context ctx, Runtime *runtime) -// { -// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); -// if (m->input_data_type == DT_FLOAT) { -// Cast::backward_task_with_2_type(task, regions, ctx, runtime); -// } else if (m->input_data_type == DT_DOUBLE) { -// Cast::backward_task_with_2_type(task, regions, ctx, -// runtime); -// } else if (m->input_data_type == DT_INT32) { -// Cast::backward_task_with_2_type(task, regions, ctx, -// runtime); -// } else if (m->input_data_type == DT_INT64) { -// Cast::backward_task_with_2_type(task, regions, ctx, -// runtime); -// } -// } - -// template -// void Cast::backward_task_with_2_type(Task const *task, -// std::vector const -// ®ions, Context ctx, Runtime *runtime) -// { -// assert(regions.size() == 2); -// assert(task->regions.size() == regions.size()); -// // Domain input_domain = runtime->get_index_space_domain( -// // ctx, task->regions[0].region.get_index_space()); -// Domain output_domain = runtime->get_index_space_domain( -// ctx, task->regions[1].region.get_index_space()); -// const IDT *input_ptr = helperGetTensorPointerRO( -// regions[0], task->regions[0], FID_DATA, ctx, runtime); -// ODT *output_ptr = helperGetTensorPointerRW( -// regions[1], task->regions[1], FID_DATA, ctx, runtime); -// backward_kernel_wrapper( -// input_ptr, output_ptr, output_domain.get_volume()); -// } - -/* class Cast : public Op { */ -/* public: */ -/* Cast(FFModel &model, */ -/* ParallelTensor const &input, */ -/* DataType dtype, */ -/* char const *name); */ -/* Cast(FFModel &model, */ -/* CastAttrs const ¶ms, */ -/* std::vector const &input, */ -/* char const *name = nullptr); */ -/* void init(FFModel const &) override; */ -/* void forward(FFModel const &) override; */ -/* void backward(FFModel const &) override; */ -/* static Op * */ -/* create_operator_from_layer(FFModel &model, */ -/* Layer const *layer, */ -/* std::vector const &inputs); - */ -/* static PerDeviceOpState *init_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void forward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void backward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* OpTaskBinding get_init_task_binding() const override; */ -/* OpTaskBinding get_fwd_task_binding() const override; */ -/* OpTaskBinding get_bwd_task_binding() const override; */ - -/* bool measure_operator_cost(Simulator *sim, */ -/* MachineView const &pc, */ -/* CostMetrics &cost_metrics) const; */ -/* }; */ - -// void Cast::backward(FFModel const &ff) { -// this->execute_task(ff, CAST_BWD_TASK_ID, get_bwd_task_signature()); -// ArgumentMap argmap; -// Context ctx = ff.config.lg_ctx; -// Runtime *runtime = ff.config.lg_hlr; -// set_argumentmap_for_backward(ff, argmap); -// IndexLauncher launcher(CAST_BWD_TASK_ID, -// parallel_is, -// TaskArgument(NULL, false), -// argmap, -// Predicate::TRUE_PRED, -// false /*must*/, -// 0 /*mapper_id*/, -// outputs[0]->machine_view.hash()); -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// outputs[0]->region_grad)); -// launcher.add_field(0, FID_DATA); -// launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, -// 0 /*projection id*/, -// WRITE_ONLY, -// EXCLUSIVE, -// inputs[0]->region_grad)); -// launcher.add_field(1, FID_DATA); -// runtime->execute_index_space(ctx, launcher); -// } - -// template -// void Cast::forward_task_with_1_type(Task const *task, -// std::vector const -// ®ions, Context ctx, Runtime *runtime) -// { -// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); -// if (m->output_data_type == DT_FLOAT) { -// Cast::forward_task_with_2_type(task, regions, ctx, runtime); -// } else if (m->output_data_type == DT_DOUBLE) { -// Cast::forward_task_with_2_type(task, regions, ctx, runtime); -// } else if (m->output_data_type == DT_INT32) { -// Cast::forward_task_with_2_type(task, regions, ctx, -// runtime); -// } else if (m->output_data_type == DT_INT64) { -// Cast::forward_task_with_2_type(task, regions, ctx, -// runtime); -// } -// } - -// template -// void Cast::forward_task_with_2_type(Task const *task, -// std::vector const -// ®ions, Context ctx, Runtime *runtime) -// { -// assert(regions.size() == 2); -// assert(task->regions.size() == regions.size()); -// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); -// // Domain input_domain = runtime->get_index_space_domain( -// // ctx, task->regions[0].region.get_index_space()); -// Domain output_domain = runtime->get_index_space_domain( -// ctx, task->regions[1].region.get_index_space()); -// const IDT *input_ptr = helperGetTensorPointerRO( -// regions[0], task->regions[0], FID_DATA, ctx, runtime); -// ODT *output_ptr = helperGetTensorPointerWO( -// regions[1], task->regions[1], FID_DATA, ctx, runtime); -// forward_kernel_wrapper( -// m, input_ptr, output_ptr, output_domain.get_volume()); -// } - -// void Cast::forward(FFModel const &ff) { -// this->execute_task(ff, CAST_FWD_TASK_ID, get_fwd_task_signature()); -// ArgumentMap argmap; -// Context ctx = ff.config.lg_ctx; -// Runtime *runtime = ff.config.lg_hlr; -// set_argumentmap_for_forward(ff, argmap); -// IndexLauncher launcher(CAST_FWD_TASK_ID, -// parallel_is, -// TaskArgument(NULL, false), -// argmap, -// Predicate::TRUE_PRED, -// false /*must*/, -// 0 /*mapper_id*/, -// outputs[0]->machine_view.hash()); -// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[0]->region)); -// launcher.add_field(0, FID_DATA); -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, -// 0 /*projection id*/, -// WRITE_ONLY, -// EXCLUSIVE, -// outputs[0]->region)); -// launcher.add_field(1, FID_DATA); -// runtime->execute_index_space(ctx, launcher); -// } - -// void Cast::init(FFModel const &ff) { -// this->execute_task(ff, CAST_INIT_TASK_ID, get_init_task_signature()); -// assert(check_output_input_weight_same_parallel_is()); -// parallel_is = outputs[0]->parallel_is; -// ArgumentMap argmap; -// Context ctx = ff.config.lg_ctx; -// Runtime *runtime = ff.config.lg_hlr; -// set_argumentmap_for_init(ff, argmap); -// IndexLauncher launcher(CAST_INIT_TASK_ID, -// parallel_is, -// TaskArgument(this, sizeof(Cast)), -// argmap, -// Predicate::TRUE_PRED, -// false /*must*/, -// 0 /*mapper_id*/, -// outputs[0]->machine_view.hash()); -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, -// 0 /*projection id*/, -// WRITE_ONLY, -// EXCLUSIVE, -// outputs[0]->region)); -// launcher.add_field(0, FID_DATA); -// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[0]->region)); -// launcher.add_field(1, FID_DATA); -// FutureMap fm = runtime->execute_index_space(ctx, launcher); -// fm.wait_all_results(); -// set_opmeta_from_futuremap(ff, fm); -// } - -// void Cast::serialize(Legion::Serializer &sez) const { -// sez.serialize(this->outputs[0]->data_type); -// } - -// using PCG::Node; - -// Node Cast::deserialize(FFModel &ff, -// Legion::Deserializer &dez, -// ParallelTensor inputs[], -// int num_inputs) { -// assert(num_inputs == 1); -// DataType dtype; -// dez.deserialize(dtype); -// return ff.get_or_create_node(inputs[0], {dtype}); -// } - -// Op *Cast::materialize(FFModel &ff, -// ParallelTensor inputs[], -// int num_inputs) const { -// assert(num_inputs == 1); -// return new Cast(ff, inputs[0], this->outputs[0]->data_type, this->name); -// } - -// Cast::Cast(FFModel &model, -// ParallelTensor const &input, -// DataType _dtype, -// char const *name) -// : Op(model, -// OP_CAST, -// _dtype, -// name, -// 1 /*inputs*/, -// 0 /*weights*/, -// 1 /*outputs*/, -// input) { -// numOutputs = 1; -// numWeights = 0; -// int numdim = input->num_dims; -// ParallelDim dims[MAX_TENSOR_DIM]; -// for (int i = 0; i < numdim; i++) { -// dims[i] = input->dims[i]; -// } -// outputs[0] = -// model.create_parallel_tensor_legion_ordering(numdim, dims, _dtype, -// this); -// } - -// Tensor FFModel::cast(const Tensor input, DataType dtype, char const *name) { -// Layer *cast = new Layer(this, -// OP_CAST, -// dtype, -// name, -// 1 /*inputs*/, -// 0 /*weights*/, -// 1 /*outputs*/, -// input); -// int numdims = input->num_dims; -// int dims[MAX_TENSOR_DIM]; -// for (int i = 0; i < numdims; i++) { -// dims[i] = input->dims[i]; -// } -// cast->outputs[0] = create_tensor_legion_ordering( -// numdims, dims, dtype, cast, 0, true /*create_grad*/); -// cast->add_int_property("dtype", dtype); -// layers.push_back(cast); -// return cast->outputs[0]; -// } - -// Op *Cast::create_operator_from_layer( -// FFModel &model, -// Layer const *layer, -// std::vector const &inputs) { -// long long value; -// layer->get_int_property("dtype", value); -// DataType dtype = (DataType)value; -// return new Cast(model, inputs[0], dtype, layer->name); -// } - -// CastParams Cast::get_params() const { -// CastParams params; -// params.dtype = this->outputs[0]->data_type; -// return params; -// } - -// Cast::Cast(FFModel &model, -// CastParams const ¶ms, -// ParallelTensor const &input, -// char const *name) -// : Cast(model, input, params.dtype, name) {} \ No newline at end of file +#endif \ No newline at end of file From 922e40b175533cf83cb8ededc309174f108ae297 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 12 Sep 2023 13:03:41 -0700 Subject: [PATCH 3/9] cast cc --- lib/runtime/src/ops/cast.cc | 47 +++++++++++++------------------------ 1 file changed, 16 insertions(+), 31 deletions(-) diff --git a/lib/runtime/src/ops/cast.cc b/lib/runtime/src/ops/cast.cc index c0a933f4e8..3ac11b2294 100644 --- a/lib/runtime/src/ops/cast.cc +++ b/lib/runtime/src/ops/cast.cc @@ -29,30 +29,11 @@ namespace FlexFlow { enum Slots { INPUT, OUTPUT, ATTRS, PROFILING, PER_DEVICE_STATE, HANDLE }; -// declare Legion names -// using Legion::ArgumentMap; -// using Legion::Context; -// using Legion::coord_t; -// using Legion::Domain; -// using Legion::FutureMap; -// using Legion::IndexLauncher; -// using Legion::PhysicalRegion; -// using Legion::Predicate; -// using Legion::Rect; -// using Legion::RegionRequirement; -// using Legion::Runtime; -// using Legion::Task; -// using Legion::TaskArgument; -// using Legion::TaskLauncher; - OpTaskInvocation init(CastAttrs const &attrs) { OpTaskBinding binding; binding.bind_arg(HANDLE, ff_handle()); - binding.bind(INPUT, input_tensor(0)); - binding.bind(OUTPUT, output_tensor(0)); - return {CAST_INIT_TASK_ID, binding}; } @@ -61,6 +42,7 @@ OpTaskInvocation forward(CastAttrs const &attrs) { binding.bind_arg(PER_DEVICE_STATE, per_device_op_state()); binding.bind_arg(PROFILING, profiling_settings()); + b.bind_arg(ATTRS, attrs); binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); @@ -78,12 +60,10 @@ static DeviceSpecific init_task_impl(TaskArgumentAccessor const &acc) { PerDeviceFFHandle handle = acc.get_argument(HANDLE); - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); DeviceSpecific per_device_state = acc.create_device_specific( - init_kernel(handle, input.data_type, output.data_type)); + init_kernel(handle)); return per_device_state; } @@ -97,9 +77,9 @@ static DeviceSpecific } static optional forward_task_impl(TaskArgumentAccessor const &acc) { - auto per_device_state = - acc.get_argument(PER_DEVICE_STATE); + auto per_device_state = acc.get_argument(PER_DEVICE_STATE); ProfilingSettings profiling = acc.get_argument(PROFILING); + auto const &attrs = acc.get_argument(ATTRS); auto input = acc.get_tensor(INPUT); auto output = acc.get_tensor(OUTPUT); @@ -109,7 +89,9 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { "[Cast] forward_time = %.2lfms\n", &per_device_state, input, - output); + output, + input.data_type, + attrs.dtype); } static void forward_task(Task const *task, @@ -121,9 +103,11 @@ static void forward_task(Task const *task, } static optional backward_task_impl(TaskArgumentAccessor const &acc) { - auto per_device_state = - acc.get_argument(PER_DEVICE_STATE); + auto per_device_state = acc.get_argument(PER_DEVICE_STATE); ProfilingSettings profiling = acc.get_argument(PROFILING); + auto const &attrs = acc.get_argument(ATTRS); + + auto input = acc.get_tensor(INPUT); auto input_grad = acc.get_tensor_grad(INPUT); auto output_grad = acc.get_tensor_grad(OUTPUT); @@ -133,7 +117,9 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { "[Cast] forward_time = %.2lfms\n", &per_device_state, input_grad, - output_grad); + output_grad, + input.data_type, + attrs.dtype); } static void backward_task(Task const *task, @@ -163,9 +149,7 @@ void register_task() { OpTaskSignature init(OpTaskType::INIT); init.add_unchecked_arg_slot(HANDLE); - - init.add_input_slot(INPUT); - init.add_output_slot(OUTPUT); + init.add_return_value(); register_task(CAST_INIT_TASK_ID, "Cast Init", init, init_task); } @@ -174,6 +158,7 @@ template <> void register_task() { OpTaskSignature fwd(OpTaskType::FWD); + fwd.add_arg_slot(ATTRS); fwd.add_arg_slot(PROFILING); fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); From a1bb222338ad4b5bcdc58c3af6d738ec970ce335 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 12 Sep 2023 15:19:37 -0700 Subject: [PATCH 4/9] format --- lib/kernels/include/kernels/cast_kernels.h | 4 ++-- lib/kernels/src/cuda/cast_kernels.cu | 23 +++++++++++++--------- lib/runtime/src/ops/cast.cc | 13 ++++++------ lib/runtime/src/ops/cast.h | 2 +- 4 files changed, 24 insertions(+), 18 deletions(-) diff --git a/lib/kernels/include/kernels/cast_kernels.h b/lib/kernels/include/kernels/cast_kernels.h index adb209ed76..db6d431c09 100644 --- a/lib/kernels/include/kernels/cast_kernels.h +++ b/lib/kernels/include/kernels/cast_kernels.h @@ -10,7 +10,7 @@ namespace FlexFlow { struct CastPerDeviceState { PerDeviceFFHandle handle; - }; +}; FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(CastPerDeviceState, handle); @@ -37,4 +37,4 @@ void backward_kernel(ffStream_t stream, } // namespace Kernels } // namespace FlexFlow -#endif \ No newline at end of file +#endif diff --git a/lib/kernels/src/cuda/cast_kernels.cu b/lib/kernels/src/cuda/cast_kernels.cu index 1afcdb56cc..86f4ca7d93 100644 --- a/lib/kernels/src/cuda/cast_kernels.cu +++ b/lib/kernels/src/cuda/cast_kernels.cu @@ -14,14 +14,11 @@ */ #include "kernels/cast_kernels.h" -#include "kernels/cuda_helper.h" #include "kernels/datatype_dispatch.h" +#include "device.h" +#include "kernels/device.h" namespace FlexFlow { - -CastPerDeviceState::CastPerDeviceState(FFHandler handle) - : PerDeviceOpState(handle) {} - namespace Kernels { namespace Cast { @@ -64,20 +61,28 @@ struct BackwardKernel { } }; +CastPerDeviceState init_kernel(PerDeviceFFHandle const &handle) { + return {handle}; +} + void forward_kernel(ffStream_t stream, CastPerDeviceState const *m, GenericTensorAccessorR const &input, - GenericTensorAccessorW const &output) { + GenericTensorAccessorW const &output, + DataType input_type, + DataType output_type) { DataTypeDispatch2{}( - m->input_data_type, m->output_data_type, stream, m, input, output); + input_type, output_type, stream, m, input, output); } void backward_kernel(ffStream_t stream, CastPerDeviceState const *m, GenericTensorAccessorR const &input, - GenericTensorAccessorW const &output) { + GenericTensorAccessorW const &output, + DataType input_type, + DataType output_type) { DataTypeDispatch2{}( - m->input_data_type, m->output_data_type, stream, m, input, output); + input_type, output_type, stream, m, input, output); } } // namespace Cast diff --git a/lib/runtime/src/ops/cast.cc b/lib/runtime/src/ops/cast.cc index 3ac11b2294..7561b8bf4b 100644 --- a/lib/runtime/src/ops/cast.cc +++ b/lib/runtime/src/ops/cast.cc @@ -62,8 +62,7 @@ static DeviceSpecific PerDeviceFFHandle handle = acc.get_argument(HANDLE); DeviceSpecific per_device_state = - acc.create_device_specific( - init_kernel(handle)); + acc.create_device_specific(init_kernel(handle)); return per_device_state; } @@ -77,7 +76,8 @@ static DeviceSpecific } static optional forward_task_impl(TaskArgumentAccessor const &acc) { - auto per_device_state = acc.get_argument(PER_DEVICE_STATE); + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); ProfilingSettings profiling = acc.get_argument(PROFILING); auto const &attrs = acc.get_argument(ATTRS); @@ -103,10 +103,11 @@ static void forward_task(Task const *task, } static optional backward_task_impl(TaskArgumentAccessor const &acc) { - auto per_device_state = acc.get_argument(PER_DEVICE_STATE); + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); ProfilingSettings profiling = acc.get_argument(PROFILING); auto const &attrs = acc.get_argument(ATTRS); - + auto input = acc.get_tensor(INPUT); auto input_grad = acc.get_tensor_grad(INPUT); @@ -175,4 +176,4 @@ void register_task() { register_task(CAST_BWD_TASK_ID, "Cast Bwd", bwd, backward_task); } -}; // namespace FlexFlow \ No newline at end of file +}; // namespace FlexFlow diff --git a/lib/runtime/src/ops/cast.h b/lib/runtime/src/ops/cast.h index c9aa12bc61..c0c500e869 100644 --- a/lib/runtime/src/ops/cast.h +++ b/lib/runtime/src/ops/cast.h @@ -40,4 +40,4 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, } // namespace FlexFlow -#endif \ No newline at end of file +#endif From b9307b3d90bb646180750c62149ed36f070588c2 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 3 Oct 2023 18:01:17 -0700 Subject: [PATCH 5/9] cast --- lib/kernels/include/kernels/cast_kernels.h | 17 ++--- lib/kernels/src/cuda/cast_kernels.cu | 22 +++--- lib/runtime/src/ops/cast.cc | 84 ++++++++-------------- 3 files changed, 44 insertions(+), 79 deletions(-) diff --git a/lib/kernels/include/kernels/cast_kernels.h b/lib/kernels/include/kernels/cast_kernels.h index db6d431c09..cb1b566c78 100644 --- a/lib/kernels/include/kernels/cast_kernels.h +++ b/lib/kernels/include/kernels/cast_kernels.h @@ -7,31 +7,22 @@ #include "op-attrs/activation.h" namespace FlexFlow { - -struct CastPerDeviceState { - PerDeviceFFHandle handle; -}; - -FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(CastPerDeviceState, handle); - namespace Kernels { namespace Cast { -CastPerDeviceState init_kernel(PerDeviceFFHandle const &handle); - void forward_kernel(ffStream_t stream, - CastPerDeviceState const *m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, DataType input_type, - DataType output_type); + DataType output_type, + PerDeviceFFHandle handle); void backward_kernel(ffStream_t stream, - CastPerDeviceState const *m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, DataType input_type, - DataType output_type); + DataType output_type, + PerDeviceFFHandle handle); } // namespace Cast } // namespace Kernels diff --git a/lib/kernels/src/cuda/cast_kernels.cu b/lib/kernels/src/cuda/cast_kernels.cu index 86f4ca7d93..b3272fc963 100644 --- a/lib/kernels/src/cuda/cast_kernels.cu +++ b/lib/kernels/src/cuda/cast_kernels.cu @@ -13,9 +13,9 @@ * limitations under the License. */ +#include "device.h" #include "kernels/cast_kernels.h" #include "kernels/datatype_dispatch.h" -#include "device.h" #include "kernels/device.h" namespace FlexFlow { @@ -40,7 +40,7 @@ __global__ void template struct ForwardKernel { void operator()(ffStream_t stream, - CastPerDeviceState const *m, + PerDeviceFFHandle handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { size_t volume = input.shape.get_volume(); @@ -52,7 +52,7 @@ struct ForwardKernel { template struct BackwardKernel { void operator()(ffStream_t stream, - CastPerDeviceState const *m, + PerDeviceFFHandle handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { size_t volume = input.shape.get_volume(); @@ -61,28 +61,24 @@ struct BackwardKernel { } }; -CastPerDeviceState init_kernel(PerDeviceFFHandle const &handle) { - return {handle}; -} - void forward_kernel(ffStream_t stream, - CastPerDeviceState const *m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, DataType input_type, - DataType output_type) { + DataType output_type, + PerDeviceFFHandle handle) { DataTypeDispatch2{}( - input_type, output_type, stream, m, input, output); + input_type, output_type, stream, handle, input, output); } void backward_kernel(ffStream_t stream, - CastPerDeviceState const *m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, DataType input_type, - DataType output_type) { + DataType output_type, + PerDeviceFFHandle handle) { DataTypeDispatch2{}( - input_type, output_type, stream, m, input, output); + input_type, output_type, stream, handle, input, output); } } // namespace Cast diff --git a/lib/runtime/src/ops/cast.cc b/lib/runtime/src/ops/cast.cc index 7561b8bf4b..00cd3af6ff 100644 --- a/lib/runtime/src/ops/cast.cc +++ b/lib/runtime/src/ops/cast.cc @@ -16,6 +16,7 @@ #include "cast.h" #include "kernels/cast_kernels.h" #include "legion/legion_utilities.h" +#include "task_spec/op_task_signature.h" #include "utils/hash-utils.h" using namespace FlexFlow::Kernels::Cast; @@ -27,22 +28,14 @@ using Legion::Task; namespace FlexFlow { -enum Slots { INPUT, OUTPUT, ATTRS, PROFILING, PER_DEVICE_STATE, HANDLE }; - -OpTaskInvocation init(CastAttrs const &attrs) { - OpTaskBinding binding; - - binding.bind_arg(HANDLE, ff_handle()); - - return {CAST_INIT_TASK_ID, binding}; -} +enum Slots { INPUT, OUTPUT, ATTRS, PROFILING, HANDLE }; OpTaskInvocation forward(CastAttrs const &attrs) { OpTaskBinding binding; - binding.bind_arg(PER_DEVICE_STATE, per_device_op_state()); + binding.bind_arg(HANDLE, ff_handle()); binding.bind_arg(PROFILING, profiling_settings()); - b.bind_arg(ATTRS, attrs); + binding.bind_arg(ATTRS, attrs); binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); @@ -56,30 +49,10 @@ OpTaskInvocation backward(CastAttrs const &attrs) { return {CAST_BWD_TASK_ID, binding}; } -static DeviceSpecific - init_task_impl(TaskArgumentAccessor const &acc) { - - PerDeviceFFHandle handle = acc.get_argument(HANDLE); - - DeviceSpecific per_device_state = - acc.create_device_specific(init_kernel(handle)); - return per_device_state; -} - -static DeviceSpecific - init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - return init_task_impl(acc); -} - static optional forward_task_impl(TaskArgumentAccessor const &acc) { - auto per_device_state = - acc.get_argument(PER_DEVICE_STATE); ProfilingSettings profiling = acc.get_argument(PROFILING); auto const &attrs = acc.get_argument(ATTRS); + PerDeviceFFHandle handle = acc.get_argument(HANDLE); auto input = acc.get_tensor(INPUT); auto output = acc.get_tensor(OUTPUT); @@ -87,11 +60,11 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { return profile(forward_kernel, profiling, "[Cast] forward_time = %.2lfms\n", - &per_device_state, input, output, input.data_type, - attrs.dtype); + attrs.dtype, + handle); } static void forward_task(Task const *task, @@ -103,10 +76,9 @@ static void forward_task(Task const *task, } static optional backward_task_impl(TaskArgumentAccessor const &acc) { - auto per_device_state = - acc.get_argument(PER_DEVICE_STATE); ProfilingSettings profiling = acc.get_argument(PROFILING); auto const &attrs = acc.get_argument(ATTRS); + PerDeviceFFHandle handle = acc.get_argument(HANDLE); auto input = acc.get_tensor(INPUT); @@ -116,11 +88,11 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { return profile(backward_kernel, profiling, "[Cast] forward_time = %.2lfms\n", - &per_device_state, input_grad, output_grad, input.data_type, - attrs.dtype); + attrs.dtype, + handle); } static void backward_task(Task const *task, @@ -146,34 +118,40 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, } template <> -void register_task() { - OpTaskSignature init(OpTaskType::INIT); - - init.add_unchecked_arg_slot(HANDLE); - init.add_return_value(); - - register_task(CAST_INIT_TASK_ID, "Cast Init", init, init_task); -} - -template <> -void register_task() { +OpTaskSignature fwd_signature() { OpTaskSignature fwd(OpTaskType::FWD); fwd.add_arg_slot(ATTRS); fwd.add_arg_slot(PROFILING); - fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); + fwd.add_unchecked_arg_slot(HANDLE); fwd.add_input_slot(INPUT); fwd.add_output_slot(OUTPUT); - register_task(CAST_FWD_TASK_ID, "Cast Fwd", fwd, forward_task); + return fwd; } template <> -void register_task() { +void register_task() { + register_task(CAST_FWD_TASK_ID, + "Cast Fwd", + fwd_signature(), + forward_task); +} + +template <> +OpTaskSignature fwd_signature() { OpTaskSignature bwd = infer_bwd_signature(get_op_signature(CAST_FWD_TASK_ID)); - register_task(CAST_BWD_TASK_ID, "Cast Bwd", bwd, backward_task); + return bwd; +} + +template <> +void register_task() { + register_task(CAST_BWD_TASK_ID, + "Cast Bwd", + bwd_signature(), + backward_task); } }; // namespace FlexFlow From cbc89b173337c43692a0026ce8628d4349f52e03 Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 3 Oct 2023 18:09:42 -0700 Subject: [PATCH 6/9] cast --- lib/runtime/src/ops/cast.cc | 14 -------------- lib/runtime/src/ops/cast.h | 6 ------ 2 files changed, 20 deletions(-) diff --git a/lib/runtime/src/ops/cast.cc b/lib/runtime/src/ops/cast.cc index 00cd3af6ff..cf438d4473 100644 --- a/lib/runtime/src/ops/cast.cc +++ b/lib/runtime/src/ops/cast.cc @@ -103,20 +103,6 @@ static void backward_task(Task const *task, backward_task_impl(acc); } -CostMetrics measure_operator_cost(SimEnvFactory const &sim, - CastAttrs const &attrs, - InputParallelTensorDesc const &input_shape, - ProfilingSettings const &settings, - MachineView const &mv) { - auto env = sim.new_environment(); - - // Assume cast has no cost - float forward_time = 0.0; - float backward_time = 0.0; - float sync_time = 0.0; - return make_metrics(forward_time, backward_time, sync_time, env); -} - template <> OpTaskSignature fwd_signature() { OpTaskSignature fwd(OpTaskType::FWD); diff --git a/lib/runtime/src/ops/cast.h b/lib/runtime/src/ops/cast.h index c0c500e869..7c42b02f41 100644 --- a/lib/runtime/src/ops/cast.h +++ b/lib/runtime/src/ops/cast.h @@ -32,12 +32,6 @@ OpTaskInvocation init(CastAttrs const &); OpTaskInvocation forward(CastAttrs const &); OpTaskInvocation backward(CastAttrs const &); -CostMetrics measure_operator_cost(SimEnvFactory const &sim, - CastAttrs const &attrs, - InputParallelTensorDesc const &input_shape, - ProfilingSettings const &settings, - MachineView const &mv); - } // namespace FlexFlow #endif From b88dee6a50e2f232231c1bb11bcb3d9f0c70e1ee Mon Sep 17 00:00:00 2001 From: Kate Unger Date: Tue, 3 Oct 2023 18:14:08 -0700 Subject: [PATCH 7/9] add back measure_operator --- lib/runtime/src/ops/cast.cc | 14 ++++++++++++++ lib/runtime/src/ops/cast.h | 6 ++++++ 2 files changed, 20 insertions(+) diff --git a/lib/runtime/src/ops/cast.cc b/lib/runtime/src/ops/cast.cc index cf438d4473..00cd3af6ff 100644 --- a/lib/runtime/src/ops/cast.cc +++ b/lib/runtime/src/ops/cast.cc @@ -103,6 +103,20 @@ static void backward_task(Task const *task, backward_task_impl(acc); } +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + CastAttrs const &attrs, + InputParallelTensorDesc const &input_shape, + ProfilingSettings const &settings, + MachineView const &mv) { + auto env = sim.new_environment(); + + // Assume cast has no cost + float forward_time = 0.0; + float backward_time = 0.0; + float sync_time = 0.0; + return make_metrics(forward_time, backward_time, sync_time, env); +} + template <> OpTaskSignature fwd_signature() { OpTaskSignature fwd(OpTaskType::FWD); diff --git a/lib/runtime/src/ops/cast.h b/lib/runtime/src/ops/cast.h index 7c42b02f41..c0c500e869 100644 --- a/lib/runtime/src/ops/cast.h +++ b/lib/runtime/src/ops/cast.h @@ -32,6 +32,12 @@ OpTaskInvocation init(CastAttrs const &); OpTaskInvocation forward(CastAttrs const &); OpTaskInvocation backward(CastAttrs const &); +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + CastAttrs const &attrs, + InputParallelTensorDesc const &input_shape, + ProfilingSettings const &settings, + MachineView const &mv); + } // namespace FlexFlow #endif From 86a52050cb1a9c8bed14fb422633ffaa2debbf52 Mon Sep 17 00:00:00 2001 From: Reyna Abhyankar Date: Fri, 6 Oct 2023 23:32:25 -0700 Subject: [PATCH 8/9] Fix hip and measure op cost --- lib/kernels/src/hip/cast_kernels.cpp | 20 ++++++++++---------- lib/runtime/src/ops/cast.cc | 26 ++++++++++++++++++++------ 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/lib/kernels/src/hip/cast_kernels.cpp b/lib/kernels/src/hip/cast_kernels.cpp index 732114335d..08037051c3 100644 --- a/lib/kernels/src/hip/cast_kernels.cpp +++ b/lib/kernels/src/hip/cast_kernels.cpp @@ -19,10 +19,6 @@ #include namespace FlexFlow { - -CastPerDeviceState::CastPerDeviceState(FFHandler handle) - : PerDeviceOpState(handle) {} - namespace Kernels { namespace Cast { @@ -44,7 +40,7 @@ __global__ void template struct ForwardKernel { void operator()(ffStream_t stream, - CastPerDeviceState const *m, + PerDeviceFFHandle handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { size_t volume = input.shape.get_volume(); @@ -62,7 +58,7 @@ struct ForwardKernel { template struct BackwardKernel { void operator()(ffStream_t stream, - CastPerDeviceState const *m, + PerDeviceFFHandle handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { size_t volume = input.shape.get_volume(); @@ -79,17 +75,21 @@ struct BackwardKernel { }; void forward_kernel(ffStream_t stream, - CastPerDeviceState const *m, GenericTensorAccessorR const &input, - GenericTensorAccessorW const &output) { + GenericTensorAccessorW const &output, + DataType input_type, + DataType output_type, + PerDeviceFFHandle handle) { DataTypeDispatch2{}( m->input_data_type, m->output_data_type, stream, m, input, output); } void backward_kernel(ffStream_t stream, - CastPerDeviceState const *m, GenericTensorAccessorR const &input, - GenericTensorAccessorW const &output) { + GenericTensorAccessorW const &output, + DataType input_type, + DataType output_type, + PerDeviceFFHandle handle) { DataTypeDispatch2{}( m->input_data_type, m->output_data_type, stream, m, input, output); } diff --git a/lib/runtime/src/ops/cast.cc b/lib/runtime/src/ops/cast.cc index 00cd3af6ff..57b0167b4e 100644 --- a/lib/runtime/src/ops/cast.cc +++ b/lib/runtime/src/ops/cast.cc @@ -110,10 +110,24 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, MachineView const &mv) { auto env = sim.new_environment(); - // Assume cast has no cost - float forward_time = 0.0; - float backward_time = 0.0; - float sync_time = 0.0; + SimTaskBinding fwd_binding; + fwd_binding.bind_arg(HANDLE, ff_handle()); + fwd_binding.bind_arg(PROFILING, settings); + fwd_binding.bind_arg(ATTRS, attrs); + + fwd_binding.bind(INPUT, input_shape); + fwd_binding.bind(OUTPUT, input_shape); // cast does not change shape + + SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); + + auto fwd_accessor = env.get_fwd_accessor(CAST_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = env.get_bwd_accessor(CAST_BWD_TASK_ID, bwd_binding); + + float forward_time = forward_task_impl(fwd_accessor).value(); + float backward_time = backward_task_impl(bwd_accessor).value(); + + float sync_time = default_estimate_sync_time(env); + return make_metrics(forward_time, backward_time, sync_time, env); } @@ -140,8 +154,8 @@ void register_task() { } template <> -OpTaskSignature fwd_signature() { - OpTaskSignature bwd = infer_bwd_signature(get_op_signature(CAST_FWD_TASK_ID)); +OpTaskSignature bwd_signature() { + OpTaskSignature bwd = infer_bwd_signature(fwd_signature()); return bwd; } From 47398912969edb56a14d22a906cc217970a494f3 Mon Sep 17 00:00:00 2001 From: Reyna Abhyankar Date: Mon, 9 Oct 2023 16:22:20 -0700 Subject: [PATCH 9/9] Remove handle --- lib/kernels/include/kernels/cast_kernels.h | 6 ++---- lib/kernels/src/cuda/cast_kernels.cu | 8 ++------ lib/kernels/src/hip/cast_kernels.cpp | 12 ++++-------- lib/runtime/src/ops/cast.cc | 13 +++---------- 4 files changed, 11 insertions(+), 28 deletions(-) diff --git a/lib/kernels/include/kernels/cast_kernels.h b/lib/kernels/include/kernels/cast_kernels.h index cb1b566c78..cc10342e75 100644 --- a/lib/kernels/include/kernels/cast_kernels.h +++ b/lib/kernels/include/kernels/cast_kernels.h @@ -14,15 +14,13 @@ void forward_kernel(ffStream_t stream, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, DataType input_type, - DataType output_type, - PerDeviceFFHandle handle); + DataType output_type); void backward_kernel(ffStream_t stream, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, DataType input_type, - DataType output_type, - PerDeviceFFHandle handle); + DataType output_type); } // namespace Cast } // namespace Kernels diff --git a/lib/kernels/src/cuda/cast_kernels.cu b/lib/kernels/src/cuda/cast_kernels.cu index b3272fc963..3d8804862d 100644 --- a/lib/kernels/src/cuda/cast_kernels.cu +++ b/lib/kernels/src/cuda/cast_kernels.cu @@ -40,7 +40,6 @@ __global__ void template struct ForwardKernel { void operator()(ffStream_t stream, - PerDeviceFFHandle handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { size_t volume = input.shape.get_volume(); @@ -52,7 +51,6 @@ struct ForwardKernel { template struct BackwardKernel { void operator()(ffStream_t stream, - PerDeviceFFHandle handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { size_t volume = input.shape.get_volume(); @@ -65,8 +63,7 @@ void forward_kernel(ffStream_t stream, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, DataType input_type, - DataType output_type, - PerDeviceFFHandle handle) { + DataType output_type) { DataTypeDispatch2{}( input_type, output_type, stream, handle, input, output); } @@ -75,8 +72,7 @@ void backward_kernel(ffStream_t stream, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, DataType input_type, - DataType output_type, - PerDeviceFFHandle handle) { + DataType output_type) { DataTypeDispatch2{}( input_type, output_type, stream, handle, input, output); } diff --git a/lib/kernels/src/hip/cast_kernels.cpp b/lib/kernels/src/hip/cast_kernels.cpp index 08037051c3..cf0ea83275 100644 --- a/lib/kernels/src/hip/cast_kernels.cpp +++ b/lib/kernels/src/hip/cast_kernels.cpp @@ -40,7 +40,6 @@ __global__ void template struct ForwardKernel { void operator()(ffStream_t stream, - PerDeviceFFHandle handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { size_t volume = input.shape.get_volume(); @@ -58,7 +57,6 @@ struct ForwardKernel { template struct BackwardKernel { void operator()(ffStream_t stream, - PerDeviceFFHandle handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { size_t volume = input.shape.get_volume(); @@ -78,20 +76,18 @@ void forward_kernel(ffStream_t stream, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, DataType input_type, - DataType output_type, - PerDeviceFFHandle handle) { + DataType output_type) { DataTypeDispatch2{}( - m->input_data_type, m->output_data_type, stream, m, input, output); + input_type, output_type, stream, input, output); } void backward_kernel(ffStream_t stream, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, DataType input_type, - DataType output_type, - PerDeviceFFHandle handle) { + DataType output_type) { DataTypeDispatch2{}( - m->input_data_type, m->output_data_type, stream, m, input, output); + input_type, output_type, stream, input, output); } } // namespace Cast diff --git a/lib/runtime/src/ops/cast.cc b/lib/runtime/src/ops/cast.cc index 57b0167b4e..44230eaf46 100644 --- a/lib/runtime/src/ops/cast.cc +++ b/lib/runtime/src/ops/cast.cc @@ -28,12 +28,11 @@ using Legion::Task; namespace FlexFlow { -enum Slots { INPUT, OUTPUT, ATTRS, PROFILING, HANDLE }; +enum Slots { INPUT, OUTPUT, ATTRS, PROFILING }; OpTaskInvocation forward(CastAttrs const &attrs) { OpTaskBinding binding; - binding.bind_arg(HANDLE, ff_handle()); binding.bind_arg(PROFILING, profiling_settings()); binding.bind_arg(ATTRS, attrs); @@ -52,7 +51,6 @@ OpTaskInvocation backward(CastAttrs const &attrs) { static optional forward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto const &attrs = acc.get_argument(ATTRS); - PerDeviceFFHandle handle = acc.get_argument(HANDLE); auto input = acc.get_tensor(INPUT); auto output = acc.get_tensor(OUTPUT); @@ -63,8 +61,7 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { input, output, input.data_type, - attrs.dtype, - handle); + attrs.dtype); } static void forward_task(Task const *task, @@ -78,7 +75,6 @@ static void forward_task(Task const *task, static optional backward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto const &attrs = acc.get_argument(ATTRS); - PerDeviceFFHandle handle = acc.get_argument(HANDLE); auto input = acc.get_tensor(INPUT); @@ -91,8 +87,7 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { input_grad, output_grad, input.data_type, - attrs.dtype, - handle); + attrs.dtype); } static void backward_task(Task const *task, @@ -111,7 +106,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, auto env = sim.new_environment(); SimTaskBinding fwd_binding; - fwd_binding.bind_arg(HANDLE, ff_handle()); fwd_binding.bind_arg(PROFILING, settings); fwd_binding.bind_arg(ATTRS, attrs); @@ -137,7 +131,6 @@ OpTaskSignature fwd_signature() { fwd.add_arg_slot(ATTRS); fwd.add_arg_slot(PROFILING); - fwd.add_unchecked_arg_slot(HANDLE); fwd.add_input_slot(INPUT); fwd.add_output_slot(OUTPUT);