Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
9d68a63
Initial work on exposing copies to dynamic graph.
elliottslaughter Feb 26, 2026
b8af3ab
And now with actually fixed tests.
elliottslaughter Feb 26, 2026
4b8411e
More work on copy insertion.
elliottslaughter Feb 26, 2026
76afe66
Don't test mapping in machine_slicing.
elliottslaughter Feb 26, 2026
efc8d56
Basic test for no copies.
elliottslaughter Feb 26, 2026
f43019f
Test copy case.
elliottslaughter Feb 26, 2026
890bf54
Check no copies pre-exist copy insertion.
elliottslaughter Feb 26, 2026
9a1479c
Filter to avoid degenerate copies.
elliottslaughter Feb 26, 2026
49fd444
Wire up copy insertion and fix shard expansion.
elliottslaughter Feb 27, 2026
878b822
Sketch interface for issuing operations.
elliottslaughter Feb 27, 2026
b4fcf43
Sketch interface for copies.
elliottslaughter Feb 27, 2026
9d06077
Implement copies.
elliottslaughter Feb 27, 2026
4560087
Assign copies to a phase based on tensor roles.
elliottslaughter Feb 27, 2026
fe37fc9
Update shard expansion test to include copy case.
elliottslaughter Feb 27, 2026
fdb4ebe
It is safe to return NO_EVENT for nop tasks even in presence of depen…
elliottslaughter Feb 27, 2026
fbb9eba
Update to match Realm PR changes.
elliottslaughter Mar 18, 2026
8778da4
Merge branch 'master' into realm-data-movement-explicit
lockshaw Mar 19, 2026
5304a6b
Updates in response to feedback.
elliottslaughter Mar 19, 2026
ca4de4d
Respond to PR feedback.
elliottslaughter Mar 20, 2026
d47668a
Fixes from PR review
lockshaw Mar 20, 2026
679b25f
Format
lockshaw Mar 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MAPPED_PARALLEL_COMPUTATION_GRAPH_MAPPED_OPERATOR_TASK_GROUP_H

#include "op-attrs/computation_graph_op_attrs.dtg.h"
#include "op-attrs/tensor_slot_name.dtg.h"
#include "pcg/machine_space_coordinate.dtg.h"
#include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.h"
#include "utils/bidict/bidict.h"
Expand Down Expand Up @@ -32,6 +33,10 @@ struct MappedOperatorTaskGroup {
friend struct ::std::hash<MappedOperatorTaskGroup>;
};

bidict<ParallelTensorSpaceCoordinate, MachineSpaceCoordinate>
get_tensor_bindings_for_slot_name(MappedOperatorTaskGroup const &,
TensorSlotName const &);

std::string format_as(::FlexFlow::MappedOperatorTaskGroup const &);
std::ostream &operator<<(std::ostream &,
::FlexFlow::MappedOperatorTaskGroup const &);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "op-attrs/operator_task_space.h"
#include "op-attrs/parallel_tensor_space_coordinate.h"
#include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.h"
#include "utils/bidict/algorithms/transform_values.h"
#include "utils/bidict/generate_bidict.h"
#include "utils/containers/are_all_distinct.h"
#include "utils/containers/require_all_same.h"
Expand Down Expand Up @@ -70,6 +71,17 @@ bidict<MachineSpaceCoordinate, OperatorAtomicTaskShardBinding> const &
return this->shard_bindings;
}

bidict<ParallelTensorSpaceCoordinate, MachineSpaceCoordinate>
get_tensor_bindings_for_slot_name(MappedOperatorTaskGroup const &task_group,
TensorSlotName const &slot_name) {
return transform_values(task_group.get_shard_bindings(),
[&](OperatorAtomicTaskShardBinding const &b) {
return ptensor_space_coord_for_slot_name(b,
slot_name);
})
.reversed();
}

