Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/runtime/src/legion_backing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include "model.h"
#include "runtime/task_spec/typed_task_invocation.h"
#include "task_spec/concrete_args_format.h"
#include "task_spec/device_specific_arg.h"
#include "task_spec/device_specific.h"
#include "task_spec/future_args_format.h"
#include "task_spec/task_argument_accessor.h"
#include "task_spec/task_invocation_args_format.h"
Expand Down Expand Up @@ -296,7 +296,7 @@ template <>
void register_task<FF_INIT_TASK_ID>() {
TaskSignature sig;
sig.add_arg_slot<FFInitInfo>(FF_INIT_INFO);
sig.add_return_value<DeviceSpecificArg<PerDeviceFFHandle>>();
sig.add_return_value<DeviceSpecific<PerDeviceFFHandle>>();

register_task(FF_INIT_TASK_ID, "cuda init task", sig, ff_init_task);
}
Expand Down
2 changes: 1 addition & 1 deletion lib/runtime/src/ops/aggregate_spec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

#include "aggregate_spec.h"
#include "kernels/aggregate_spec_kernels.h"
#include "task_spec/device_specific_arg.h"
#include "task_spec/device_specific.h"

namespace FlexFlow {

Expand Down
8 changes: 4 additions & 4 deletions lib/runtime/src/ops/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ OpTaskInvocation backward(MultiHeadAttentionAttrs const &attrs) {
return {ATTENTION_BWD_TASK_ID, b};
}

static DeviceSpecificArg<MHAPerDeviceState>
static DeviceSpecific<MHAPerDeviceState>
init_task_impl(TaskArgumentAccessor const &acc) {
auto const &attrs = acc.get_argument<MultiHeadAttentionAttrs>(ATTRS);
Allocator allocator = acc.get_allocator();
Expand Down Expand Up @@ -133,7 +133,7 @@ static DeviceSpecificArg<MHAPerDeviceState>
assert(qoSeqLength == output.shape[legion_dim_t(1)]);
assert(oProjSize == output.shape[legion_dim_t(0)]);

DeviceSpecificArg<MHAPerDeviceState> per_device_state =
DeviceSpecific<MHAPerDeviceState> per_device_state =
acc.create_device_specific<MHAPerDeviceState>(
init_kernel(handle,
allocator,
Expand All @@ -155,7 +155,7 @@ static DeviceSpecificArg<MHAPerDeviceState>
return per_device_state;
}

static DeviceSpecificArg<MHAPerDeviceState>
static DeviceSpecific<MHAPerDeviceState>
init_task(Task const *task,
std::vector<PhysicalRegion> const &regions,
Context ctx,
Expand Down Expand Up @@ -274,7 +274,7 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim,

auto init_accessor =
env.get_init_accessor(ATTENTION_INIT_TASK_ID, init_binding);
DeviceSpecificArg<MHAPerDeviceState> per_device_state =
DeviceSpecific<MHAPerDeviceState> per_device_state =
init_task_impl(init_accessor);

SimTaskBinding fwd_binding;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@
namespace FlexFlow {

template <typename T>
struct DeviceSpecificArg {
struct DeviceSpecific {

DeviceSpecificArg() = delete;
DeviceSpecific() = delete;

template <typename... Args>
static DeviceSpecificArg<T> create(size_t device_idx, Args &&...args) {
static DeviceSpecific<T> create(size_t device_idx, Args &&...args) {
NOT_IMPLEMENTED();
}

T const *get(size_t curr_device_idx) const {
if (curr_device_idx != this->device_idx) {
throw mk_runtime_error("Invalid access to DeviceSpecificArg: attempted "
throw mk_runtime_error("Invalid access to DeviceSpecific: attempted "
"device_idx {} != correct device_idx {})",
curr_device_idx,
this->device_idx);
Expand All @@ -31,10 +31,10 @@ struct DeviceSpecificArg {
size_t device_idx;
};

// manually force serialization to make DeviceSpecificArgs trivially
// manually force serialization to make DeviceSpecific trivially
// serializable
template <typename T>
struct is_trivially_serializable<DeviceSpecificArg<T>> : std::true_type {};
struct is_trivially_serializable<DeviceSpecific<T>> : std::true_type {};

} // namespace FlexFlow

Expand Down
4 changes: 2 additions & 2 deletions lib/runtime/src/task_spec/op_arg_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#define _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_OP_ARG_REF_H

#include "arg_ref.h"
#include "device_specific_arg.h"
#include "device_specific.h"
#include "op-attrs/parallel_tensor_shape.h"

namespace FlexFlow {
Expand All @@ -15,7 +15,7 @@ using OpArgRef = ArgRef<OpArgRefType, T>;
using OpArgRefSpec = ArgRefSpec<OpArgRefType>;

template <typename T>
OpArgRef<DeviceSpecificArg<T>> per_device_op_state() {
OpArgRef<DeviceSpecific<T>> per_device_op_state() {
return {OpArgRefType::PER_DEVICE_OP_STATE};
}

Expand Down
4 changes: 2 additions & 2 deletions lib/runtime/src/task_spec/runtime_arg_ref.cc
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#include "runtime_arg_ref.h"
#include "device_specific_arg.h"
#include "device_specific.h"

namespace FlexFlow {

RuntimeArgRef<ProfilingSettings> profiling_settings() {
return {RuntimeArgRefType::PROFILING_SETTINGS};
}

RuntimeArgRef<DeviceSpecificArg<PerDeviceFFHandle>> ff_handle() {
RuntimeArgRef<DeviceSpecific<PerDeviceFFHandle>> ff_handle() {
return {RuntimeArgRefType::FF_HANDLE};
}

Expand Down
4 changes: 2 additions & 2 deletions lib/runtime/src/task_spec/runtime_arg_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#define _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_RUNTIME_ARG_REF_H

#include "arg_ref.h"
#include "device_specific_arg.h"
#include "device_specific.h"

namespace FlexFlow {

Expand All @@ -14,7 +14,7 @@ using RuntimeArgRef = ArgRef<RuntimeArgRefType, T>;
using RuntimeArgRefSpec = ArgRefSpec<RuntimeArgRefType>;

RuntimeArgRef<ProfilingSettings> profiling_settings();
RuntimeArgRef<DeviceSpecificArg<PerDeviceFFHandle>> ff_handle();
RuntimeArgRef<DeviceSpecific<PerDeviceFFHandle>> ff_handle();

} // namespace FlexFlow

Expand Down
10 changes: 5 additions & 5 deletions lib/runtime/src/task_spec/task_argument_accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#define _FLEXFLOW_RUNTIME_SRC_TASK_ARGUMENT_ACCESSOR_H

#include "accessor.h"
#include "device_specific_arg.h"
#include "device_specific.h"
#include "realm_allocator.h"
#include "runtime/config.h"
#include "utils/exception.h"
Expand Down Expand Up @@ -166,14 +166,14 @@ struct TaskArgumentAccessor {
}

template <typename T>
T *unwrap(DeviceSpecificArg<T> const &arg) const {
T *unwrap(DeviceSpecific<T> const &arg) const {
return arg.get(this->get_device_idx());
}

template <typename T, typename... Args>
DeviceSpecificArg<T> create_device_specific(Args &&...args) const {
return DeviceSpecificArg<T>::create(this->get_device_idx(),
std::forward<Args>(args)...);
DeviceSpecific<T> create_device_specific(Args &&...args) const {
return DeviceSpecific<T>::create(this->get_device_idx(),
std::forward<Args>(args)...);
}

size_t get_device_idx() const {
Expand Down