diff --git a/lib/runtime/src/legion_backing.cc b/lib/runtime/src/legion_backing.cc index 4e24a43db1..51cbf82242 100644 --- a/lib/runtime/src/legion_backing.cc +++ b/lib/runtime/src/legion_backing.cc @@ -3,7 +3,7 @@ #include "model.h" #include "runtime/task_spec/typed_task_invocation.h" #include "task_spec/concrete_args_format.h" -#include "task_spec/device_specific_arg.h" +#include "task_spec/device_specific.h" #include "task_spec/future_args_format.h" #include "task_spec/task_argument_accessor.h" #include "task_spec/task_invocation_args_format.h" @@ -296,7 +296,7 @@ template <> void register_task() { TaskSignature sig; sig.add_arg_slot(FF_INIT_INFO); - sig.add_return_value>(); + sig.add_return_value>(); register_task(FF_INIT_TASK_ID, "cuda init task", sig, ff_init_task); } diff --git a/lib/runtime/src/ops/aggregate_spec.cc b/lib/runtime/src/ops/aggregate_spec.cc index f8cd7db067..d720d5312b 100644 --- a/lib/runtime/src/ops/aggregate_spec.cc +++ b/lib/runtime/src/ops/aggregate_spec.cc @@ -15,7 +15,7 @@ #include "aggregate_spec.h" #include "kernels/aggregate_spec_kernels.h" -#include "task_spec/device_specific_arg.h" +#include "task_spec/device_specific.h" namespace FlexFlow { diff --git a/lib/runtime/src/ops/attention.cc b/lib/runtime/src/ops/attention.cc index 6a3a0bd300..bca87bdb53 100644 --- a/lib/runtime/src/ops/attention.cc +++ b/lib/runtime/src/ops/attention.cc @@ -85,7 +85,7 @@ OpTaskInvocation backward(MultiHeadAttentionAttrs const &attrs) { return {ATTENTION_BWD_TASK_ID, b}; } -static DeviceSpecificArg +static DeviceSpecific init_task_impl(TaskArgumentAccessor const &acc) { auto const &attrs = acc.get_argument(ATTRS); Allocator allocator = acc.get_allocator(); @@ -133,7 +133,7 @@ static DeviceSpecificArg assert(qoSeqLength == output.shape[legion_dim_t(1)]); assert(oProjSize == output.shape[legion_dim_t(0)]); - DeviceSpecificArg per_device_state = + DeviceSpecific per_device_state = acc.create_device_specific( init_kernel(handle, allocator, @@ -155,7 +155,7 @@ static DeviceSpecificArg return per_device_state; } -static DeviceSpecificArg +static DeviceSpecific init_task(Task const *task, std::vector const ®ions, Context ctx, @@ -274,7 +274,7 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, auto init_accessor = env.get_init_accessor(ATTENTION_INIT_TASK_ID, init_binding); - DeviceSpecificArg per_device_state = + DeviceSpecific per_device_state = init_task_impl(init_accessor); SimTaskBinding fwd_binding; diff --git a/lib/runtime/src/task_spec/device_specific_arg.h b/lib/runtime/src/task_spec/device_specific.h similarity index 65% rename from lib/runtime/src/task_spec/device_specific_arg.h rename to lib/runtime/src/task_spec/device_specific.h index 5331bd14b8..e29e4e9450 100644 --- a/lib/runtime/src/task_spec/device_specific_arg.h +++ b/lib/runtime/src/task_spec/device_specific.h @@ -7,18 +7,18 @@ namespace FlexFlow { template -struct DeviceSpecificArg { +struct DeviceSpecific { - DeviceSpecificArg() = delete; + DeviceSpecific() = delete; template - static DeviceSpecificArg create(size_t device_idx, Args &&...args) { + static DeviceSpecific create(size_t device_idx, Args &&...args) { NOT_IMPLEMENTED(); } T const *get(size_t curr_device_idx) const { if (curr_device_idx != this->device_idx) { - throw mk_runtime_error("Invalid access to DeviceSpecificArg: attempted " + throw mk_runtime_error("Invalid access to DeviceSpecific: attempted " "device_idx {} != correct device_idx {})", curr_device_idx, this->device_idx); @@ -31,10 +31,10 @@ struct DeviceSpecificArg { size_t device_idx; }; -// manually force serialization to make DeviceSpecificArgs trivially +// manually force serialization to make DeviceSpecific trivially // serializable template -struct is_trivially_serializable> : std::true_type {}; +struct is_trivially_serializable> : std::true_type {}; } // namespace FlexFlow diff --git a/lib/runtime/src/task_spec/op_arg_ref.h b/lib/runtime/src/task_spec/op_arg_ref.h index 6e921c05e8..3e931d79a4 100644 --- a/lib/runtime/src/task_spec/op_arg_ref.h +++ b/lib/runtime/src/task_spec/op_arg_ref.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_OP_ARG_REF_H #include "arg_ref.h" -#include "device_specific_arg.h" +#include "device_specific.h" #include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { @@ -15,7 +15,7 @@ using OpArgRef = ArgRef; using OpArgRefSpec = ArgRefSpec; template -OpArgRef> per_device_op_state() { +OpArgRef> per_device_op_state() { return {OpArgRefType::PER_DEVICE_OP_STATE}; } diff --git a/lib/runtime/src/task_spec/runtime_arg_ref.cc b/lib/runtime/src/task_spec/runtime_arg_ref.cc index bb516849cd..a0aa242ce6 100644 --- a/lib/runtime/src/task_spec/runtime_arg_ref.cc +++ b/lib/runtime/src/task_spec/runtime_arg_ref.cc @@ -1,5 +1,5 @@ #include "runtime_arg_ref.h" -#include "device_specific_arg.h" +#include "device_specific.h" namespace FlexFlow { @@ -7,7 +7,7 @@ RuntimeArgRef profiling_settings() { return {RuntimeArgRefType::PROFILING_SETTINGS}; } -RuntimeArgRef> ff_handle() { +RuntimeArgRef> ff_handle() { return {RuntimeArgRefType::FF_HANDLE}; } diff --git a/lib/runtime/src/task_spec/runtime_arg_ref.h b/lib/runtime/src/task_spec/runtime_arg_ref.h index 1874d39584..6b4345091a 100644 --- a/lib/runtime/src/task_spec/runtime_arg_ref.h +++ b/lib/runtime/src/task_spec/runtime_arg_ref.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_RUNTIME_ARG_REF_H #include "arg_ref.h" -#include "device_specific_arg.h" +#include "device_specific.h" namespace FlexFlow { @@ -14,7 +14,7 @@ using RuntimeArgRef = ArgRef; using RuntimeArgRefSpec = ArgRefSpec; RuntimeArgRef profiling_settings(); -RuntimeArgRef> ff_handle(); +RuntimeArgRef> ff_handle(); } // namespace FlexFlow diff --git a/lib/runtime/src/task_spec/task_argument_accessor.h b/lib/runtime/src/task_spec/task_argument_accessor.h index 89b0c48cbd..9cc05b8252 100644 --- a/lib/runtime/src/task_spec/task_argument_accessor.h +++ b/lib/runtime/src/task_spec/task_argument_accessor.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_RUNTIME_SRC_TASK_ARGUMENT_ACCESSOR_H #include "accessor.h" -#include "device_specific_arg.h" +#include "device_specific.h" #include "realm_allocator.h" #include "runtime/config.h" #include "utils/exception.h" @@ -166,14 +166,14 @@ struct TaskArgumentAccessor { } template - T *unwrap(DeviceSpecificArg const &arg) const { + T *unwrap(DeviceSpecific const &arg) const { return arg.get(this->get_device_idx()); } template - DeviceSpecificArg create_device_specific(Args &&...args) const { - return DeviceSpecificArg::create(this->get_device_idx(), - std::forward(args)...); + DeviceSpecific create_device_specific(Args &&...args) const { + return DeviceSpecific::create(this->get_device_idx(), + std::forward(args)...); } size_t get_device_idx() const {