std::string format_as(::FlexFlow::MappedOperatorTaskGroup const &m) {
return fmt::format("<MappedOperatorTaskGroup shard_bindings={}>",
m.get_shard_bindings());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ features = []
docstring = '''
\brief Maps each shard-expanded DynamicNodeInvocation to its corresponding PerDeviceOpState.

PerDeviceOpStateBacking is to PerDeviceOpState as DistributedDeviceHandle is to \ref device_handle_t (i.e., FFHandle).
\ref PerDeviceOpStateBacking is to \ref PerDeviceOpState as \ref DistributedFfHandle is to \ref device_handle_t (i.e., FFHandle).
'''


Expand Down
12 changes: 12 additions & 0 deletions lib/realm-execution/include/realm-execution/realm_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "kernels/allocation.h"
#include "kernels/device_handle_t.dtg.h"
#include "kernels/managed_per_device_ff_handle.h"
#include "op-attrs/parallel_tensor_shape.dtg.h"
#include "op-attrs/tensor_shape.dtg.h"
#include "pcg/device_id_t.dtg.h"
#include "pcg/machine_space_coordinate.dtg.h"
Expand Down Expand Up @@ -62,6 +63,17 @@ struct RealmContext {
int priority = 0);
///\}

/** \name Data movement */
///\{
Realm::Event issue_copy(ParallelTensorShape const &src_shape,
Realm::RegionInstance src_inst,
ParallelTensorShape const &dst_shape,
Realm::RegionInstance dst_inst,
Realm::ProfilingRequestSet const &requests,
Realm::Event wait_on = Realm::Event::NO_EVENT,
int priority = 0);
///\}

/** \name Instance management */
///\{
std::pair<Realm::RegionInstance, Realm::Event>
Expand Down
111 changes: 84 additions & 27 deletions lib/realm-execution/src/realm-execution/pcg_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "realm-execution/realm_context.h"
#include "realm-execution/tasks/impl/op_task.h"
#include "realm-execution/tensor_instance_backing.h"
#include "task-spec/dynamic_graph/copy_insertion.h"
#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h"
#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h"
#include "task-spec/dynamic_graph/dynamic_task_type.dtg.h"
Expand All @@ -18,6 +19,7 @@
#include "task-spec/dynamic_graph/shard_expansion.h"
#include "task-spec/dynamic_graph/training_operation_attrs.dtg.h"
#include "task-spec/dynamic_graph/update_insertion.h"
#include "utils/containers/get_only.h"
#include "utils/containers/map_values.h"
#include "utils/containers/transform.h"
#include "utils/containers/try_at.h"
Expand Down Expand Up @@ -104,6 +106,7 @@ PCGInstance create_pcg_instance(
}

dg = perform_update_insertion(dg, optimizer_attrs);
dg = perform_copy_insertion(dg);
dg = perform_shard_expansion(dg);
TensorInstanceBacking tensor_instance_backing =
perform_instance_allocation(dg, inputs, ctx);
Expand Down Expand Up @@ -157,6 +160,76 @@ PCGInstance create_pcg_instance(
/*logit_grad_tensor=*/logit_grad_tensor};
}

/**
* \brief Spawn the Realm operations (tasks, copies, etc.) for a given \ref
* DynamicNodeInvocation, given the specified dependencies, instances, etc. Note
* that one \ref DynamicNodeInvocation may become multiple Realm operations
* (e.g., a parallel operator may turn into multiple copies).
*/
static Realm::Event spawn_dynamic_node_invocation(
RealmContext &ctx,
DynamicNodeInvocation const &invocation,
std::vector<Realm::Event> const &input_dependencies,
std::vector<Realm::Event> const &output_dependencies,
TensorInstanceBacking const &tensor_instance_backing,
PerDeviceOpStateBacking const &device_state_backing,
OptimizerAttrs const &optimizer_attrs,
ProfilingSettings const &profiling_settings,
DistributedFfHandle const &device_handle,
FFIterationConfig iteration_config) {
Realm::Event precondition = Realm::Event::merge_events(
Realm::Event::merge_events(input_dependencies),
Realm::Event::merge_events(output_dependencies));

TensorInstanceBacking tensor_backing =
subset_tensor_instance_backing_for_invocation(tensor_instance_backing,
invocation);

auto spawn_task = [&]() {
Realm::Processor target_proc = ctx.map_device_coord_to_processor(
assert_unwrap(invocation.node_attrs.device_coord));
return spawn_op_task(ctx,
target_proc,
invocation,
tensor_backing,
try_at(device_state_backing.backing, invocation),
profiling_settings,
device_handle.at(target_proc),
iteration_config,
optimizer_attrs,
precondition);
};

auto issue_copy = [&]() {
DynamicValueAttrs const &input = get_only(invocation.inputs).second;
DynamicValueAttrs const &output = get_only(invocation.outputs).second;
Realm::RegionInstance src_inst =
tensor_instance_backing.backing.at(input).first;
Realm::RegionInstance dst_inst =
tensor_instance_backing.backing.at(output).first;
return ctx.issue_copy(assert_unwrap(input.parallel_tensor_shape),
src_inst,
assert_unwrap(output.parallel_tensor_shape),
dst_inst,
Realm::ProfilingRequestSet{},
precondition);
};

TrainingOperationAttrs op_attrs =
assert_unwrap(invocation.node_attrs.op_attrs);
return op_attrs.visit<Realm::Event>(overload{
[&](PCGOperatorAttrs const &pcg_op_attrs) {
return pcg_op_attrs.visit<Realm::Event>(overload{
[&](InputAttrs const &) { return Realm::Event::NO_EVENT; },
[&](WeightAttrs const &) { return Realm::Event::NO_EVENT; },
[&](auto const &) { return spawn_task(); },
});
},
[&](LossAttrs const &) { return spawn_task(); },
[&](CopyAttrs const &) { return issue_copy(); },
});
}

