diff --git a/lib/kernels/include/kernels/cast_kernels.h b/lib/kernels/include/kernels/cast_kernels.h index d43446883c..cc10342e75 100644 --- a/lib/kernels/include/kernels/cast_kernels.h +++ b/lib/kernels/include/kernels/cast_kernels.h @@ -3,28 +3,24 @@ #include "kernels/accessor.h" #include "kernels/device.h" -#include "op-attrs/ffconst.h" +#include "kernels/ff_handle.h" +#include "op-attrs/activation.h" namespace FlexFlow { - -class CastPerDeviceState : public PerDeviceOpState { -public: - CastPerDeviceState(FFHandler handle); - DataType input_data_type, output_data_type; -}; - namespace Kernels { namespace Cast { void forward_kernel(ffStream_t stream, - CastPerDeviceState const *, GenericTensorAccessorR const &input, - GenericTensorAccessorW const &output); + GenericTensorAccessorW const &output, + DataType input_type, + DataType output_type); void backward_kernel(ffStream_t stream, - CastPerDeviceState const *, GenericTensorAccessorR const &input, - GenericTensorAccessorW const &output); + GenericTensorAccessorW const &output, + DataType input_type, + DataType output_type); } // namespace Cast } // namespace Kernels diff --git a/lib/kernels/src/cuda/cast_kernels.cu b/lib/kernels/src/cuda/cast_kernels.cu index 1afcdb56cc..3d8804862d 100644 --- a/lib/kernels/src/cuda/cast_kernels.cu +++ b/lib/kernels/src/cuda/cast_kernels.cu @@ -13,15 +13,12 @@ * limitations under the License. */ +#include "device.h" #include "kernels/cast_kernels.h" -#include "kernels/cuda_helper.h" #include "kernels/datatype_dispatch.h" +#include "kernels/device.h" namespace FlexFlow { - -CastPerDeviceState::CastPerDeviceState(FFHandler handle) - : PerDeviceOpState(handle) {} - namespace Kernels { namespace Cast { @@ -43,7 +40,6 @@ __global__ void template struct ForwardKernel { void operator()(ffStream_t stream, - CastPerDeviceState const *m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { size_t volume = input.shape.get_volume(); @@ -55,7 +51,6 @@ struct ForwardKernel { template struct BackwardKernel { void operator()(ffStream_t stream, - CastPerDeviceState const *m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { size_t volume = input.shape.get_volume(); @@ -65,19 +60,21 @@ struct BackwardKernel { }; void forward_kernel(ffStream_t stream, - CastPerDeviceState const *m, GenericTensorAccessorR const &input, - GenericTensorAccessorW const &output) { + GenericTensorAccessorW const &output, + DataType input_type, + DataType output_type) { DataTypeDispatch2{}( - m->input_data_type, m->output_data_type, stream, m, input, output); + input_type, output_type, stream, handle, input, output); } void backward_kernel(ffStream_t stream, - CastPerDeviceState const *m, GenericTensorAccessorR const &input, - GenericTensorAccessorW const &output) { + GenericTensorAccessorW const &output, + DataType input_type, + DataType output_type) { DataTypeDispatch2{}( - m->input_data_type, m->output_data_type, stream, m, input, output); + input_type, output_type, stream, handle, input, output); } } // namespace Cast diff --git a/lib/kernels/src/hip/cast_kernels.cpp b/lib/kernels/src/hip/cast_kernels.cpp index 732114335d..cf0ea83275 100644 --- a/lib/kernels/src/hip/cast_kernels.cpp +++ b/lib/kernels/src/hip/cast_kernels.cpp @@ -19,10 +19,6 @@ #include namespace FlexFlow { - -CastPerDeviceState::CastPerDeviceState(FFHandler handle) - : PerDeviceOpState(handle) {} - namespace Kernels { namespace Cast { @@ -44,7 +40,6 @@ __global__ void template struct ForwardKernel { void operator()(ffStream_t stream, - CastPerDeviceState const *m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { size_t volume = input.shape.get_volume(); @@ -62,7 +57,6 @@ struct ForwardKernel { template struct BackwardKernel { void operator()(ffStream_t stream, - CastPerDeviceState const *m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { size_t volume = input.shape.get_volume(); @@ -79,19 +73,21 @@ struct BackwardKernel { }; void forward_kernel(ffStream_t stream, - CastPerDeviceState const *m, GenericTensorAccessorR const &input, - GenericTensorAccessorW const &output) { + GenericTensorAccessorW const &output, + DataType input_type, + DataType output_type) { DataTypeDispatch2{}( - m->input_data_type, m->output_data_type, stream, m, input, output); + input_type, output_type, stream, input, output); } void backward_kernel(ffStream_t stream, - CastPerDeviceState const *m, GenericTensorAccessorR const &input, - GenericTensorAccessorW const &output) { + GenericTensorAccessorW const &output, + DataType input_type, + DataType output_type) { DataTypeDispatch2{}( - m->input_data_type, m->output_data_type, stream, m, input, output); + input_type, output_type, stream, input, output); } } // namespace Cast diff --git a/lib/runtime/src/ops/cast.cc b/lib/runtime/src/ops/cast.cc index 23c1bc9940..44230eaf46 100644 --- a/lib/runtime/src/ops/cast.cc +++ b/lib/runtime/src/ops/cast.cc @@ -16,441 +16,149 @@ #include "cast.h" #include "kernels/cast_kernels.h" #include "legion/legion_utilities.h" -#include "task_spec.h" +#include "task_spec/op_task_signature.h" #include "utils/hash-utils.h" using namespace FlexFlow::Kernels::Cast; -namespace FlexFlow { - -enum Slots { - INPUT, - OUTPUT, - INPUT_GRAD, - OUTPUT_GRAD, - ATTRS, - PROFILING -} - -// declare Legion names -using Legion::ArgumentMap; using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; using Legion::Runtime; using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; - -Tensor FFModel::cast(const Tensor input, DataType dtype, char const *name) { - Layer *cast = new Layer(this, - OP_CAST, - dtype, - name, - 1 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - input); - int numdims = input->num_dims; - int dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdims; i++) { - dims[i] = input->dims[i]; - } - cast->outputs[0] = create_tensor_legion_ordering( - numdims, dims, dtype, cast, 0, true /*create_grad*/); - cast->add_int_property("dtype", dtype); - layers.push_back(cast); - return cast->outputs[0]; -} - -Op *Cast::create_operator_from_layer( - FFModel &model, - Layer const *layer, - std::vector const &inputs) { - long long value; - layer->get_int_property("dtype", value); - DataType dtype = (DataType)value; - return new Cast(model, inputs[0], dtype, layer->name); -} -CastParams Cast::get_params() const { - CastParams params; - params.dtype = this->outputs[0]->data_type; - return params; -} - -Cast::Cast(FFModel &model, - ParallelTensor const &input, - DataType _dtype, - char const *name) - : Op(model, - OP_CAST, - _dtype, - name, - 1 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - input) { - numOutputs = 1; - numWeights = 0; - int numdim = input->num_dims; - ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdim; i++) { - dims[i] = input->dims[i]; - } - outputs[0] = - model.create_parallel_tensor_legion_ordering(numdim, dims, _dtype, this); -} +namespace FlexFlow { -Cast::Cast(FFModel &model, - CastParams const ¶ms, - ParallelTensor const &input, - char const *name) - : Cast(model, input, params.dtype, name) {} +enum Slots { INPUT, OUTPUT, ATTRS, PROFILING }; -static OpTaskSignature get_init_task_signature() { - OpTaskSignature init(OpTaskType::INIT); +OpTaskInvocation forward(CastAttrs const &attrs) { + OpTaskBinding binding; - init.add_arg_slot(ATTRS); - init.add_arg_slot(PROFILING); + binding.bind_arg(PROFILING, profiling_settings()); + binding.bind_arg(ATTRS, attrs); - init.add_input_slot(INPUT); - init.add_output_slot(OUTPUT); + binding.bind(INPUT, input_tensor(0)); + binding.bind(OUTPUT, output_tensor(0)); - return init; + return {CAST_FWD_TASK_ID, binding}; } -static OpTaskSignature get_fwd_task_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_arg_slot(ATTRS); - - fwd.add_input_slot(INPUT); - fwd.add_output_slot(OUTPUT); +OpTaskInvocation backward(CastAttrs const &attrs) { + OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return init; + return {CAST_BWD_TASK_ID, binding}; } -static OpTaskSignature get_bwd_task_signature() { - OpTaskSignature bwd(OpTaskType::BWD); +static optional forward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_argument(PROFILING); + auto const &attrs = acc.get_argument(ATTRS); - bwd.add_arg_slot(ATTRS); + auto input = acc.get_tensor(INPUT); + auto output = acc.get_tensor(OUTPUT); - bwd.add_input_grad_slot(INPUT_GRAD); - bwd.add_output_grad_slot(OUTPUT_GRAD); + return profile(forward_kernel, + profiling, + "[Cast] forward_time = %.2lfms\n", + input, + output, + input.data_type, + attrs.dtype); +} - return bwd; +static void forward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + forward_task_impl(acc); } -OpTaskBinding Cast::get_init_task_binding() const { - OpTaskBinding binding; +static optional backward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_argument(PROFILING); + auto const &attrs = acc.get_argument(ATTRS); - binding.bind_arg(PROFILING, this->profiling); - binding.bind_arg(ATTRS, this->attrs); + auto input = acc.get_tensor(INPUT); - binding.bind(INPUT, input_tensor(0)); - binding.bind(OUTPUT, output_tensor(0)); + auto input_grad = acc.get_tensor_grad(INPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); - return binding; + return profile(backward_kernel, + profiling, + "[Cast] forward_time = %.2lfms\n", + input_grad, + output_grad, + input.data_type, + attrs.dtype); } -OpTaskBinding Cast::get_fwd_task_binding() const { - OpTaskBinding binding; - - binding.bind_arg(ATTRS, this->attrs); - - binding.bind(INPUT, input_tensor(0)); - binding.bind(OUTPUT, output_tensor(0)); - - return binding; +static void backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + backward_task_impl(acc); } -OpTaskBinding Cast::get_bwd_task_binding() const { - OpTaskBinding binding; +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + CastAttrs const &attrs, + InputParallelTensorDesc const &input_shape, + ProfilingSettings const &settings, + MachineView const &mv) { + auto env = sim.new_environment(); - binding.bind_arg(ATTRS, this->attrs); + SimTaskBinding fwd_binding; + fwd_binding.bind_arg(PROFILING, settings); + fwd_binding.bind_arg(ATTRS, attrs); - binding.bind(INPUT_GRAD, input_tensor(0).grad()); - binding.bind(OUTPUT_GRAD, output_tensor(0).grad()); + fwd_binding.bind(INPUT, input_shape); + fwd_binding.bind(OUTPUT, input_shape); // cast does not change shape - return binding; -} + SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); -void Cast::init(FFModel const &ff) { - this->execute_task(ff, CAST_INIT_TASK_ID, get_init_task_signature()); - // assert(check_output_input_weight_same_parallel_is()); - // parallel_is = outputs[0]->parallel_is; - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_init(ff, argmap); - // IndexLauncher launcher(CAST_INIT_TASK_ID, - // parallel_is, - // TaskArgument(this, sizeof(Cast)), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[0]->region)); - // launcher.add_field(1, FID_DATA); - // FutureMap fm = runtime->execute_index_space(ctx, launcher); - // fm.wait_all_results(); - // set_opmeta_from_futuremap(ff, fm); -} + auto fwd_accessor = env.get_fwd_accessor(CAST_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = env.get_bwd_accessor(CAST_BWD_TASK_ID, bwd_binding); -PerDeviceOpState *Cast::init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); + float forward_time = forward_task_impl(fwd_accessor).value(); + float backward_time = backward_task_impl(bwd_accessor).value(); - FFHandler handler = *((FFHandler const *)task->local_args); - CastPerDeviceState *m = new CastPerDeviceState(handler); - bool profiling = acc.get_argument(PROFILING); - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - m->input_data_type = input->data_type; - m->output_data_type = output->data_type; - m->profiling = profiling; - return m; -} + float sync_time = default_estimate_sync_time(env); -void Cast::forward(FFModel const &ff) { - this->execute_task(ff, CAST_FWD_TASK_ID, get_fwd_task_signature()); - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_forward(ff, argmap); - // IndexLauncher launcher(CAST_FWD_TASK_ID, - // parallel_is, - // TaskArgument(NULL, false), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // inputs[0]->region)); - // launcher.add_field(0, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // outputs[0]->region)); - // launcher.add_field(1, FID_DATA); - // runtime->execute_index_space(ctx, launcher); + return make_metrics(forward_time, backward_time, sync_time, env); } -// template -// void Cast::forward_task_with_1_type(Task const *task, -// std::vector const -// ®ions, Context ctx, Runtime *runtime) -// { -// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); -// if (m->output_data_type == DT_FLOAT) { -// Cast::forward_task_with_2_type(task, regions, ctx, runtime); -// } else if (m->output_data_type == DT_DOUBLE) { -// Cast::forward_task_with_2_type(task, regions, ctx, runtime); -// } else if (m->output_data_type == DT_INT32) { -// Cast::forward_task_with_2_type(task, regions, ctx, -// runtime); -// } else if (m->output_data_type == DT_INT64) { -// Cast::forward_task_with_2_type(task, regions, ctx, -// runtime); -// } -// } - -// template -// void Cast::forward_task_with_2_type(Task const *task, -// std::vector const -// ®ions, Context ctx, Runtime *runtime) -// { -// assert(regions.size() == 2); -// assert(task->regions.size() == regions.size()); -// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); -// // Domain input_domain = runtime->get_index_space_domain( -// // ctx, task->regions[0].region.get_index_space()); -// Domain output_domain = runtime->get_index_space_domain( -// ctx, task->regions[1].region.get_index_space()); -// const IDT *input_ptr = helperGetTensorPointerRO( -// regions[0], task->regions[0], FID_DATA, ctx, runtime); -// ODT *output_ptr = helperGetTensorPointerWO( -// regions[1], task->regions[1], FID_DATA, ctx, runtime); -// forward_kernel_wrapper( -// m, input_ptr, output_ptr, output_domain.get_volume()); -// } - -void Cast::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); - // if (m->input_data_type == DT_FLOAT) { - // Cast::forward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->input_data_type == DT_DOUBLE) { - // Cast::forward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->input_data_type == DT_INT32) { - // Cast::forward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->input_data_type == DT_INT64) { - // Cast::forward_task_with_1_type(task, regions, ctx, runtime); - // } - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - - profile(forward_kernel, - m->profiling, - "[Cast] forward_time = %.2lfms\n", - m, - input, - output) -} +template <> +OpTaskSignature fwd_signature() { + OpTaskSignature fwd(OpTaskType::FWD); -void Cast::backward(FFModel const &ff) { - this->execute_task(ff, CAST_BWD_TASK_ID, get_bwd_task_signature()); - // ArgumentMap argmap; - // Context ctx = ff.config.lg_ctx; - // Runtime *runtime = ff.config.lg_hlr; - // set_argumentmap_for_backward(ff, argmap); - // IndexLauncher launcher(CAST_BWD_TASK_ID, - // parallel_is, - // TaskArgument(NULL, false), - // argmap, - // Predicate::TRUE_PRED, - // false /*must*/, - // 0 /*mapper_id*/, - // outputs[0]->machine_view.hash()); - // launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, - // 0 /*projection id*/, - // READ_ONLY, - // EXCLUSIVE, - // outputs[0]->region_grad)); - // launcher.add_field(0, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // inputs[0]->region_grad)); - // launcher.add_field(1, FID_DATA); - // runtime->execute_index_space(ctx, launcher); -} + fwd.add_arg_slot(ATTRS); + fwd.add_arg_slot(PROFILING); -// template -// void Cast::backward_task_with_1_type(Task const *task, -// std::vector const -// ®ions, Context ctx, Runtime *runtime) -// { -// CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); -// if (m->input_data_type == DT_FLOAT) { -// Cast::backward_task_with_2_type(task, regions, ctx, runtime); -// } else if (m->input_data_type == DT_DOUBLE) { -// Cast::backward_task_with_2_type(task, regions, ctx, -// runtime); -// } else if (m->input_data_type == DT_INT32) { -// Cast::backward_task_with_2_type(task, regions, ctx, -// runtime); -// } else if (m->input_data_type == DT_INT64) { -// Cast::backward_task_with_2_type(task, regions, ctx, -// runtime); -// } -// } - -// template -// void Cast::backward_task_with_2_type(Task const *task, -// std::vector const -// ®ions, Context ctx, Runtime *runtime) -// { -// assert(regions.size() == 2); -// assert(task->regions.size() == regions.size()); -// // Domain input_domain = runtime->get_index_space_domain( -// // ctx, task->regions[0].region.get_index_space()); -// Domain output_domain = runtime->get_index_space_domain( -// ctx, task->regions[1].region.get_index_space()); -// const IDT *input_ptr = helperGetTensorPointerRO( -// regions[0], task->regions[0], FID_DATA, ctx, runtime); -// ODT *output_ptr = helperGetTensorPointerRW( -// regions[1], task->regions[1], FID_DATA, ctx, runtime); -// backward_kernel_wrapper( -// input_ptr, output_ptr, output_domain.get_volume()); -// } - -void Cast::backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - CastPerDeviceState const *m = *((CastPerDeviceState **)task->local_args); - // if (m->output_data_type == DT_FLOAT) { - // Cast::backward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->output_data_type == DT_DOUBLE) { - // Cast::backward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->output_data_type == DT_INT32) { - // Cast::backward_task_with_1_type(task, regions, ctx, runtime); - // } else if (m->output_data_type == DT_INT64) { - // Cast::backward_task_with_1_type(task, regions, ctx, runtime); - // } - auto input_grad = acc.get_tensor(INPUT); - auto output_grad = acc.get_tensor(OUTPUT); - - profile(backward_kernel, - m->profiling, - "[Cast] forward_time = %.2lfms\n", - m, - input_grad, - output_grad) -} + fwd.add_input_slot(INPUT); + fwd.add_output_slot(OUTPUT); -bool Cast::measure_operator_cost(Simulator *sim, - MachineView const &mv, - CostMetrics &cost_metrics) const { - // Assume cast has no cost - cost_metrics.forward_time = 0.0f; - cost_metrics.backward_time = 0.0f; - cost_metrics.inputs_memory = 0; - cost_metrics.outputs_memory = 0; - cost_metrics.weights_memory = 0; - return true; + return fwd; } -void Cast::serialize(Legion::Serializer &sez) const { - sez.serialize(this->outputs[0]->data_type); +template <> +void register_task() { + register_task(CAST_FWD_TASK_ID, + "Cast Fwd", + fwd_signature(), + forward_task); } -using PCG::Node; +template <> +OpTaskSignature bwd_signature() { + OpTaskSignature bwd = infer_bwd_signature(fwd_signature()); -Node Cast::deserialize(FFModel &ff, - Legion::Deserializer &dez, - ParallelTensor inputs[], - int num_inputs) { - assert(num_inputs == 1); - DataType dtype; - dez.deserialize(dtype); - return ff.get_or_create_node(inputs[0], {dtype}); + return bwd; } -Op *Cast::materialize(FFModel &ff, - ParallelTensor inputs[], - int num_inputs) const { - assert(num_inputs == 1); - return new Cast(ff, inputs[0], this->outputs[0]->data_type, this->name); +template <> +void register_task() { + register_task(CAST_BWD_TASK_ID, + "Cast Bwd", + bwd_signature(), + backward_task); } }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/cast.h b/lib/runtime/src/ops/cast.h index 7d346584d5..c0c500e869 100644 --- a/lib/runtime/src/ops/cast.h +++ b/lib/runtime/src/ops/cast.h @@ -16,8 +16,8 @@ #define _FLEXFLOW_CAST_H #include "op-attrs/ops/cast.h" -#include "op_task_invocation.h" #include "sim_environment.h" +#include "task_spec/op_task_invocation.h" namespace FlexFlow { @@ -32,53 +32,11 @@ OpTaskInvocation init(CastAttrs const &); OpTaskInvocation forward(CastAttrs const &); OpTaskInvocation backward(CastAttrs const &); -CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, - BatchNormAttrs const &attrs, - ParallelTensorShape const &input_shape, +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + CastAttrs const &attrs, + InputParallelTensorDesc const &input_shape, ProfilingSettings const &settings, - MachineView const &machine_view); - -/* class Cast : public Op { */ -/* public: */ -/* Cast(FFModel &model, */ -/* ParallelTensor const &input, */ -/* DataType dtype, */ -/* char const *name); */ -/* Cast(FFModel &model, */ -/* CastAttrs const ¶ms, */ -/* std::vector const &input, */ -/* char const *name = nullptr); */ -/* void init(FFModel const &) override; */ -/* void forward(FFModel const &) override; */ -/* void backward(FFModel const &) override; */ -/* static Op * */ -/* create_operator_from_layer(FFModel &model, */ -/* Layer const *layer, */ -/* std::vector const &inputs); - */ -/* static PerDeviceOpState *init_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void forward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void backward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* OpTaskBinding get_init_task_binding() const override; */ -/* OpTaskBinding get_fwd_task_binding() const override; */ -/* OpTaskBinding get_bwd_task_binding() const override; */ - -/* bool measure_operator_cost(Simulator *sim, */ -/* MachineView const &pc, */ -/* CostMetrics &cost_metrics) const; */ -/* }; */ + MachineView const &mv); } // namespace FlexFlow