static std::unordered_map<dynamic_layer_guid_t, Realm::Event>
execute_distributed_dynamic_node_invocation_set(
RealmContext &ctx,
Expand All @@ -172,14 +245,6 @@ static std::unordered_map<dynamic_layer_guid_t, Realm::Event>
DependencySet dependency_set{ctx.get_outstanding_events()};
return unordered_map_from_pairs(
transform(invocations, [&](DynamicNodeInvocation const &invocation) {
TrainingOperationAttrs op_attrs =
assert_unwrap(invocation.node_attrs.op_attrs);
if (op_attrs.is_pcg_op() && (op_attrs.require_pcg_op().is_input() ||
op_attrs.require_pcg_op().is_weight())) {
return std::pair{invocation.node_attrs.layer_guid,
Realm::Event::NO_EVENT};
}

std::vector<Realm::Event> input_dependencies =
transform(vector_of(values(invocation.inputs)),
[&](DynamicValueAttrs const &value) {
Expand All @@ -190,27 +255,19 @@ static std::unordered_map<dynamic_layer_guid_t, Realm::Event>
[&](DynamicValueAttrs const &value) {
return dependency_set.get_dependency_for_writer(value);
});
Realm::Event dependencies = Realm::Event::merge_events(
Realm::Event::merge_events(input_dependencies),
Realm::Event::merge_events(output_dependencies));
Realm::Processor target_proc = ctx.map_device_coord_to_processor(
assert_unwrap(invocation.node_attrs.device_coord));

TensorInstanceBacking tensor_backing =
subset_tensor_instance_backing_for_invocation(
tensor_instance_backing, invocation);

Realm::Event result =
spawn_op_task(ctx,
target_proc,
invocation,
tensor_backing,
try_at(device_state_backing.backing, invocation),
profiling_settings,
device_handle.at(target_proc),
iteration_config,
optimizer_attrs,
dependencies);
spawn_dynamic_node_invocation(ctx,
invocation,
input_dependencies,
output_dependencies,
tensor_instance_backing,
device_state_backing,
optimizer_attrs,
profiling_settings,
device_handle,
iteration_config);

for (DynamicValueAttrs const &value : values(invocation.inputs)) {
dependency_set.add_reader(value, result);
}
Expand Down
82 changes: 80 additions & 2 deletions lib/realm-execution/src/realm-execution/realm_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "kernels/device_handle_t.dtg.h"
#include "kernels/device_handle_t.h"
#include "op-attrs/datatype.h"
#include "op-attrs/parallel_tensor_shape.h"
#include "op-attrs/tensor_dims.dtg.h"
#include "pcg/device_id_t.h"
#include "pcg/device_type.dtg.h"
Expand All @@ -10,6 +11,7 @@
#include "realm-execution/tasks/task_id_t.h"
#include "utils/containers/contains_key.h"
#include "utils/containers/transform.h"
#include "utils/exception.h"
#include "utils/nonnegative_int/nonnegative_int.h"
#include "utils/one_to_many/one_to_many.h"
#include "utils/positive_int/positive_int.h"
Expand Down Expand Up @@ -146,13 +148,89 @@ static Realm::Rect<N, T> rect_from_dims(TensorDims const &dims) {
Realm::Point<N, T>::ONES()};
}

template <int N, typename T = int>
static Realm::IndexSpace<N, T> ispace_from_dims(TensorDims const &dims) {
Realm::Rect<N, T> rect = rect_from_dims<N, T>(dims);
return Realm::IndexSpace<N, T>{rect};
}

Realm::Event
RealmContext::issue_copy(ParallelTensorShape const &src_shape,
Realm::RegionInstance src_inst,
ParallelTensorShape const &dst_shape,
Realm::RegionInstance dst_inst,
Realm::ProfilingRequestSet const &requests,
Realm::Event wait_on,
int priority) {
TensorShape src_piece_shape = get_piece_shape(src_shape);
TensorShape dst_piece_shape = get_piece_shape(dst_shape);
ASSERT(src_piece_shape == dst_piece_shape); // For now, assume they match

Realm::CopySrcDstField src_field;
src_field.set_field(
/*inst=*/src_inst,
/*field_id=*/0,
/*size=*/
static_cast<size_t>(
size_of_datatype(src_piece_shape.data_type).int_from_positive_int()),
/*subfield_offset=*/0);
Realm::CopySrcDstField dst_field;
dst_field.set_field(
/*inst=*/dst_inst,
/*field_id=*/0,
/*size=*/
static_cast<size_t>(
size_of_datatype(src_piece_shape.data_type).int_from_positive_int()),
/*subfield_offset=*/0);

Realm::Event result;
switch (src_piece_shape.dims.ff_ordered.num_dims()) {
#if REALM_MAX_DIM >= 1
case 1:
result = ispace_from_dims<1>(src_piece_shape.dims)
.copy({src_field}, {dst_field}, requests, wait_on, priority);
break;
#endif
#if REALM_MAX_DIM >= 2
case 2:
result = ispace_from_dims<2>(src_piece_shape.dims)
.copy({src_field}, {dst_field}, requests, wait_on, priority);
break;
#endif
#if REALM_MAX_DIM >= 3
case 3:
result = ispace_from_dims<3>(src_piece_shape.dims)
.copy({src_field}, {dst_field}, requests, wait_on, priority);
break;
#endif
#if REALM_MAX_DIM >= 4
case 4:
result = ispace_from_dims<4>(src_piece_shape.dims)
.copy({src_field}, {dst_field}, requests, wait_on, priority);
break;
#endif
#if REALM_MAX_DIM >= 5
case 5:
result = ispace_from_dims<5>(src_piece_shape.dims)
.copy({src_field}, {dst_field}, requests, wait_on, priority);
break;
#endif
default:
PANIC("TensorShape dims greater than REALM_MAX_DIM: {}",
src_piece_shape.dims.ff_ordered.num_dims());
break;
}
this->outstanding_events.push_back(result);
return result;
}

std::pair<Realm::RegionInstance, Realm::Event>
RealmContext::create_instance(Realm::Memory memory,
TensorShape const &shape,
Realm::ProfilingRequestSet const &prs,
Realm::Event wait_on) {
std::vector<size_t> field_sizes{
static_cast<size_t>(int{size_of_datatype(shape.data_type)})};
std::vector<size_t> field_sizes{static_cast<size_t>(
size_of_datatype(shape.data_type).int_from_positive_int())};
Realm::RegionInstance inst;
Realm::Event ready;
switch (shape.dims.ff_ordered.num_dims()) {
Expand Down
13 changes: 13 additions & 0 deletions lib/task-spec/include/task-spec/dynamic_graph/copy_attrs.dtg.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
namespace = "FlexFlow"
name = "CopyAttrs"
type = "struct"
features = [
"eq",
"ord",
"hash",
"json",
"fmt",
"rapidcheck",
]

fields = []
26 changes: 26 additions & 0 deletions lib/task-spec/include/task-spec/dynamic_graph/copy_insertion.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_COPY_INSERTION_H
#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_COPY_INSERTION_H

#include "task-spec/dynamic_graph/dynamic_node_attrs.dtg.h"
#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h"
#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h"

namespace FlexFlow {

bool node_is_copy(DynamicNodeAttrs const &n);
bool value_is_mapped(DynamicValueAttrs const &);

bool no_part_of_graph_is_copy_inserted(DynamicOpenDataflowGraph const &);
bool graph_is_fully_copy_inserted(DynamicOpenDataflowGraph const &);

std::unordered_set<DynamicNodeInvocation> perform_copy_insertion_for_invocation(
DynamicNodeInvocation const &i,
std::unordered_map<DynamicValueAttrs, DynamicValueAttrs> const
&unmapped_value_to_mapped_source_value);

DynamicOpenDataflowGraph
perform_copy_insertion(DynamicOpenDataflowGraph const &);

} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
namespace = "FlexFlow"
name = "dynamic_copy_layer_guid_t"
type = "struct"
features = [
"eq",
"ord",
"hash",
"json",
"fmt",
"rapidcheck",
]

fields = []
Loading
Loading